Skip to content

Commit 01e4f05

Browse files
committed
TFLiteConverter: Support QuantizeAndDequantizeV4
1 parent 9ec5804 commit 01e4f05

File tree

6 files changed

+43
-2
lines changed

6 files changed

+43
-2
lines changed

tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,3 +2068,15 @@ func @all_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor<?xi
20682068
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
20692069
// CHECK: "tfl.reduce_all"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor<?xi1>
20702070
}
2071+
2072+
func @quantize_dequantize_v4(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
2073+
%cst = constant dense<0.0> : tensor<f32>
2074+
%cst_0 = constant dense<255.0> : tensor<f32>
2075+
%0 = "tf.QuantizeAndDequantizeV4"(%arg0, %cst, %cst_0) : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
2076+
return %0 : tensor<?x?xf32>
2077+
2078+
// CHECK-LABEL: quantize_dequantize_v4
2079+
// CHECK: %[[QUANT:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<?x?x!quant.uniform<u8:f32, 1.000000e+00>>} : (tensor<?x?xf32>) -> tensor<?x?x!quant.uniform<u8:f32, 1.000000e+00>>
2080+
// CHECK: %[[DEQUANT:.*]] = "tfl.dequantize"(%[[QUANT]]) : (tensor<?x?x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<?x?xf32>
2081+
// CHECK: return %[[DEQUANT]]
2082+
}

tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ def LegalizeFakeQuantWithMinMaxVars: Pat<
297297

298298
// TODO(rocky): Not all of the attributes are handled correctly. Make this
299299
// more general if there is a need.
300-
def LegalizeQuantizeAndDequantizeV2 : Pat<
301-
(TF_QuantizeAndDequantizeV2Op $inputs, (ConstantOp F32ElementsAttr:$min),
300+
def LegalizeQuantizeAndDequantizeV4 : Pat<
301+
(TF_QuantizeAndDequantizeV4Op $inputs, (ConstantOp F32ElementsAttr:$min),
302302
(ConstantOp F32ElementsAttr:$max),
303303
$signed_input, $num_bits, $range_given, $round_mode, $narrow_range, $axis),
304304
(TFL_DequantizeOp

tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9862,6 +9862,8 @@ tensor.}]>:$input_max,
98629862
);
98639863

98649864
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
9865+
9866+
let hasCanonicalizer = 1;
98659867
}
98669868

98679869
def TF_QuantizeAndDequantizeV3Op : TF_Op<"QuantizeAndDequantizeV3", [NoSideEffect]> {

tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,15 @@ OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
491491
return {};
492492
}
493493

494+
//===----------------------------------------------------------------------===//
495+
// QuantizeAndDequantizeV2Op
496+
//===----------------------------------------------------------------------===//
497+
498+
void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns(
499+
OwningRewritePatternList& results, MLIRContext* context) {
500+
results.insert<QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4>(context);
501+
}
502+
494503
//===----------------------------------------------------------------------===//
495504
// QrOp
496505
//===----------------------------------------------------------------------===//

tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,3 +1695,11 @@ func @while_with_id_passthrough(%arg0: tensor<7xf32> {tf._user_specified_name =
16951695
%7 = "tf.Identity"(%6#2) {device = ""} : (tensor<?xf32>) -> tensor<?xf32>
16961696
return %7 : tensor<?xf32>
16971697
}
1698+
1699+
// CHECK-LABEL: testConvertQuantizeAndDequantizeV2ToQuantizeAndDequantizeV4
1700+
func @testConvertQuantizeAndDequantizeV2ToQuantizeAndDequantizeV4(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?x?xf32> {
1701+
%0 = "tf.QuantizeAndDequantizeV2"(%arg0, %arg1, %arg2) : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
1702+
return %0 : tensor<?x?xf32>
1703+
// CHECK: %[[QUANT:.*]] = "tf.QuantizeAndDequantizeV4"(%arg0, %arg1, %arg2) {axis = -1 : i64, narrow_range = false, num_bits = 8 : i64, range_given = false, round_mode = "HALF_TO_EVEN", signed_input = true} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
1704+
// CHECK: return %[[QUANT]] : tensor<?x?xf32>
1705+
}

tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,16 @@ def MatrixSetDiagV2ToV3 : Pat<(TF_MatrixSetDiagV2Op $input, $diag, $k),
188188
(TF_MatrixSetDiagV3Op $input, $diag, $k,
189189
(GetStrAttr<"LEFT_LEFT">))>;
190190

191+
//===----------------------------------------------------------------------===//
192+
// QuantizeAndDequantizeV2 op patterns.
193+
//===----------------------------------------------------------------------===//
194+
195+
def QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4 : Pat<
196+
(TF_QuantizeAndDequantizeV2Op $inputs, $min, $max, $signed_input, $num_bits,
197+
$range_given, $round_mode, $narrow_range, $axis),
198+
(TF_QuantizeAndDequantizeV4Op $inputs, $min, $max, $signed_input, $num_bits,
199+
$range_given, $round_mode, $narrow_range, $axis)>;
200+
191201
//===----------------------------------------------------------------------===//
192202
// RealDiv op patterns.
193203
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)