@@ -435,6 +435,16 @@ func @scatterNdHigherRankIndices(%arg0: tensor<4x2x2xi32>, %arg1: tensor<4x2x3xf
435435// CHECK: return %[[RES]]
436436}
437437
438+ func @scatter_nd_i64 (%arg0: tensor <4 x2 x2 xi64 >, %arg1: tensor <4 x2 x3 xf32 >, %arg2: tensor <3 xi64 >) -> tensor <10 x2 x3 xf32 > {
439+ %0 = " tf.ScatterNd" (%arg0 , %arg1 , %arg2 ) : (tensor <4 x2 x2 xi64 >, tensor <4 x2 x3 xf32 >, tensor <3 xi64 >) -> tensor <10 x2 x3 xf32 >
440+ return %0 : tensor <10 x2 x3 xf32 >
441+
442+ // CHECK-LABEL:scatter_nd_i64
443+ // CHECK: "tfl.cast"
444+ // CHECK: "tfl.cast"
445+ // CHECK: "tfl.scatter_nd"
446+ }
447+
438448func @gatherV2VectorIndices (%arg0 : tensor <1 x2 x20 xf32 >, %arg1 : tensor <3 x5 xi32 >) -> tensor <1 x3 x5 x20 xf32 > {
439449 %0 = " tf.Const" () { value = dense <[1 ]> : tensor <1 xi32 > } : () -> tensor <1 xi32 >
440450 %1 = " tf.GatherV2" (%arg0 , %arg1 , %0 ) : (tensor <1 x2 x20 xf32 >, tensor <3 x5 xi32 >, tensor <1 xi32 >) -> tensor <1 x3 x5 x20 xf32 >
@@ -689,6 +699,16 @@ func @reverse_v2(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi32>) -> tensor<1x2
689699// CHECK: return
690700}
691701
702+ func @reverse_v2_i64 (%arg0: tensor <1 x2 x3 x4 xf32 >, %arg1: tensor <1 xi64 >) -> tensor <1 x2 x3 x4 xf32 > {
703+ %0 = " tf.ReverseV2" (%arg0 , %arg1 ) : (tensor <1 x2 x3 x4 xf32 >, tensor <1 xi64 >) -> tensor <1 x2 x3 x4 xf32 >
704+ return %0 : tensor <1 x2 x3 x4 xf32 >
705+
706+ // CHECK-LABEL:reverse_v2_i64
707+ // CHECK: "tfl.cast"
708+ // CHECK: "tfl.reverse_v2"
709+ // CHECK: return
710+ }
711+
692712func @matrix_diag (%arg0: tensor <8 x16 xf32 >) -> tensor <8 x16 x16 xf32 > {
693713 %0 = " tf.MatrixDiag" (%arg0 ) : (tensor <8 x16 xf32 >) -> tensor <8 x16 x16 xf32 >
694714 return %0 : tensor <8 x16 x16 xf32 >
@@ -996,13 +1016,31 @@ func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<
9961016 // CHECK: "tf.BatchToSpaceND"
9971017}
9981018
1019+ func @batch_to_space_nd_i64 (%arg0: tensor <4 x2 x2 x3 xf32 >, %arg1: tensor <2 xi64 >, %arg2: tensor <2 x2 xi64 >) -> tensor <?xf32 > {
1020+ %0 = " tf.BatchToSpaceND" (%arg0 , %arg1 , %arg2 ) : (tensor <4 x2 x2 x3 xf32 >, tensor <2 xi64 >, tensor <2 x2 xi64 >) -> tensor <?xf32 >
1021+ return %0 : tensor <?xf32 >
1022+ // CHECK-LABEL: batch_to_space_nd_i64
1023+ // CHECK: "tfl.cast"
1024+ // CHECK: "tfl.cast"
1025+ // CHECK: "tfl.batch_to_space_nd"
1026+ }
1027+
9991028func @space_to_batch_nd (%arg0: tensor <1 x4 x4 x3 xf32 >, %arg1: tensor <2 xi32 >, %arg2: tensor <2 x2 xi32 >) -> tensor <*xf32 > {
10001029 %0 = " tf.SpaceToBatchND" (%arg0 , %arg1 , %arg2 ) : (tensor <1 x4 x4 x3 xf32 >, tensor <2 xi32 >, tensor <2 x2 xi32 >) -> tensor <*xf32 >
10011030 return %0 : tensor <*xf32 >
10021031 // CHECK-LABEL: space_to_batch_nd
10031032 // CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
10041033}
10051034
1035+ func @space_to_batch_nd_i64 (%arg0: tensor <1 x4 x4 x3 xf32 >, %arg1: tensor <2 xi64 >, %arg2: tensor <2 x2 xi64 >) -> tensor <*xf32 > {
1036+ %0 = " tf.SpaceToBatchND" (%arg0 , %arg1 , %arg2 ) : (tensor <1 x4 x4 x3 xf32 >, tensor <2 xi64 >, tensor <2 x2 xi64 >) -> tensor <*xf32 >
1037+ return %0 : tensor <*xf32 >
1038+ // CHECK-LABEL: space_to_batch_nd_i64
1039+ // CHECK: "tfl.cast"
1040+ // CHECK: "tfl.cast"
1041+ // CHECK: "tfl.space_to_batch_nd"
1042+ }
1043+
10061044func @split (%arg0: tensor <i32 >, %arg1: tensor <1 x4 x3 x3 xf32 >) -> tensor <1 x4 x3 xf32 > {
10071045 %0:3 = " tf.Split" (%arg0 , %arg1 ) : (tensor <i32 >, tensor <1 x4 x3 x3 xf32 >) -> (tensor <1 x4 x3 xf32 >, tensor <1 x4 x3 xf32 >, tensor <1 x4 x3 xf32 >)
10081046 return %0#0 : tensor <1 x4 x3 xf32 >
@@ -1361,8 +1399,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
13611399
13621400 // CHECK-LABEL: conv2d_backprop_input
13631401 // CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
1364- // CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST]]) : (tensor<4xi32>) -> tensor<4xi32>
1365- // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CAST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
1402+ // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
13661403 // CHECK: %[[CST_0:.*]] = constant unit
13671404 // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
13681405 // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
@@ -1797,10 +1834,25 @@ func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
17971834 // CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
17981835}
17991836
1800- func @cumsum_invalid (%arg0: tensor <3 x3 xf32 >, %arg1: tensor <i64 >) -> tensor <3 x3 xf32 > {
1837+ func @cumsum_i64 (%arg0: tensor <3 x3 xf32 >, %arg1: tensor <i64 >) -> tensor <3 x3 xf32 > {
18011838 %0 = " tf.Cumsum" (%arg0 , %arg1 ) {exclusive = false , reverse = false } : (tensor <3 x3 xf32 >, tensor <i64 >) -> tensor <3 x3 xf32 >
18021839 return %0 : tensor <3 x3 xf32 >
1803- // CHECK-LABEL: cumsum_invalid
1804- // CHECK-NOT: "tfl.cumsum"
1840+ // CHECK-LABEL: cumsum_i64
1841+ // CHECK: "tfl.cast"
1842+ // CHECK: "tfl.cumsum"
18051843}
18061844
1845+ func @segmentsum (%arg0: tensor <3 x3 xf32 >, %arg1: tensor <i32 >) -> tensor <*xf32 > {
1846+ %0 = " tf.SegmentSum" (%arg0 , %arg1 ) : (tensor <3 x3 xf32 >, tensor <i32 >) -> tensor <*xf32 >
1847+ return %0 : tensor <*xf32 >
1848+ // CHECK-LABEL: segmentsum
1849+ // CHECK: "tfl.segment_sum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
1850+ }
1851+
1852+ func @segmentsum_i64 (%arg0: tensor <3 x3 xf32 >, %arg1: tensor <i64 >) -> tensor <*xf32 > {
1853+ %0 = " tf.SegmentSum" (%arg0 , %arg1 ) : (tensor <3 x3 xf32 >, tensor <i64 >) -> tensor <*xf32 >
1854+ return %0 : tensor <*xf32 >
1855+ // CHECK-LABEL: segmentsum_i64
1856+ // CHECK: "tfl.cast"
1857+ // CHECK: "tfl.segment_sum"
1858+ }
0 commit comments