Skip to content

Commit 50ea0bc

Browse files
Abdurrahman Akkastensorflower-gardener
authored andcommitted
Mark variable tensors as known in SubgraphWriter::CheckInputOutput.
PiperOrigin-RevId: 348859590 Change-Id: I234b9c170c24863e6221167b27a7dcccd9bf126c
1 parent 7dfcca7 commit 50ea0bc

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

tensorflow/lite/tools/serialization/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ cc_library(
3636
"//tensorflow/lite/schema:schema_fbs_with_reflection",
3737
"//tensorflow/lite/schema:schema_utils",
3838
"@com_google_absl//absl/container:flat_hash_map",
39+
"@com_google_absl//absl/container:flat_hash_set",
3940
],
4041
)
4142

tensorflow/lite/tools/serialization/writer_lib.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <unordered_map>
2020
#include <unordered_set>
2121

22+
#include "absl/container/flat_hash_set.h"
2223
#include "tensorflow/lite/builtin_op_data.h"
2324
#include "tensorflow/lite/c/common.h"
2425
#include "tensorflow/lite/context_util.h"
@@ -326,7 +327,9 @@ TfLiteStatus SubgraphWriter::RegisterCustomWriter(
326327
TfLiteStatus SubgraphWriter::CheckInputOutput(
327328
const std::vector<int>& inputs, const std::vector<int>& outputs,
328329
const std::vector<int>& execution_plan) {
329-
std::unordered_set<int> known_tensors(inputs.begin(), inputs.end());
330+
absl::flat_hash_set<int> known_tensors(inputs.begin(), inputs.end());
331+
known_tensors.insert(subgraph_->variables().begin(),
332+
subgraph_->variables().end());
330333
// Scan execution plan and confirm input tensors are known before each node
331334
// executes. Then append output tensors to known tensors.
332335
for (int op_index : execution_plan) {

tensorflow/lite/tools/serialization/writer_lib_test.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include <tuple>
2323

2424
#include <gtest/gtest.h>
25+
#include "tensorflow/lite/c/builtin_op_data.h"
2526
#include "tensorflow/lite/c/common.h"
2627
#include "tensorflow/lite/interpreter.h"
2728
#include "tensorflow/lite/kernels/register.h"
@@ -230,6 +231,50 @@ TEST_P(SingleSubgraphTest, CustomInputOutputErrorCasesTest) {
230231
kTfLiteOk);
231232
}
232233

234+
// Tests if SetCustomInputOutput handles variable tensors correctly.
235+
TEST_P(SingleSubgraphTest, CustomInputOutputVariableTensorTest) {
236+
Interpreter interpreter;
237+
tflite::ops::builtin::BuiltinOpResolver resolver;
238+
239+
// Create tensors.
240+
interpreter.AddTensors(3);
241+
interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
242+
TfLiteQuantization());
243+
interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "b", {3},
244+
TfLiteQuantization(),
245+
/*is_variable=*/true);
246+
interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
247+
TfLiteQuantization());
248+
interpreter.SetInputs({0});
249+
interpreter.SetOutputs({2});
250+
interpreter.SetVariables({1});
251+
252+
// Create an Add node.
253+
TfLiteAddParams* builtin_data =
254+
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
255+
builtin_data->activation = kTfLiteActNone;
256+
builtin_data->pot_scale_int16 = false;
257+
interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0,
258+
reinterpret_cast<void*>(builtin_data),
259+
resolver.FindOp(BuiltinOperator_ADD, 1));
260+
261+
// Write model to file.
262+
const std::string test_file = CreateFilePath("test_variables.tflite");
263+
SubgraphWriter writer(&interpreter.primary_subgraph());
264+
EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0}, /*outputs=*/{2},
265+
/*execution_plan=*/{0}),
266+
kTfLiteOk);
267+
writer.Write(test_file);
268+
269+
// Read model and test.
270+
std::unique_ptr<FlatBufferModel> model =
271+
FlatBufferModel::BuildFromFile(test_file.c_str());
272+
InterpreterBuilder builder(*model, resolver);
273+
std::unique_ptr<Interpreter> new_interpreter;
274+
builder(&new_interpreter);
275+
CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
276+
}
277+
233278
TEST_P(SingleSubgraphTest, PerTensorQuantizedModelTest) {
234279
Interpreter interpreter;
235280
interpreter.AddTensors(3);

0 commit comments

Comments
 (0)