@@ -35,6 +35,11 @@ limitations under the License.
35
35
36
36
namespace tensorflow {
37
37
namespace {
38
+
39
+ // TODO(suharshs): If desired, make these values configurable.
40
+ const uint32 kAllowedInputs = 2 ;
41
+ const float kEMADecay = 0.999 ;
42
+
38
43
// Node types to rewrite. Insert quantize_and_dequantize op for their inputs.
39
44
const std::unordered_set<string, StringPiece::Hasher> nodes_to_rewrite{
40
45
" MatMul" , " Conv2D" };
@@ -50,14 +55,13 @@ struct EdgeToConvert {
50
55
float input_max;
51
56
52
57
EdgeToConvert (const Edge* e, int32 bits, bool sign, bool range, float min,
53
- float max) {
54
- edge = e;
55
- num_bits = bits;
56
- signed_input = sign;
57
- range_given = range;
58
- input_min = min;
59
- input_max = max;
60
- }
58
+ float max)
59
+ : edge(e),
60
+ num_bits (bits),
61
+ signed_input(sign),
62
+ range_given(range),
63
+ input_min(min),
64
+ input_max(max) {}
61
65
};
62
66
63
67
// Decide if a node is in backward pass by checking if its name is led by
@@ -83,6 +87,9 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input,
83
87
*signed_input = false ;
84
88
*range_given = false ;
85
89
} else if (src_op == " Relu6" ) {
90
+ // TODO(suharshs): Also the theoretical min and max is 0 and 6, if the
91
+ // actual activations are somewhere in within this range, we can quantize
92
+ // this even further. This is true for other activations like Sigmoid6 too.
86
93
*signed_input = false ;
87
94
*range_given = true ;
88
95
*input_min = 0 ;
@@ -117,7 +124,7 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input,
117
124
}
118
125
} else {
119
126
// Unknown type, could be the model input examples.
120
- // TODO: Set the params for input with user's hint.
127
+ // TODO(jmchen) : Set the params for input with user's hint.
121
128
*signed_input = true ;
122
129
*range_given = false ;
123
130
return false ;
@@ -126,29 +133,210 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input,
126
133
return true ;
127
134
}
128
135
136
+ // Sets output to the Node that computes reduction axes corresponding to all
137
+ // dimensions of input and return.
138
+ Status MakeReductionAxes (Graph* graph, string name_prefix, Node* input,
139
+ Node** output) {
140
+ name_prefix = strings::StrCat (name_prefix, " /ReductionAxes" );
141
+ Node* start;
142
+ Tensor zero_tensor (DT_INT32, TensorShape ());
143
+ zero_tensor.flat <int32>()(0 ) = 0 ;
144
+ TF_RETURN_IF_ERROR (
145
+ NodeBuilder (strings::StrCat (name_prefix, " /RangeStart" ), " Const" )
146
+ .Attr (" dtype" , DT_INT32)
147
+ .Attr (" value" , zero_tensor)
148
+ .Finalize (graph, &start));
149
+ Node* delta;
150
+ Tensor one_tensor (DT_INT32, TensorShape ());
151
+ one_tensor.flat <int32>()(0 ) = 1 ;
152
+ TF_RETURN_IF_ERROR (
153
+ NodeBuilder (strings::StrCat (name_prefix, " /RangeDelta" ), " Const" )
154
+ .Attr (" dtype" , DT_INT32)
155
+ .Attr (" value" , one_tensor)
156
+ .Finalize (graph, &delta));
157
+ Node* rank;
158
+ TF_RETURN_IF_ERROR (
159
+ NodeBuilder (strings::StrCat (name_prefix, " /InputRank" ), " Rank" )
160
+ .Input (input)
161
+ .Finalize (graph, &rank));
162
+ TF_RETURN_IF_ERROR (
163
+ NodeBuilder (strings::StrCat (name_prefix, " /ReductionAxes" ), " Range" )
164
+ .Input (start)
165
+ .Input (rank)
166
+ .Input (delta)
167
+ .Finalize (graph, output));
168
+ return Status::OK ();
169
+ }
170
+
171
+ // Computes the exponential moving average of input, updated in update_variable.
172
+ Status MakeExponentialMovingAverage (Graph* graph, string name_prefix,
173
+ const NodeBuilder::NodeOut& input,
174
+ Node* decay, Node* update_variable,
175
+ Node** assign_value) {
176
+ // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)]
177
+ name_prefix = strings::StrCat (name_prefix, " /EMA" );
178
+ Node* one;
179
+ Tensor one_tensor (DT_FLOAT, TensorShape ());
180
+ one_tensor.flat <float >()(0 ) = 1.0 ;
181
+ TF_RETURN_IF_ERROR (
182
+ NodeBuilder (strings::StrCat (name_prefix, " /OneConst" ), " Const" )
183
+ .Attr (" dtype" , DT_FLOAT)
184
+ .Attr (" value" , one_tensor)
185
+ .Finalize (graph, &one));
186
+ Node* decay_complement;
187
+ TF_RETURN_IF_ERROR (
188
+ NodeBuilder (strings::StrCat (name_prefix, " /DecayComplement" ), " Sub" )
189
+ .Input (one)
190
+ .Input (decay)
191
+ .Finalize (graph, &decay_complement));
192
+
193
+ Node* value_diff;
194
+ TF_RETURN_IF_ERROR (
195
+ NodeBuilder (strings::StrCat (name_prefix, " /ValueDiff" ), " Sub" )
196
+ .Input (update_variable)
197
+ .Input (input)
198
+ .Finalize (graph, &value_diff));
199
+ Node* update_value;
200
+ TF_RETURN_IF_ERROR (
201
+ NodeBuilder (strings::StrCat (name_prefix, " /UpdateValue" ), " Mul" )
202
+ .Input (value_diff)
203
+ .Input (decay_complement)
204
+ .Finalize (graph, &update_value));
205
+
206
+ TF_RETURN_IF_ERROR (
207
+ NodeBuilder (strings::StrCat (name_prefix, " /EMAValue" ), " Sub" )
208
+ .Input (update_variable)
209
+ .Input (update_value)
210
+ .Finalize (graph, assign_value));
211
+ return Status::OK ();
212
+ }
213
+
214
+ // Creates an automatically initialized exponential moving average variable.
215
+ // This uses a switch op to assign a value to the variable on the first run,
216
+ // and update with the moving average for all other runs:
217
+ // init_val
218
+ // |
219
+ // var--is_init--switch
220
+ // | true / \ false
221
+ // | | |
222
+ // | EMA init_val
223
+ // | \ /
224
+ // +----------- assign
225
+ Status MakeInitializedEMAVariable (Graph* graph, const string& name, Node* decay,
226
+ Node* init_val, Node** var) {
227
+ TF_RETURN_IF_ERROR (
228
+ NodeBuilder (strings::StrCat (name, " /Variable" ), " VariableV2" )
229
+ .Attr (" shape" , TensorShape ())
230
+ .Attr (" dtype" , DT_FLOAT)
231
+ .Finalize (graph, var));
232
+
233
+ Node* is_initialized;
234
+ TF_RETURN_IF_ERROR (NodeBuilder (strings::StrCat (name, " /IsInitialized" ),
235
+ " IsVariableInitialized" )
236
+ .Input (*var)
237
+ .Finalize (graph, &is_initialized));
238
+ Node* switch_node;
239
+ TF_RETURN_IF_ERROR (NodeBuilder (strings::StrCat (name, " /Switch" ), " Switch" )
240
+ .Input (init_val)
241
+ .Input (is_initialized)
242
+ .Finalize (graph, &switch_node));
243
+ NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut (switch_node, 0 );
244
+ NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut (switch_node, 1 );
245
+
246
+ Node* ema_value;
247
+ TF_RETURN_IF_ERROR (MakeExponentialMovingAverage (graph, name, output_true,
248
+ decay, *var, &ema_value));
249
+
250
+ Node* assign_value;
251
+ TF_RETURN_IF_ERROR (NodeBuilder (strings::StrCat (name, " /Merge" ), " Merge" )
252
+ .Input ({output_false, ema_value})
253
+ .Finalize (graph, &assign_value));
254
+
255
+ TF_RETURN_IF_ERROR (
256
+ NodeBuilder (strings::StrCat (name, " /AssignValue" ), " Assign" )
257
+ .Input (*var)
258
+ .Input (assign_value)
259
+ .Finalize (graph, var));
260
+ return Status::OK ();
261
+ }
262
+
263
+ // Computes the min and max EMA of input and stores them in min_var and max_var.
264
+ Status MakeEMAMinMaxVars (Graph* graph, const string& name_prefix, Node* input,
265
+ Node** min_var, Node** max_var) {
266
+ // TODO(suharshs): The decay will be constant, so we could make only one for
267
+ // all quantize_and_dequantize ops to share, this would have to live outside
268
+ // this function.
269
+ Tensor decay_tensor (DT_FLOAT, TensorShape ());
270
+ decay_tensor.flat <float >()(0 ) = kEMADecay ;
271
+ Node* decay;
272
+ TF_RETURN_IF_ERROR (
273
+ NodeBuilder (strings::StrCat (name_prefix, " /Decay" ), " Const" )
274
+ .Attr (" dtype" , DT_FLOAT)
275
+ .Attr (" value" , decay_tensor)
276
+ .Finalize (graph, &decay));
277
+
278
+ Node* reduction_axes;
279
+ TF_RETURN_IF_ERROR (
280
+ MakeReductionAxes (graph, name_prefix, input, &reduction_axes));
281
+ Node* min;
282
+ string min_name = strings::StrCat (name_prefix, " /Min" );
283
+ TF_RETURN_IF_ERROR (NodeBuilder (min_name, " Min" )
284
+ .Input (input)
285
+ .Input (reduction_axes)
286
+ .Finalize (graph, &min));
287
+ Node* max;
288
+ string max_name = strings::StrCat (name_prefix, " /Max" );
289
+ TF_RETURN_IF_ERROR (NodeBuilder (max_name, " Max" )
290
+ .Input (input)
291
+ .Input (reduction_axes)
292
+ .Finalize (graph, &max));
293
+ TF_RETURN_IF_ERROR (
294
+ MakeInitializedEMAVariable (graph, min_name, decay, min, min_var));
295
+ TF_RETURN_IF_ERROR (
296
+ MakeInitializedEMAVariable (graph, max_name, decay, max, max_var));
297
+ return Status::OK ();
298
+ }
299
+
300
+ // Makes an input min and max constant if the range is given. Otherwise, makes
301
+ // min and max variables that are updated by an EMA.
302
+ Status MakeInputMinMax (Graph* graph, const string& name_prefix,
303
+ const EdgeToConvert& edge, Node** input_min,
304
+ Node** input_max) {
305
+ if (edge.range_given ) {
306
+ // Make constant nodes for the input_min and input_max if the range is
307
+ // provided.
308
+ Tensor input_min_tensor (DT_FLOAT, TensorShape ());
309
+ input_min_tensor.flat <float >()(0 ) = edge.input_min ;
310
+ TF_RETURN_IF_ERROR (
311
+ NodeBuilder (strings::StrCat (name_prefix, " /InputMin" ), " Const" )
312
+ .Attr (" dtype" , DT_FLOAT)
313
+ .Attr (" value" , input_min_tensor)
314
+ .Finalize (graph, input_min));
315
+ Tensor input_max_tensor (DT_FLOAT, TensorShape ());
316
+ input_max_tensor.flat <float >()(0 ) = edge.input_max ;
317
+ TF_RETURN_IF_ERROR (
318
+ NodeBuilder (strings::StrCat (name_prefix, " /InputMax" ), " Const" )
319
+ .Attr (" dtype" , DT_FLOAT)
320
+ .Attr (" value" , input_max_tensor)
321
+ .Finalize (graph, input_max));
322
+ } else {
323
+ // If the range is not given, estimate the range with EMA variables.
324
+ TF_RETURN_IF_ERROR (MakeEMAMinMaxVars (graph, name_prefix, edge.edge ->src (),
325
+ input_min, input_max));
326
+ }
327
+
328
+ return Status::OK ();
329
+ }
330
+
129
331
// Adds a QuantizeAndDequantizeV2Op (and required input nodes) based on edge.
130
332
// The result is stored in convert_node.
131
333
Status MakeQuantizeAndDequantizeV2 (Graph* graph, const string& name_prefix,
132
334
const EdgeToConvert& edge,
133
335
Node** convert_node) {
134
336
Node* input_min;
135
337
Node* input_max;
136
- // Make constant nodes for the input_min and input_max if the range is
137
- // provided.
138
- Tensor input_min_tensor (DT_FLOAT, TensorShape ());
139
- input_min_tensor.flat <float >()(0 ) = edge.input_min ;
140
- string min_name = strings::StrCat (name_prefix, " /InputMin" );
141
- TF_RETURN_IF_ERROR (NodeBuilder (min_name, " Const" )
142
- .Attr (" dtype" , DT_FLOAT)
143
- .Attr (" value" , input_min_tensor)
144
- .Finalize (graph, &input_min));
145
- Tensor input_max_tensor (DT_FLOAT, TensorShape ());
146
- input_max_tensor.flat <float >()(0 ) = edge.input_max ;
147
- string max_name = strings::StrCat (name_prefix, " /InputMax" );
148
- TF_RETURN_IF_ERROR (NodeBuilder (max_name, " Const" )
149
- .Attr (" dtype" , DT_FLOAT)
150
- .Attr (" value" , input_max_tensor)
151
- .Finalize (graph, &input_max));
338
+ TF_RETURN_IF_ERROR (
339
+ MakeInputMinMax (graph, name_prefix, edge, &input_min, &input_max));
152
340
153
341
string quant_name = strings::StrCat (name_prefix, " /QuantizeAndDequantizeV2" );
154
342
TF_RETURN_IF_ERROR (NodeBuilder (quant_name, " QuantizeAndDequantizeV2" )
@@ -157,16 +345,16 @@ Status MakeQuantizeAndDequantizeV2(Graph* graph, const string& name_prefix,
157
345
.Input (input_max)
158
346
.Attr (" signed_input" , edge.signed_input )
159
347
.Attr (" num_bits" , edge.num_bits )
160
- .Attr (" range_given" , edge. range_given )
348
+ .Attr (" range_given" , true )
161
349
.Finalize (graph, convert_node));
162
350
return Status::OK ();
163
351
}
164
352
165
353
// Insert conversion op, connect it to the graph and remove the old edge.
166
354
Status ProcessTargetEdges (Graph* graph,
167
355
const std::vector<EdgeToConvert>& target_edges) {
168
- // Remember previous convert ops to avoid duplicated conversion on the same
169
- // input.
356
+ // Remember previously converted ops to avoid duplicated conversion on the
357
+ // same input.
170
358
std::unordered_map<string, Node*, StringPiece::Hasher> name_index;
171
359
for (const EdgeToConvert edge : target_edges) {
172
360
Node* convert_node;
@@ -230,17 +418,15 @@ Status DoQuantizeTraining(int32 num_bits, Graph* graph) {
230
418
&range_given, &input_min, &input_max);
231
419
if (!known_op) {
232
420
// Unknown op is considered as input.
233
- // Only support one input for now.
234
- // TODO: Make this configurable if this is the desirable way to find
235
- // input.
236
- if (potential_input > 0 ) {
421
+ potential_input++;
422
+ if (potential_input > kAllowedInputs ) {
237
423
return errors::Unimplemented (
238
- " Find a second unknown op: " , edge->src ()->name (),
424
+ " Found an unknown op: " , edge->src ()->name (),
239
425
" with type: " , edge->src ()->type_string (),
240
426
" ; Unknown ops are considered as model input for now and "
241
- " only 1 input is supported currently." );
427
+ " only " ,
428
+ kAllowedInputs , " inputs are supported currently." );
242
429
}
243
- potential_input++;
244
430
}
245
431
246
432
target_edges.emplace_back (EdgeToConvert (
0 commit comments