@@ -22,6 +22,7 @@ limitations under the License.
2222#include < tuple>
2323
2424#include < gtest/gtest.h>
25+ #include " tensorflow/lite/c/builtin_op_data.h"
2526#include " tensorflow/lite/c/common.h"
2627#include " tensorflow/lite/interpreter.h"
2728#include " tensorflow/lite/kernels/register.h"
@@ -230,6 +231,50 @@ TEST_P(SingleSubgraphTest, CustomInputOutputErrorCasesTest) {
230231 kTfLiteOk );
231232}
232233
234+ // Tests if SetCustomInputOutput handles variable tensors correctly.
235+ TEST_P (SingleSubgraphTest, CustomInputOutputVariableTensorTest) {
236+ Interpreter interpreter;
237+ tflite::ops::builtin::BuiltinOpResolver resolver;
238+
239+ // Create tensors.
240+ interpreter.AddTensors (3 );
241+ interpreter.SetTensorParametersReadWrite (0 , kTfLiteFloat32 , " a" , {3 },
242+ TfLiteQuantization ());
243+ interpreter.SetTensorParametersReadWrite (1 , kTfLiteFloat32 , " b" , {3 },
244+ TfLiteQuantization (),
245+ /* is_variable=*/ true );
246+ interpreter.SetTensorParametersReadWrite (2 , kTfLiteFloat32 , " c" , {3 },
247+ TfLiteQuantization ());
248+ interpreter.SetInputs ({0 });
249+ interpreter.SetOutputs ({2 });
250+ interpreter.SetVariables ({1 });
251+
252+ // Create an Add node.
253+ TfLiteAddParams* builtin_data =
254+ reinterpret_cast <TfLiteAddParams*>(malloc (sizeof (TfLiteAddParams)));
255+ builtin_data->activation = kTfLiteActNone ;
256+ builtin_data->pot_scale_int16 = false ;
257+ interpreter.AddNodeWithParameters ({0 , 1 }, {2 }, nullptr , 0 ,
258+ reinterpret_cast <void *>(builtin_data),
259+ resolver.FindOp (BuiltinOperator_ADD, 1 ));
260+
261+ // Write model to file.
262+ const std::string test_file = CreateFilePath (" test_variables.tflite" );
263+ SubgraphWriter writer (&interpreter.primary_subgraph ());
264+ EXPECT_EQ (writer.SetCustomInputOutput (/* inputs=*/ {0 }, /* outputs=*/ {2 },
265+ /* execution_plan=*/ {0 }),
266+ kTfLiteOk );
267+ writer.Write (test_file);
268+
269+ // Read model and test.
270+ std::unique_ptr<FlatBufferModel> model =
271+ FlatBufferModel::BuildFromFile (test_file.c_str ());
272+ InterpreterBuilder builder (*model, resolver);
273+ std::unique_ptr<Interpreter> new_interpreter;
274+ builder (&new_interpreter);
275+ CHECK_EQ (new_interpreter->AllocateTensors (), kTfLiteOk );
276+ }
277+
233278TEST_P (SingleSubgraphTest, PerTensorQuantizedModelTest) {
234279 Interpreter interpreter;
235280 interpreter.AddTensors (3 );
0 commit comments