Skip to content

Commit 5a8af55

Browse files
Switch FakeQuant and Dequantize ops in array_ops to use C++ shape functions.
Change: 137756007
1 parent 12629a0 commit 5a8af55

File tree

3 files changed

+60
-22
lines changed

3 files changed

+60
-22
lines changed

tensorflow/core/ops/array_ops.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4392,6 +4392,7 @@ REGISTER_OP("FakeQuantWithMinMaxArgs")
43924392
.Attr("max: float = 6.0")
43934393
.Input("inputs: float")
43944394
.Output("outputs: float")
4395+
.SetShapeFn(shape_inference::UnchangedShape)
43954396
.Doc(R"doc(
43964397
Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
43974398
@@ -4422,6 +4423,13 @@ REGISTER_OP("FakeQuantWithMinMaxVars")
44224423
.Input("min: float")
44234424
.Input("max: float")
44244425
.Output("outputs: float")
4426+
.SetShapeFn([](InferenceContext* c) {
4427+
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
4428+
ShapeHandle unused;
4429+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
4430+
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
4431+
return Status::OK();
4432+
})
44254433
.Doc(R"doc(
44264434
Fake-quantize the 'inputs' tensor of type float and shape `[b, h, w, d]` via
44274435
global float scalars `min` and `max` to 'outputs' tensor of same shape as
@@ -4461,6 +4469,20 @@ REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
44614469
.Input("min: float")
44624470
.Input("max: float")
44634471
.Output("outputs: float")
4472+
.SetShapeFn([](InferenceContext* c) {
4473+
ShapeHandle input, min, max;
4474+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
4475+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &min));
4476+
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max));
4477+
4478+
DimensionHandle unused;
4479+
TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(min, 0), &unused));
4480+
TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -1), c->Dim(max, 0), &unused));
4481+
TF_RETURN_IF_ERROR(c->Merge(c->Dim(min, 0), c->Dim(max, 0), &unused));
4482+
4483+
c->set_output(0, input);
4484+
return Status::OK();
4485+
})
44644486
.Doc(R"doc(
44654487
Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`,
44664488
`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]`

tensorflow/core/ops/array_ops_test.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,4 +1493,36 @@ TEST(ArrayOpsTest, StridedSliceGrad_ShapeFn) {
14931493
INFER_OK(op, "[4];?;?;?;?", "[1,2,3,4]");
14941494
}
14951495

1496+
TEST(ArrayOpsTest, UnchangedWithQuantizationScalars_ShapeFn) {
1497+
for (const char* op_name : {"Dequantize", "FakeQuantWithMinMaxVars"}) {
1498+
ShapeInferenceTestOp op(op_name);
1499+
1500+
INFER_OK(op, "?;?;?", "in0");
1501+
INFER_OK(op, "[1,?,3];[];[]", "in0");
1502+
1503+
// Rank check scalars.
1504+
INFER_ERROR("be rank 0", op, "[1,?,3];[1];[]");
1505+
INFER_ERROR("be rank 0", op, "[1,?,3];[];[1]");
1506+
}
1507+
}
1508+
1509+
TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
1510+
ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannel");
1511+
1512+
INFER_OK(op, "?;?;?", "?");
1513+
INFER_OK(op, "[?];?;?", "in0");
1514+
INFER_OK(op, "[1,?,3];[3];[3]", "in0");
1515+
INFER_OK(op, "[3];[3];[3]", "in0");
1516+
1517+
// Rank check vectors.
1518+
INFER_ERROR("be rank 1", op, "[1,?,3];[1];[]");
1519+
INFER_ERROR("be rank 1", op, "[1,?,3];[];[1]");
1520+
1521+
// Vectors must match each other, and match last dim of input.
1522+
INFER_ERROR("must be equal", op, "[1,?,3];[2];[?]");
1523+
INFER_ERROR("must be equal", op, "[1,?,3];[?];[2]");
1524+
INFER_ERROR("must be equal", op, "[1,?,?];[1];[2]");
1525+
INFER_ERROR("must be equal", op, "[5];[4];[?]");
1526+
}
1527+
14961528
} // end namespace tensorflow

tensorflow/python/ops/array_ops.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,18 +2011,8 @@ def _EditDistanceShape(op):
20112011
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[2, 5])
20122012

20132013

2014-
@ops.RegisterShape("Quantize")
2015-
@ops.RegisterShape("Dequantize")
2016-
def _QuantizeDequantizeShape(op):
2017-
unused_min_range = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
2018-
unused_max_range = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
2019-
return common_shapes.unchanged_shape(op)
2020-
2021-
2022-
@ops.RegisterShape("FakeQuantWithMinMaxArgs")
2023-
def _FakeQuantWithMinMaxArgsShape(op):
2024-
"""Shape function for FakeQuantWithMinMaxArgs op: preserve the input shape."""
2025-
return [op.inputs[0].get_shape()]
2014+
ops.RegisterShape("Quantize")(common_shapes.call_cpp_shape_fn)
2015+
ops.RegisterShape("Dequantize")(common_shapes.call_cpp_shape_fn)
20262016

20272017

20282018
@ops.RegisterGradient("FakeQuantWithMinMaxArgs")
@@ -2031,10 +2021,10 @@ def _FakeQuantWithMinMaxArgsGradient(op, grad):
20312021
return fake_quant_with_min_max_args_gradient(grad, op.inputs[0])
20322022

20332023

2034-
@ops.RegisterShape("FakeQuantWithMinMaxVars")
2035-
def _FakeQuantWithMinMaxVarsShape(op):
2036-
"""Shape function for FakeQuantWithMinMaxVars op: preserve the input shape."""
2037-
return [op.inputs[0].get_shape()]
2024+
ops.RegisterShape("FakeQuantWithMinMaxArgs")(common_shapes.call_cpp_shape_fn)
2025+
ops.RegisterShape("FakeQuantWithMinMaxVars")(common_shapes.call_cpp_shape_fn)
2026+
ops.RegisterShape("FakeQuantWithMinMaxVarsPerChannel")(
2027+
common_shapes.call_cpp_shape_fn)
20382028

20392029

20402030
@ops.RegisterGradient("FakeQuantWithMinMaxVars")
@@ -2044,12 +2034,6 @@ def _FakeQuantWithMinMaxVarsGradient(op, grad):
20442034
op.inputs[2])
20452035

20462036

2047-
@ops.RegisterShape("FakeQuantWithMinMaxVarsPerChannel")
2048-
def _FakeQuantWithMinMaxVarsPerChannelShape(op):
2049-
"""Shape function for FakeQuantWithMinMaxVarsPerChannel op: input shape."""
2050-
return [op.inputs[0].get_shape()]
2051-
2052-
20532037
@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel")
20542038
def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
20552039
"""Gradient for FakeQuantWithMinMaxVarsPerChannel op."""

0 commit comments

Comments
 (0)