Skip to content

Commit e1cb1f1

Browse files
abatterytensorflower-gardener
authored andcommitted
Provide user friendly errors when upgrading legacy control flow ops is failed
PiperOrigin-RevId: 347731417 Change-Id: Id9ddaf308be4b9b13d9aa266f6dda9165fb4f4a6
1 parent 442824b commit e1cb1f1

File tree

8 files changed

+123
-4
lines changed

8 files changed

+123
-4
lines changed

tensorflow/compiler/mlir/tensorflow/translate/import_model.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,13 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
190190
restrict_functionalization_to_tpu_nodes
191191
? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); }
192192
: NodeFilter{};
193-
return FunctionalizeControlFlow(graph, flib_def, node_filter,
194-
/*include_functions=*/true);
193+
TF_RETURN_WITH_CONTEXT_IF_ERROR(
194+
FunctionalizeControlFlow(graph, flib_def, node_filter,
195+
/*include_functions=*/true),
196+
"Failed to functionalize Control Flow V1 ops. Consider using Control "
197+
"Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/tf/"
198+
"compat/v1/enable_control_flow_v2.");
199+
return Status::OK();
195200
}
196201

197202
// Stateful helper class to import a TensorFlow model into an MLIR Module.

tensorflow/lite/python/BUILD

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,10 @@ py_library(
186186
py_test(
187187
name = "lite_test",
188188
srcs = ["lite_test.py"],
189-
data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
189+
data = [
190+
"//tensorflow/lite/python/testdata:control_flow_v1.pbtxt",
191+
"@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
192+
],
190193
python_version = "PY3",
191194
shard_count = 4,
192195
srcs_version = "PY2AND3",
@@ -205,6 +208,9 @@ py_test(
205208
py_test(
206209
name = "lite_v2_test",
207210
srcs = ["lite_v2_test.py"],
211+
data = [
212+
"//tensorflow/lite/python/testdata/control_flow_v1_saved_model:saved_model.pb",
213+
],
208214
python_version = "PY3",
209215
shard_count = 12,
210216
srcs_version = "PY2AND3",

tensorflow/lite/python/lite_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2740,5 +2740,24 @@ def testAttrs(self):
27402740
self.assertIsNone(converter.conversion_summary_dir)
27412741

27422742

2743+
class ControlFlowV1OpsTest(LiteTest):
2744+
2745+
def testConverterErrorOnControlFlowV1Ops(self):
2746+
graph_def_file = resource_loader.get_path_to_datafile(
2747+
'testdata/control_flow_v1.pbtxt')
2748+
input_arrays = ['a', 'b', 'c', 'd']
2749+
output_arrays = ['Merge']
2750+
2751+
converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
2752+
input_arrays,
2753+
output_arrays)
2754+
with self.assertRaises(ConverterError) as error:
2755+
converter.convert()
2756+
self.assertIn(
2757+
'Failed to functionalize Control Flow V1 ops. Consider using Control '
2758+
'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
2759+
'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
2760+
2761+
27432762
if __name__ == '__main__':
27442763
test.main()

tensorflow/lite/python/lite_v2_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import tensorflow as tf
2929

3030
from tensorflow.lite.kernels.hashtable import pywrap_hashtable_ops as hashtable_ops_registerer
31+
from tensorflow.lite.python import convert
3132
from tensorflow.lite.python import lite
3233
from tensorflow.lite.python import lite_v2_test_util
3334
from tensorflow.lite.python.convert import mlir_quantize
@@ -38,6 +39,7 @@
3839
from tensorflow.python.framework import ops
3940
from tensorflow.python.framework import test_util
4041
from tensorflow.python.lib.io import file_io
42+
from tensorflow.python.platform import resource_loader
4143
from tensorflow.python.platform import test
4244
from tensorflow.python.saved_model import save_options
4345
from tensorflow.python.saved_model import saved_model
@@ -1263,6 +1265,18 @@ def model(x, b):
12631265
tflite_model, [input_data['x'], input_data['b']])[0]
12641266
self.assertAllClose(expected_value, actual_value)
12651267

1268+
@test_util.run_v2_only
1269+
def testConverterErrorOnControlFlowV1Ops(self):
1270+
filename = resource_loader.get_path_to_datafile(
1271+
'testdata/control_flow_v1_saved_model')
1272+
converter = lite.TFLiteConverterV2.from_saved_model(filename)
1273+
with self.assertRaises(convert.ConverterError) as error:
1274+
converter.convert()
1275+
self.assertIn(
1276+
'Failed to functionalize Control Flow V1 ops. Consider using Control '
1277+
'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
1278+
'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
1279+
12661280
@test_util.run_v2_only
12671281
def testStaticRnn(self):
12681282
input_data = tf.constant(

tensorflow/lite/python/testdata/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ package(
1212
licenses = ["notice"], # Apache 2.0,
1313
)
1414

15-
exports_files(glob(["*.pb"]))
15+
exports_files(glob([
16+
"*.pb",
17+
"*.pbtxt",
18+
]))
1619

1720
tf_to_tflite(
1821
name = "permute_float",
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
node {
2+
name: "a"
3+
op: "Placeholder"
4+
attr {
5+
key: "dtype"
6+
value {
7+
type: DT_FLOAT
8+
}
9+
}
10+
}
11+
node {
12+
name: "b"
13+
op: "Placeholder"
14+
attr {
15+
key: "dtype"
16+
value {
17+
type: DT_FLOAT
18+
}
19+
}
20+
}
21+
node {
22+
name: "c"
23+
op: "Placeholder"
24+
attr {
25+
key: "dtype"
26+
value {
27+
type: DT_FLOAT
28+
}
29+
}
30+
}
31+
node {
32+
name: "d"
33+
op: "Placeholder"
34+
attr {
35+
key: "dtype"
36+
value {
37+
type: DT_FLOAT
38+
}
39+
}
40+
}
41+
node {
42+
name: "Merge"
43+
op: "Merge"
44+
input: "a"
45+
input: "b"
46+
input: "c"
47+
input: "d"
48+
attr {
49+
key: "N"
50+
value {
51+
i: 4
52+
}
53+
}
54+
attr {
55+
key: "T"
56+
value {
57+
type: DT_FLOAT
58+
}
59+
}
60+
}
61+
62+
versions {
63+
producer: 27
64+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package(
2+
default_visibility = ["//tensorflow:internal"],
3+
licenses = ["notice"], # Apache 2.0,
4+
)
5+
6+
exports_files([
7+
"saved_model.pb",
8+
])
Binary file not shown.

0 commit comments

Comments
 (0)