@@ -125,6 +125,117 @@ class PrepareCompositeFunctionsPass
125125 void runOnOperation () override ;
126126};
127127
128+ LogicalResult CheckFusableLayerNormalizedLstmCellSimple (FuncOp lstm_func) {
129+ for (int i = 0 ; i < 5 ; ++i) {
130+ auto input = lstm_func.getArgument (i);
131+ auto input_type = input.getType ().dyn_cast_or_null <RankedTensorType>();
132+ if (!input_type) {
133+ lstm_func.emitWarning (
134+ " we cannot fuse this lstm func because all the inputs have not "
135+ " ranked tensor type." );
136+ return failure ();
137+ }
138+ }
139+
140+ return success ();
141+ }
142+
143+ LogicalResult CheckFusableLstmCellSimple (FuncOp lstm_func) {
144+ for (int i = 0 ; i < 4 ; ++i) {
145+ auto input = lstm_func.getArgument (i);
146+ auto input_type = input.getType ().dyn_cast_or_null <RankedTensorType>();
147+ if (!input_type) {
148+ lstm_func.emitWarning (
149+ " we cannot fuse this lstm func because all the inputs have not "
150+ " ranked tensor type." );
151+ return failure ();
152+ }
153+ }
154+
155+ return success ();
156+ }
157+
158+ LogicalResult CheckOutputConsumer (
159+ Operation* call_op, int expected_num_outputs,
160+ llvm::DenseSet<int > expected_consumer_indices) {
161+ const int num_results = call_op->getNumResults ();
162+ if (num_results != expected_num_outputs) return failure ();
163+
164+ for (int i = 0 ; i < expected_num_outputs; ++i) {
165+ auto it = expected_consumer_indices.find (i);
166+ if (it == expected_consumer_indices.end ()) {
167+ // Unexpected consumer.
168+ if (!call_op->getResult (i).use_empty ()) return failure ();
169+ }
170+ }
171+ return success ();
172+ }
173+
174+ LogicalResult CheckFusableKerasLstm (FuncOp lstm_func, ModuleOp module ) {
175+ for (auto func : module .getOps <FuncOp>()) {
176+ if (func == lstm_func) continue ;
177+ auto result = func.walk ([&](CallOpInterface op) {
178+ if (dyn_cast<FuncOp>(op.resolveCallable ()) == lstm_func) {
179+ // Keras LSTM have 5 outputs.
180+ // We should make sure only the first or the second output are
181+ // consumed.
182+ if (failed (CheckOutputConsumer (op.getOperation (), 5 , {0 , 1 })))
183+ return WalkResult::interrupt ();
184+ }
185+ return WalkResult::advance ();
186+ });
187+
188+ if (result.wasInterrupted ()) return failure ();
189+ }
190+
191+ // We should know the batch size in advance for the lstm fusion.
192+ // A good indicator of batch size is both cell state and input state (indices
193+ // 1 & 2) have fixed shape and other input tenors should have ranked tensor
194+ // types.
195+ for (int i = 0 ; i < 6 ; ++i) {
196+ auto input = lstm_func.getArgument (i);
197+ auto input_type = input.getType ().dyn_cast_or_null <RankedTensorType>();
198+ if (!input_type) {
199+ lstm_func.emitWarning (
200+ " we cannot fuse this lstm func because all the inputs have not "
201+ " ranked tensor type." );
202+ return failure ();
203+ }
204+ switch (i) {
205+ case 1 : // output_init_state
206+ case 2 : // hidden_init_state
207+ if (!input_type.hasStaticShape ()) {
208+ lstm_func.emitWarning (
209+ " we cannot fuse this lstm func because the batch size is not "
210+ " fixed, please consider setting fixed batch size like "
211+ " https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
212+ " lite/examples/experimental_new_converter/"
213+ " Keras_LSTM_fusion_Codelab.ipynb" );
214+ return failure ();
215+ }
216+ break ;
217+ case 3 : // wiehgt
218+ case 4 : // recurrent_kernel
219+ case 5 : // bias
220+ if (!input_type.hasStaticShape ()) {
221+ lstm_func.emitWarning (
222+ " we cannot fuse this lstm func because the weight & bias are not "
223+ " fixed, please consider setting fixed batch size like "
224+ " https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
225+ " lite/examples/experimental_new_converter/"
226+ " Keras_LSTM_fusion_Codelab.ipynb" );
227+ return failure ();
228+ }
229+ break ;
230+ default :
231+ // No op.
232+ break ;
233+ }
234+ }
235+
236+ return success ();
237+ }
238+
128239void PrepareCompositeFunctionsPass::ConvertTFImplements (FuncOp func,
129240 StringAttr attr) {
130241 if (attr.getValue () == " embedding_matmul" ) {
@@ -138,13 +249,19 @@ void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
138249 }
139250 convert_embedded_lookup.RewriteFunc ();
140251 } else if (attr.getValue () == mlir::TFL::kLstmCellSimple ) {
252+ // Check if the lstm cell simple can be fused, if not, we just don't do
253+ // anything.
254+ if (failed (CheckFusableLstmCellSimple (func))) return ;
141255 func.eraseBody ();
142256 func.addEntryBlock ();
143257 ConvertLSTMCellSimpleToFusedLSTM convert_lstm_cell_simple (func);
144258 if (failed (convert_lstm_cell_simple.RewriteFunc ())) {
145259 return signalPassFailure ();
146260 }
147261 } else if (attr.getValue () == mlir::TFL::kLayerNormalizedLstmCellSimple ) {
262+ // Check if the layer normalized lstm cell simple can be fused, if not, we
263+ // just don't do anything.
264+ if (failed (CheckFusableLayerNormalizedLstmCellSimple (func))) return ;
148265 func.eraseBody ();
149266 func.addEntryBlock ();
150267 ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
@@ -181,59 +298,6 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
181298 }
182299}
183300
184- LogicalResult CheckOutputConsumer (
185- Operation* call_op, int expected_num_outputs,
186- llvm::DenseSet<int > expected_consumer_indices) {
187- const int num_results = call_op->getNumResults ();
188- if (num_results != expected_num_outputs) return failure ();
189-
190- for (int i = 0 ; i < expected_num_outputs; ++i) {
191- auto it = expected_consumer_indices.find (i);
192- if (it == expected_consumer_indices.end ()) {
193- // Unexpected consumer.
194- if (!call_op->getResult (i).use_empty ()) return failure ();
195- }
196- }
197- return success ();
198- }
199-
200- LogicalResult CheckFusableKerasLstm (FuncOp lstm_func, ModuleOp module ) {
201- for (auto func : module .getOps <FuncOp>()) {
202- if (func == lstm_func) continue ;
203- auto result = func.walk ([&](CallOpInterface op) {
204- if (dyn_cast<FuncOp>(op.resolveCallable ()) == lstm_func) {
205- // Keras LSTM have 5 outputs.
206- // We should make sure only the first or the second output are
207- // consumed.
208- if (failed (CheckOutputConsumer (op.getOperation (), 5 , {0 , 1 })))
209- return WalkResult::interrupt ();
210- }
211- return WalkResult::advance ();
212- });
213-
214- if (result.wasInterrupted ()) return failure ();
215- }
216-
217- // We should know the batch size in advance for the lstm fusion.
218- // A good indicator of batch size is both cell state and input state have
219- // fixed shape. (indices 1 & 2).
220- for (int i = 1 ; i < 3 ; ++i) {
221- auto input = lstm_func.getArgument (i);
222- auto input_type = input.getType ().dyn_cast_or_null <RankedTensorType>();
223- if (!input_type || !input_type.hasStaticShape ()) {
224- lstm_func.emitWarning (
225- " we cannot fuse this lstm func because the batch size is not fixed, "
226- " please consider setting fixed batch size like "
227- " https://github.com/tensorflow/tensorflow/blob/master/tensorflow/"
228- " lite/examples/experimental_new_converter/"
229- " Keras_LSTM_fusion_Codelab.ipynb" );
230- return failure ();
231- }
232- }
233-
234- return success ();
235- }
236-
237301void PrepareCompositeFunctionsPass::ConvertTFAPIImplements (FuncOp func,
238302 StringAttr attr,
239303 ModuleOp module ) {
0 commit comments