Skip to content

Commit 05303d0

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Fix the op quant trait of tfl.rsqrt and the nudging for the calibration data
PiperOrigin-RevId: 348884204 Change-Id: Iabac61239cf394fa18d491e0da1fb54d17c3b23a
1 parent 7494180 commit 05303d0

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

tensorflow/compiler/mlir/lite/ir/tfl_ops.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2649,8 +2649,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
26492649

26502650
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
26512651
TFL_SameFirstOperandAndFirstResultElementType,
2652-
SameOperandsAndResultShape,
2653-
NoQuantizableResult]> {
2652+
SameOperandsAndResultShape]> {
26542653
let summary = "Reciprocal of square root operator";
26552654

26562655
let description = [{

tensorflow/compiler/mlir/lite/quantization/quantization_utils.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,28 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
104104
if (!stats) return failure();
105105

106106
for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
107-
double min = FloatAttr::getValueAsDouble(*it++);
108-
double max = FloatAttr::getValueAsDouble(*it);
109-
TensorRangeSanityCheck(op, min, max);
110-
mins.push_back(min);
111-
maxs.push_back(max);
107+
double rmin = FloatAttr::getValueAsDouble(*it++);
108+
double rmax = FloatAttr::getValueAsDouble(*it);
109+
// The default nudging implementation of mlir quant library might cause
110+
// clamping during inference if the calibration range isn't wide enough.
111+
// So here we adjust the range to include 0.0.
112+
rmin = std::min(rmin, 0.0);
113+
rmax = std::max(rmax, 0.0);
114+
TensorRangeSanityCheck(op, rmin, rmax);
115+
mins.push_back(rmin);
116+
maxs.push_back(rmax);
112117
}
113118
quant_type =
114119
quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins,
115120
maxs, narrow_range, expressed, is_signed);
116121
} else if (auto stats = op.layerStats().dyn_cast<DenseFPElementsAttr>()) {
117122
double rmin = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
118123
double rmax = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
124+
// The default nudging implementation of mlir quant library might cause
125+
// clamping during inference if the calibration range isn't wide enough.
126+
// So here we adjust the range to include 0.0.
127+
rmin = std::min(rmin, 0.0);
128+
rmax = std::max(rmax, 0.0);
119129
TensorRangeSanityCheck(op, rmin, rmax);
120130
quant_type =
121131
quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,

tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,28 @@ func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
5656
// CHECK: return %[[dq2]]
5757
}
5858

59+
// CHECK-LABEL: prepareStatisticsNudge
60+
func @prepareStatisticsNudge(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
61+
%0 = "quant.stats"(%arg0) {
62+
layerStats = dense<[0.1, 1.0]> : tensor<2xf32>
63+
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
64+
%1 = "quant.stats"(%0) {
65+
layerStats = dense<[0.1, 1.0]> : tensor<2xf32>,
66+
axisStats = dense<[
67+
[-1.0, 1.0],
68+
[-8.0, -1.0],
69+
[-0.5, 0.5]
70+
]> : tensor<3x2xf32>, axis = 2 : i64
71+
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
72+
return %1 : tensor<8x4x3xf32>
73+
74+
// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>}
75+
// CHECK: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]])
76+
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[dq1]]) {qtype = tensor<8x4x3x!quant.uniform<i8:f32:2, {0.0078431372549019607:-1,0.031372549019607843:127,0.0039215686274509803:-1}>>}
77+
// CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]])
78+
// CHECK: return %[[dq2]]
79+
}
80+
5981
// CHECK-LABEL: preparePrelu
6082
func @preparePrelu(%arg0: tensor<1x10x10x3xf32>) -> tensor<1x10x10x3xf32> {
6183
%cst = "tfl.pseudo_const"() {value = dense<[[[1.66394591, 3.61694336, 2.0382936]]]> : tensor<1x1x3xf32>} : () -> tensor<1x1x3xf32>

0 commit comments

Comments
 (0)