Skip to content

Commit 40cd08d

Browse files
Update NNAPI delegate to support state API for RNN.
PiperOrigin-RevId: 210603975
1 parent 31d10e3 commit 40cd08d

File tree

2 files changed

+11
-22
lines changed

2 files changed

+11
-22
lines changed

tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -779,23 +779,21 @@ class NNAPIDelegateKernel {
779779
return nullptr;
780780
}
781781
break;
782-
#if 0
783782
case kTfLiteBuiltinRnn:
784783
// NNAPI only support float32 weights.
785-
// TODO(miaowang): check the number of inputs before accessing it.
786-
if (version == 1 &&
784+
if (version == 1 && node->inputs->size == 5 &&
787785
context->tensors[node->inputs->data[/*kWeightsTensor*/ 1]].type ==
788786
kTfLiteFloat32) {
789787
return [](const NNAPIOpMappingArgs& mapping_args)
790788
-> ANeuralNetworksOperationType {
791789
// NNAPI need both state_in and state_out.
792790
int ann_index;
793791
mapping_args.builder->AddStateFloat32Tensor(
794-
mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0],
792+
mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4],
795793
&ann_index);
796794
mapping_args.model_state_outputs->push_back(ann_index);
797795
mapping_args.model_state_tfl_inputs->push_back(
798-
mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0]);
796+
mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4]);
799797
auto builtin = reinterpret_cast<TfLiteRNNParams*>(
800798
mapping_args.node->builtin_data);
801799
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
@@ -805,10 +803,9 @@ class NNAPIDelegateKernel {
805803
return nullptr;
806804
}
807805
break;
808-
#endif
809806
case kTfLiteBuiltinSvdf:
810807
// NNAPI only support float32 weights.
811-
if (version == 1 &&
808+
if (version == 1 && node->inputs->size == 5 &&
812809
context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]]
813810
.type == kTfLiteFloat32) {
814811
return [](const NNAPIOpMappingArgs& mapping_args)

tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,15 +1773,16 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
17731773
weights_ = AddInput(weights);
17741774
recurrent_weights_ = AddInput(recurrent_weights);
17751775
bias_ = AddInput(TensorType_FLOAT32);
1776-
hidden_state_ = AddOutput(TensorType_FLOAT32);
1776+
hidden_state_ = AddInput(TensorType_FLOAT32, true);
17771777
output_ = AddOutput(TensorType_FLOAT32);
17781778
SetBuiltinOp(
17791779
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
17801780
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
1781-
BuildInterpreter({{batches_, input_size_},
1782-
{units_, input_size_},
1783-
{units_, units_},
1784-
{units_}});
1781+
BuildInterpreter({{batches_, input_size_}, // input tensor
1782+
{units_, input_size_}, // weights tensor
1783+
{units_, units_}, // recurrent weights tensor
1784+
{units_}, // bias tensor
1785+
{batches_, units_}}); // hidden state tensor
17851786
}
17861787

17871788
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -1802,14 +1803,6 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
18021803
PopulateTensor(input_, offset, begin, end);
18031804
}
18041805

1805-
void ResetHiddenState() {
1806-
const int zero_buffer_size = units_ * batches_;
1807-
std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
1808-
memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
1809-
PopulateTensor(hidden_state_, 0, zero_buffer.get(),
1810-
zero_buffer.get() + zero_buffer_size);
1811-
}
1812-
18131806
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
18141807

18151808
int input_size() { return input_size_; }
@@ -1829,13 +1822,12 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
18291822
int input_size_;
18301823
};
18311824

1832-
TEST(NNAPIDelegate, DISABLED_RnnBlackBoxTest) {
1825+
TEST(NNAPIDelegate, RnnBlackBoxTest) {
18331826
RNNOpModel rnn(2, 16, 8);
18341827
rnn.SetWeights(rnn_weights);
18351828
rnn.SetBias(rnn_bias);
18361829
rnn.SetRecurrentWeights(rnn_recurrent_weights);
18371830

1838-
rnn.ResetHiddenState();
18391831
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
18401832
(rnn.input_size() * rnn.num_batches());
18411833

0 commit comments

Comments
 (0)