Skip to content

Commit f233dc7

Browse files
abatterytensorflower-gardener
authored andcommitted
Support integer64 bit indices via TFLite builtin ops
Added cast ops in front of the integer64 bit indices. PiperOrigin-RevId: 343000749 Change-Id: Idbb474581a5918bb630fbfa27647031d794280af
1 parent aef6ae1 commit f233dc7

File tree

3 files changed

+81
-46
lines changed

3 files changed

+81
-46
lines changed

tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,16 @@ func @scatterNdHigherRankIndices(%arg0: tensor<4x2x2xi32>, %arg1: tensor<4x2x3xf
435435
// CHECK: return %[[RES]]
436436
}
437437

438+
func @scatter_nd_i64(%arg0: tensor<4x2x2xi64>, %arg1: tensor<4x2x3xf32>, %arg2: tensor<3xi64>) -> tensor<10x2x3xf32> {
439+
%0 = "tf.ScatterNd"(%arg0, %arg1, %arg2) : (tensor<4x2x2xi64>, tensor<4x2x3xf32>, tensor<3xi64>) -> tensor<10x2x3xf32>
440+
return %0 : tensor<10x2x3xf32>
441+
442+
// CHECK-LABEL:scatter_nd_i64
443+
// CHECK: "tfl.cast"
444+
// CHECK: "tfl.cast"
445+
// CHECK: "tfl.scatter_nd"
446+
}
447+
438448
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
439449
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
440450
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
@@ -689,6 +699,16 @@ func @reverse_v2(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi32>) -> tensor<1x2
689699
// CHECK: return
690700
}
691701

702+
func @reverse_v2_i64(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi64>) -> tensor<1x2x3x4xf32> {
703+
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<1xi64>) -> tensor<1x2x3x4xf32>
704+
return %0 : tensor<1x2x3x4xf32>
705+
706+
// CHECK-LABEL:reverse_v2_i64
707+
// CHECK: "tfl.cast"
708+
// CHECK: "tfl.reverse_v2"
709+
// CHECK: return
710+
}
711+
692712
func @matrix_diag(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
693713
%0 = "tf.MatrixDiag"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16x16xf32>
694714
return %0 : tensor<8x16x16xf32>
@@ -996,13 +1016,31 @@ func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<
9961016
// CHECK: "tf.BatchToSpaceND"
9971017
}
9981018

1019+
func @batch_to_space_nd_i64(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<?xf32> {
1020+
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<?xf32>
1021+
return %0 : tensor<?xf32>
1022+
// CHECK-LABEL: batch_to_space_nd_i64
1023+
// CHECK: "tfl.cast"
1024+
// CHECK: "tfl.cast"
1025+
// CHECK: "tfl.batch_to_space_nd"
1026+
}
1027+
9991028
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<*xf32> {
10001029
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
10011030
return %0 : tensor<*xf32>
10021031
// CHECK-LABEL: space_to_batch_nd
10031032
// CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
10041033
}
10051034

1035+
func @space_to_batch_nd_i64(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<*xf32> {
1036+
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32>
1037+
return %0 : tensor<*xf32>
1038+
// CHECK-LABEL: space_to_batch_nd_i64
1039+
// CHECK: "tfl.cast"
1040+
// CHECK: "tfl.cast"
1041+
// CHECK: "tfl.space_to_batch_nd"
1042+
}
1043+
10061044
func @split(%arg0: tensor<i32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32> {
10071045
%0:3 = "tf.Split"(%arg0, %arg1) : (tensor<i32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>)
10081046
return %0#0 : tensor<1x4x3xf32>
@@ -1361,8 +1399,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
13611399

13621400
// CHECK-LABEL: conv2d_backprop_input
13631401
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
1364-
// CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST]]) : (tensor<4xi32>) -> tensor<4xi32>
1365-
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CAST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
1402+
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
13661403
// CHECK: %[[CST_0:.*]] = constant unit
13671404
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
13681405
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
@@ -1797,10 +1834,25 @@ func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
17971834
// CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
17981835
}
17991836

1800-
func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
1837+
func @cumsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
18011838
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i64>) -> tensor<3x3xf32>
18021839
return %0 : tensor<3x3xf32>
1803-
// CHECK-LABEL: cumsum_invalid
1804-
// CHECK-NOT: "tfl.cumsum"
1840+
// CHECK-LABEL: cumsum_i64
1841+
// CHECK: "tfl.cast"
1842+
// CHECK: "tfl.cumsum"
18051843
}
18061844

1845+
func @segmentsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
1846+
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
1847+
return %0 : tensor<*xf32>
1848+
// CHECK-LABEL: segmentsum
1849+
// CHECK: "tfl.segment_sum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
1850+
}
1851+
1852+
func @segmentsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<*xf32> {
1853+
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i64>) -> tensor<*xf32>
1854+
return %0 : tensor<*xf32>
1855+
// CHECK-LABEL: segmentsum_i64
1856+
// CHECK: "tfl.cast"
1857+
// CHECK: "tfl.segment_sum"
1858+
}

tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def ExtractSingleElementAsInt32 : NativeCodeCall<
5454
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
5555

5656
// Converts tensor with int64 to int32.
57-
def CreateTFLCastToInt32Op : NativeCodeCall<
57+
def CreateTFCastToInt32Op : NativeCodeCall<
5858
"CreateCastToInt32($0, $_loc, $_builder)">;
5959

