@@ -1773,15 +1773,16 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
1773
1773
weights_ = AddInput (weights);
1774
1774
recurrent_weights_ = AddInput (recurrent_weights);
1775
1775
bias_ = AddInput (TensorType_FLOAT32);
1776
- hidden_state_ = AddOutput (TensorType_FLOAT32);
1776
+ hidden_state_ = AddInput (TensorType_FLOAT32, true );
1777
1777
output_ = AddOutput (TensorType_FLOAT32);
1778
1778
SetBuiltinOp (
1779
1779
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
1780
1780
CreateRNNOptions (builder_, ActivationFunctionType_RELU).Union ());
1781
- BuildInterpreter ({{batches_, input_size_},
1782
- {units_, input_size_},
1783
- {units_, units_},
1784
- {units_}});
1781
+ BuildInterpreter ({{batches_, input_size_}, // input tensor
1782
+ {units_, input_size_}, // weights tensor
1783
+ {units_, units_}, // recurrent weights tensor
1784
+ {units_}, // bias tensor
1785
+ {batches_, units_}}); // hidden state tensor
1785
1786
}
1786
1787
1787
1788
void SetBias (std::initializer_list<float > f) { PopulateTensor (bias_, f); }
@@ -1802,14 +1803,6 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
1802
1803
PopulateTensor (input_, offset, begin, end);
1803
1804
}
1804
1805
1805
- void ResetHiddenState () {
1806
- const int zero_buffer_size = units_ * batches_;
1807
- std::unique_ptr<float []> zero_buffer (new float [zero_buffer_size]);
1808
- memset (zero_buffer.get (), 0 , zero_buffer_size * sizeof (float ));
1809
- PopulateTensor (hidden_state_, 0 , zero_buffer.get (),
1810
- zero_buffer.get () + zero_buffer_size);
1811
- }
1812
-
1813
1806
std::vector<float > GetOutput () { return ExtractVector<float >(output_); }
1814
1807
1815
1808
int input_size () { return input_size_; }
@@ -1829,13 +1822,12 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
1829
1822
int input_size_;
1830
1823
};
1831
1824
1832
- TEST (NNAPIDelegate, DISABLED_RnnBlackBoxTest ) {
1825
+ TEST (NNAPIDelegate, RnnBlackBoxTest ) {
1833
1826
RNNOpModel rnn (2 , 16 , 8 );
1834
1827
rnn.SetWeights (rnn_weights);
1835
1828
rnn.SetBias (rnn_bias);
1836
1829
rnn.SetRecurrentWeights (rnn_recurrent_weights);
1837
1830
1838
- rnn.ResetHiddenState ();
1839
1831
const int input_sequence_size = sizeof (rnn_input) / sizeof (float ) /
1840
1832
(rnn.input_size () * rnn.num_batches ());
1841
1833
0 commit comments