Skip to content

Commit 8ba27e3

Browse files
Suharsh Sivakumartensorflower-gardener
Suharsh Sivakumar
authored andcommitted
Estimated moving average for quantize training activations with unknown range.
Change: 150038629
1 parent 5388b22 commit 8ba27e3

File tree

4 files changed

+470
-94
lines changed

4 files changed

+470
-94
lines changed

tensorflow/core/BUILD

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1849,6 +1849,28 @@ cc_test(
18491849
],
18501850
)
18511851

1852+
tf_cc_test(
1853+
name = "quantize_training_test",
1854+
srcs = ["graph/quantize_training_test.cc"],
1855+
deps = [
1856+
":all_kernels",
1857+
":core",
1858+
":core_cpu",
1859+
":core_cpu_internal",
1860+
":direct_session_internal",
1861+
":framework",
1862+
":framework_internal",
1863+
":lib",
1864+
":lib_internal",
1865+
":ops",
1866+
":protos_all_cc",
1867+
":protos_test_cc",
1868+
":test",
1869+
":test_main",
1870+
":testlib",
1871+
],
1872+
)
1873+
18521874
tf_cc_tests(
18531875
name = "higher_level_tests",
18541876
size = "small",
@@ -1897,7 +1919,6 @@ tf_cc_tests(
18971919
"graph/graph_test.cc",
18981920
"graph/node_builder_test.cc",
18991921
"graph/optimizer_cse_test.cc",
1900-
"graph/quantize_training_test.cc",
19011922
"graph/subgraph_test.cc",
19021923
"graph/tensor_id_test.cc",
19031924
"graph/validate_test.cc",

tensorflow/core/graph/quantize_training.cc

Lines changed: 221 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ limitations under the License.
3535

3636
namespace tensorflow {
3737
namespace {
38+
39+
// TODO(suharshs): If desired, make these values configurable.
40+
const uint32 kAllowedInputs = 2;
41+
const float kEMADecay = 0.999;
42+
3843
// Node types to rewrite. Insert quantize_and_dequantize op for their inputs.
3944
const std::unordered_set<string, StringPiece::Hasher> nodes_to_rewrite{
4045
"MatMul", "Conv2D"};
@@ -50,14 +55,13 @@ struct EdgeToConvert {
5055
float input_max;
5156

5257
EdgeToConvert(const Edge* e, int32 bits, bool sign, bool range, float min,
53-
float max) {
54-
edge = e;
55-
num_bits = bits;
56-
signed_input = sign;
57-
range_given = range;
58-
input_min = min;
59-
input_max = max;
60-
}
58+
float max)
59+
: edge(e),
60+
num_bits(bits),
61+
signed_input(sign),
62+
range_given(range),
63+
input_min(min),
64+
input_max(max) {}
6165
};
6266

6367
// Decide if a node is in backward pass by checking if its name is led by
@@ -83,6 +87,9 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input,
8387
*signed_input = false;
8488
*range_given = false;
8589
} else if (src_op == "Relu6") {
90+
// TODO(suharshs): Also the theoretical min and max is 0 and 6, if the
91+
// actual activations are somewhere in within this range, we can quantize
92+
// this even further. This is true for other activations like Sigmoid6 too.
8693
*signed_input = false;
8794
*range_given = true;
8895
*input_min = 0;
@@ -117,7 +124,7 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input,
117124
}
118125
} else {
119126
// Unknown type, could be the model input examples.
120-
// TODO: Set the params for input with user's hint.
127+
// TODO(jmchen): Set the params for input with user's hint.
121128
*signed_input = true;
122129
*range_given = false;
123130
return false;
@@ -126,29 +133,210 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input,
126133
return true;
127134
}
128135

136+
// Sets output to the Node that computes reduction axes corresponding to all
137+
// dimensions of input and return.
138+
Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input,
139+
Node** output) {
140+
name_prefix = strings::StrCat(name_prefix, "/ReductionAxes");
141+
Node* start;
142+
Tensor zero_tensor(DT_INT32, TensorShape());
143+
zero_tensor.flat<int32>()(0) = 0;
144+
TF_RETURN_IF_ERROR(
145+
NodeBuilder(strings::StrCat(name_prefix, "/RangeStart"), "Const")
146+
.Attr("dtype", DT_INT32)
147+
.Attr("value", zero_tensor)
148+
.Finalize(graph, &start));
149+
Node* delta;
150+
Tensor one_tensor(DT_INT32, TensorShape());
151+
one_tensor.flat<int32>()(0) = 1;
152+
TF_RETURN_IF_ERROR(
153+
NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta"), "Const")
154+
.Attr("dtype", DT_INT32)
155+
.Attr("value", one_tensor)
156+
.Finalize(graph, &delta));
157+
Node* rank;
158+
TF_RETURN_IF_ERROR(
159+
NodeBuilder(strings::StrCat(name_prefix, "/InputRank"), "Rank")
160+
.Input(input)
161+
.Finalize(graph, &rank));
162+
TF_RETURN_IF_ERROR(
163+
NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes"), "Range")
164+
.Input(start)
165+
.Input(rank)
166+
.Input(delta)
167+
.Finalize(graph, output));
168+
return Status::OK();
169+
}
170+
171+
// Computes the exponential moving average of input, updated in update_variable.
172+
Status MakeExponentialMovingAverage(Graph* graph, string name_prefix,
173+
const NodeBuilder::NodeOut& input,
174+
Node* decay, Node* update_variable,
175+
Node** assign_value) {
176+
// variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)]
177+
name_prefix = strings::StrCat(name_prefix, "/EMA");
178+
Node* one;
179+
Tensor one_tensor(DT_FLOAT, TensorShape());
180+
one_tensor.flat<float>()(0) = 1.0;
181+
TF_RETURN_IF_ERROR(
182+
NodeBuilder(strings::StrCat(name_prefix, "/OneConst"), "Const")
183+
.Attr("dtype", DT_FLOAT)
184+
.Attr("value", one_tensor)
185+
.Finalize(graph, &one));
186+
Node* decay_complement;
187+
TF_RETURN_IF_ERROR(
188+
NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement"), "Sub")
189+
.Input(one)
190+
.Input(decay)
191+
.Finalize(graph, &decay_complement));
192+
193+
Node* value_diff;
194+
TF_RETURN_IF_ERROR(
195+
NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff"), "Sub")
196+
.Input(update_variable)
197+
.Input(input)
198+
.Finalize(graph, &value_diff));
199+
Node* update_value;
200+
TF_RETURN_IF_ERROR(
201+
NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue"), "Mul")
202+
.Input(value_diff)
203+
.Input(decay_complement)
204+
.Finalize(graph, &update_value));
205+
206+
TF_RETURN_IF_ERROR(
207+
NodeBuilder(strings::StrCat(name_prefix, "/EMAValue"), "Sub")
208+
.Input(update_variable)
209+
.Input(update_value)
210+
.Finalize(graph, assign_value));
211+
return Status::OK();
212+
}
213+
214+
// Creates an automatically initialized exponential moving average variable.
215+
// This uses a switch op to assign a value to the variable on the first run,
216+
// and update with the moving average for all other runs:
217+
// init_val
218+
// |
219+
// var--is_init--switch
220+
// | true / \ false
221+
// | | |
222+
// | EMA init_val
223+
// | \ /
224+
// +----------- assign
225+
Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay,
226+
Node* init_val, Node** var) {
227+
TF_RETURN_IF_ERROR(
228+
NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2")
229+
.Attr("shape", TensorShape())
230+
.Attr("dtype", DT_FLOAT)
231+
.Finalize(graph, var));
232+
233+
Node* is_initialized;
234+
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"),
235+
"IsVariableInitialized")
236+
.Input(*var)
237+
.Finalize(graph, &is_initialized));
238+
Node* switch_node;
239+
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch"), "Switch")
240+
.Input(init_val)
241+
.Input(is_initialized)
242+
.Finalize(graph, &switch_node));
243+
NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut(switch_node, 0);
244+
NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut(switch_node, 1);
245+
246+
Node* ema_value;
247+
TF_RETURN_IF_ERROR(MakeExponentialMovingAverage(graph, name, output_true,
248+
decay, *var, &ema_value));
249+
250+
Node* assign_value;
251+
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge"), "Merge")
252+
.Input({output_false, ema_value})
253+
.Finalize(graph, &assign_value));
254+
255+
TF_RETURN_IF_ERROR(
256+
NodeBuilder(strings::StrCat(name, "/AssignValue"), "Assign")
257+
.Input(*var)
258+
.Input(assign_value)
259+
.Finalize(graph, var));
260+
return Status::OK();
261+
}
262+
263+
// Computes the min and max EMA of input and stores them in min_var and max_var.
264+
Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input,
265+
Node** min_var, Node** max_var) {
266+
// TODO(suharshs): The decay will be constant, so we could make only one for
267+
// all quantize_and_dequantize ops to share, this would have to live outside
268+
// this function.
269+
Tensor decay_tensor(DT_FLOAT, TensorShape());
270+
decay_tensor.flat<float>()(0) = kEMADecay;
271+
Node* decay;
272+
TF_RETURN_IF_ERROR(
273+
NodeBuilder(strings::StrCat(name_prefix, "/Decay"), "Const")
274+
.Attr("dtype", DT_FLOAT)
275+
.Attr("value", decay_tensor)
276+
.Finalize(graph, &decay));
277+
278+
Node* reduction_axes;
279+
TF_RETURN_IF_ERROR(
280+
MakeReductionAxes(graph, name_prefix, input, &reduction_axes));
281+
Node* min;
282+
string min_name = strings::StrCat(name_prefix, "/Min");
283+
TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min")
284+
.Input(input)
285+
.Input(reduction_axes)
286+
.Finalize(graph, &min));
287+
Node* max;
288+
string max_name = strings::StrCat(name_prefix, "/Max");
289+
TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max")
290+
.Input(input)
291+
.Input(reduction_axes)
292+
.Finalize(graph, &max));
293+
TF_RETURN_IF_ERROR(
294+
MakeInitializedEMAVariable(graph, min_name, decay, min, min_var));
295+
TF_RETURN_IF_ERROR(
296+
MakeInitializedEMAVariable(graph, max_name, decay, max, max_var));
297+
return Status::OK();
298+
}
299+
300+
// Makes an input min and max constant if the range is given. Otherwise, makes
301+
// min and max variables that are updated by an EMA.
302+
Status MakeInputMinMax(Graph* graph, const string& name_prefix,
303+
const EdgeToConvert& edge, Node** input_min,
304+
Node** input_max) {
305+
if (edge.range_given) {
306+
// Make constant nodes for the input_min and input_max if the range is
307+
// provided.
308+
Tensor input_min_tensor(DT_FLOAT, TensorShape());
309+
input_min_tensor.flat<float>()(0) = edge.input_min;
310+
TF_RETURN_IF_ERROR(
311+
NodeBuilder(strings::StrCat(name_prefix, "/InputMin"), "Const")
312+
.Attr("dtype", DT_FLOAT)
313+
.Attr("value", input_min_tensor)
314+
.Finalize(graph, input_min));
315+
Tensor input_max_tensor(DT_FLOAT, TensorShape());
316+
input_max_tensor.flat<float>()(0) = edge.input_max;
317+
TF_RETURN_IF_ERROR(
318+
NodeBuilder(strings::StrCat(name_prefix, "/InputMax"), "Const")
319+
.Attr("dtype", DT_FLOAT)
320+
.Attr("value", input_max_tensor)
321+
.Finalize(graph, input_max));
322+
} else {
323+
// If the range is not given, estimate the range with EMA variables.
324+
TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(),
325+
input_min, input_max));
326+
}
327+
328+
return Status::OK();
329+
}
330+
129331
// Adds a QuantizeAndDequantizeV2Op (and required input nodes) based on edge.
130332
// The result is stored in convert_node.
131333
Status MakeQuantizeAndDequantizeV2(Graph* graph, const string& name_prefix,
132334
const EdgeToConvert& edge,
133335
Node** convert_node) {
134336
Node* input_min;
135337
Node* input_max;
136-
// Make constant nodes for the input_min and input_max if the range is
137-
// provided.
138-
Tensor input_min_tensor(DT_FLOAT, TensorShape());
139-
input_min_tensor.flat<float>()(0) = edge.input_min;
140-
string min_name = strings::StrCat(name_prefix, "/InputMin");
141-
TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Const")
142-
.Attr("dtype", DT_FLOAT)
143-
.Attr("value", input_min_tensor)
144-
.Finalize(graph, &input_min));
145-
Tensor input_max_tensor(DT_FLOAT, TensorShape());
146-
input_max_tensor.flat<float>()(0) = edge.input_max;
147-
string max_name = strings::StrCat(name_prefix, "/InputMax");
148-
TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Const")
149-
.Attr("dtype", DT_FLOAT)
150-
.Attr("value", input_max_tensor)
151-
.Finalize(graph, &input_max));
338+
TF_RETURN_IF_ERROR(
339+
MakeInputMinMax(graph, name_prefix, edge, &input_min, &input_max));
152340