6060
// Checks whether the given operation has static shapes and same shapes of all inputs.
@@ -193,8 +193,8 @@ def LegalizeRound : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
193193
def LegalizeRsqrt : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
194194
def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
195195
def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
196-
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids),
197-
(TFL_SegmentSumOp $data, $segment_ids)>;
196+
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, $segment_ids),
197+
(TFL_SegmentSumOp $data, (CreateTFCastToInt32Op $segment_ids))>;
198198
def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y),
199199
(TFL_SelectOp $cond, $x, $y)>;
200200
def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y),
@@ -221,7 +221,7 @@ def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
221221

222222
def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm),
223223
(TFL_TransposeOp $arg,
224-
(CreateTFLCastToInt32Op $perm))>;
224+
(CreateTFCastToInt32Op $perm))>;
225225

226226
def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
227227
def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
@@ -309,8 +309,9 @@ def LegalizeRank : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
309309
def LegalizeSquaredDifference : Pat<(TF_SquaredDifferenceOp $l, $r),
310310
(TFL_SquaredDifferenceOp $l, $r)>;
311311

312-
def LegalizeReverseV2 : Pat<(TF_ReverseV2Op $arg0, $arg1),
313-
(TFL_ReverseV2Op $arg0, $arg1)>;
312+
def LegalizeReverseV2 : Pat<
313+
(TF_ReverseV2Op $arg0, $axis),
314+
(TFL_ReverseV2Op $arg0, (CreateTFCastToInt32Op $axis))>;
314315

315316
def LegalizeEqual : Pat<(TF_EqualOp $arg0, $arg1,
316317
/*incompatible_shape_error=*/ConstBoolAttrTrue),
@@ -349,11 +350,13 @@ def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
349350

350351
def LegalizeBatchToSpaceND : Pat<
351352
(TF_BatchToSpaceNDOp $input, $block_shape, $crops),
352-
(TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
353+
(TFL_BatchToSpaceNdOp $input, (CreateTFCastToInt32Op $block_shape),
354+
(CreateTFCastToInt32Op $crops))>;
353355

354356
def LegalizeSpaceToBatchND : Pat<
355357
(TF_SpaceToBatchNDOp $input, $block_shape, $paddings),
356-
(TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
358+
(TFL_SpaceToBatchNdOp $input, (CreateTFCastToInt32Op $block_shape),
359+
(CreateTFCastToInt32Op $paddings))>;
357360

358361
def LegalizeSpaceToDepth : Pat<
359362
(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format),
@@ -442,9 +445,14 @@ def LegalizeMatrixSetDiag : Pat<
442445
(TFL_MatrixSetDiagOp $input, $diagonal)>;
443446

444447
def LegalizeScatterNd : Pat<
445-
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
446-
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
448+
(TF_ScatterNdOp $indices, $updates, $shape),
449+
(TFL_ScatterNdOp (CreateTFCastToInt32Op $indices), $updates,
450+
(CreateTFCastToInt32Op $shape))>;
447451

448452
def LegalizeCumsum : Pat<
449453
(TF_CumsumOp $input, $axis, $exclusive, $reverse),
450-
(TFL_CumsumOp $input, $axis, $exclusive, $reverse)>;
454+
(TFL_CumsumOp $input, (CreateTFCastToInt32Op $axis), $exclusive, $reverse)>;
455+
456+
def LegalizeReshape : Pat<
457+
(TF_ReshapeOp $input, $shape),
458+
(TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>;

tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
123123
auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
124124
IntegerType new_ele_type = rewriter.getIntegerType(32);
125125
ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
126-
return rewriter.create<TFL::CastOp>(loc, new_type, val);
126+
return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
127+
rewriter.getBoolAttr(false));
127128
}
128129

129130
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
@@ -145,7 +146,6 @@ DECL_CONVERT_OP(MatMul);
145146
DECL_CONVERT_OP(MatrixDiagV2);
146147
DECL_CONVERT_OP(MatrixDiagV3);
147148
DECL_CONVERT_OP(Pack);
148-
DECL_CONVERT_OP(Reshape);
149149
DECL_CONVERT_OP(Split);
150150
DECL_CONVERT_OP(SplitV);
151151
DECL_CONVERT_OP(StridedSlice);
@@ -299,30 +299,6 @@ LogicalResult ConvertTFPackOp::matchAndRewrite(
299299
return success();
300300
}
301301

302-
LogicalResult ConvertTFReshapeOp::matchAndRewrite(
303-
Operation* op, PatternRewriter& rewriter) const {
304-
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
305-
306-
auto input = tf_reshape_op.tensor();
307-
auto shape = tf_reshape_op.shape();
308-
309-
ShapedType shape_type = shape.getType().cast<ShapedType>();
310-
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
311-
if (!shape_type.getElementType().isSignlessInteger(32)) {
312-
auto new_shape = shape_type.getShape();
313-
IntegerType new_ele_type = rewriter.getIntegerType(32);
314-
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
315-
// Uses TF::CastOp to be folded if the shape input is a constant.
316-
shape = rewriter
317-
.create<TF::CastOp>(op->getLoc(), new_type, shape,
318-
rewriter.getBoolAttr(false))
319-
.y();
320-
}
321-
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
322-
input, shape);
323-
return success();
324-
}
325-
326302
LogicalResult ConvertTFSplitOp::matchAndRewrite(
327303
Operation* op, PatternRewriter& rewriter) const {
328304
auto tf_split_op = cast<TF::SplitOp>(op);
@@ -792,10 +768,9 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
792768
populateWithGenerated(context, patterns);
793769
patterns
794770
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
795-
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
796-
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
797-
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFRandomUniformOp>(
798-
context);
771+
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
772+
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
773+
ConvertTFAssertOp, ConvertTFRandomUniformOp>(context);
799774

800775
// Ophint python converter converted tf node pattern.
801776
patterns.insert<LegalizeUnidirectionalSequenceLstm,

0 commit comments

Comments
 (0)