@@ -271,7 +271,8 @@ static void RegisterDialects(mlir::DialectRegistry& registry) {
271271void CreateConvertMlirToXlaHloPipeline (
272272 mlir::OpPassManager& pm, llvm::StringRef device_type,
273273 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
274- custom_legalization_passes) {
274+ custom_legalization_passes,
275+ bool inline_after_legalization) {
275276 pm.addPass (mlir::TF::CreateTFFunctionalControlFlowToRegions ());
276277 pm.addNestedPass <mlir::FuncOp>(mlir::createCanonicalizerPass ());
277278 // Run shape inference pass before tensorlist decomposition to get buffer
@@ -319,6 +320,9 @@ void CreateConvertMlirToXlaHloPipeline(
319320 pm.addNestedPass <mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass (
320321 /* allow_partial_conversion=*/ false , /* legalize_chlo=*/ true ,
321322 /* tf2xla_fallback_device_type=*/ device_type));
323+
324+ if (inline_after_legalization) pm.addPass (mlir::createInlinerPass ());
325+
322326 // In order to export to XLA, we must sink constants to control flow regions,
323327 // since XLA uses functional control flow.
324328 pm.addNestedPass <mlir::FuncOp>(
@@ -331,11 +335,13 @@ Status ConvertMLIRToXlaComputation(
331335 bool return_tuple,
332336 const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
333337 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
334- custom_legalization_passes) {
338+ custom_legalization_passes,
339+ bool inline_after_legalization) {
335340 mlir::PassManager tf2xla (module_op.getContext ());
336341 applyTensorflowAndCLOptions (tf2xla);
337342 CreateConvertMlirToXlaHloPipeline (tf2xla, device_type,
338- custom_legalization_passes);
343+ custom_legalization_passes,
344+ inline_after_legalization);
339345
340346 if (VLOG_IS_ON (1 )) {
341347 // Print the whole module after each pass which requires disabling
@@ -373,7 +379,8 @@ Status CompileMlirToXlaHlo(
373379 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
374380 XlaCompilationResult* compilation_result,
375381 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
376- custom_legalization_passes) {
382+ custom_legalization_passes,
383+ bool inline_after_legalization) {
377384 if (VLOG_IS_ON (1 ))
378385 tensorflow::DumpMlirOpToFile (" mlir_compile_before" , module_op);
379386
@@ -391,7 +398,7 @@ Status CompileMlirToXlaHlo(
391398 TF_RETURN_IF_ERROR (ConvertMLIRToXlaComputation (
392399 module_op, device_type, compilation_result->computation .get (),
393400 use_tuple_args, use_return_tuple, shape_representation_fn,
394- custom_legalization_passes));
401+ custom_legalization_passes, inline_after_legalization ));
395402
396403 // Construct mapping from XlaComputation's arg to input edges of execute
397404 // node.
@@ -434,7 +441,8 @@ Status CompileSerializedMlirToXlaHlo(
434441 return CompileMlirToXlaHlo (
435442 mlir_module.get (), tensor_or_resource_shapes, device_type, use_tuple_args,
436443 /* use_return_tuple=*/ true , /* use_resource_updates_for_aliases=*/ false ,
437- shape_representation_fn, compilation_result, custom_legalization_passes);
444+ shape_representation_fn, compilation_result, custom_legalization_passes,
445+ /* inline_after_legalization=*/ true );
438446}
439447
440448// Rewrites the given module with specified args. For each of the constant args,
@@ -535,7 +543,8 @@ Status CompileGraphToXlaHlo(
535543 auto status = CompileMlirToXlaHlo (
536544 module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
537545 /* use_resource_updates_for_aliases=*/ true , shape_representation_fn,
538- compilation_result, custom_legalization_passes);
546+ compilation_result, custom_legalization_passes,
547+ /* inline_after_legalization=*/ false );
539548 compilation_result->input_mapping = remaining_params;
540549 return status;
541550}
0 commit comments