153341
string quant_name = strings::StrCat(name_prefix, "/QuantizeAndDequantizeV2");
154342
TF_RETURN_IF_ERROR(NodeBuilder(quant_name, "QuantizeAndDequantizeV2")
@@ -157,16 +345,16 @@ Status MakeQuantizeAndDequantizeV2(Graph* graph, const string& name_prefix,
157345
.Input(input_max)
158346
.Attr("signed_input", edge.signed_input)
159347
.Attr("num_bits", edge.num_bits)
160-
.Attr("range_given", edge.range_given)
348+
.Attr("range_given", true)
161349
.Finalize(graph, convert_node));
162350
return Status::OK();
163351
}
164352

165353
// Insert conversion op, connect it to the graph and remove the old edge.
166354
Status ProcessTargetEdges(Graph* graph,
167355
const std::vector<EdgeToConvert>& target_edges) {
168-
// Remember previous convert ops to avoid duplicated conversion on the same
169-
// input.
356+
// Remember previously converted ops to avoid duplicated conversion on the
357+
// same input.
170358
std::unordered_map<string, Node*, StringPiece::Hasher> name_index;
171359
for (const EdgeToConvert edge : target_edges) {
172360
Node* convert_node;
@@ -230,17 +418,15 @@ Status DoQuantizeTraining(int32 num_bits, Graph* graph) {
230418
&range_given, &input_min, &input_max);
231419
if (!known_op) {
232420
// Unknown op is considered as input.
233-
// Only support one input for now.
234-
// TODO: Make this configurable if this is the desirable way to find
235-
// input.
236-
if (potential_input > 0) {
421+
potential_input++;
422+
if (potential_input > kAllowedInputs) {
237423
return errors::Unimplemented(
238-
"Find a second unknown op: ", edge->src()->name(),
424+
"Found an unknown op: ", edge->src()->name(),
239425
" with type: ", edge->src()->type_string(),
240426
"; Unknown ops are considered as model input for now and "
241-
"only 1 input is supported currently.");
427+
"only ",
428+
kAllowedInputs, " inputs are supported currently.");
242429
}
243-
potential_input++;
244430
}
245431

246432
target_edges.emplace_back(EdgeToConvert(

0 commit comments

Comments
 (0)