Skip to content

Commit 5eede88

Browse files
andylytensorflower-gardener
authored andcommitted
Add inliner pass to CompileSerializedMlirToXlaHlo.
Graphs generated via model parallelism has sharding annotations but does not propagate them into functions as it is assumed functions are inlined. PiperOrigin-RevId: 346637211 Change-Id: Ifc9a95b6036fd11770e12893b374940a248589c2
1 parent 0cf527a commit 5eede88

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ static void RegisterDialects(mlir::DialectRegistry& registry) {
271271
void CreateConvertMlirToXlaHloPipeline(
272272
mlir::OpPassManager& pm, llvm::StringRef device_type,
273273
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
274-
custom_legalization_passes) {
274+
custom_legalization_passes,
275+
bool inline_after_legalization) {
275276
pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
276277
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
277278
// Run shape inference pass before tensorlist decomposition to get buffer
@@ -319,6 +320,9 @@ void CreateConvertMlirToXlaHloPipeline(
319320
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
320321
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
321322
/*tf2xla_fallback_device_type=*/device_type));
323+
324+
if (inline_after_legalization) pm.addPass(mlir::createInlinerPass());
325+
322326
// In order to export to XLA, we must sink constants to control flow regions,
323327
// since XLA uses functional control flow.
324328
pm.addNestedPass<mlir::FuncOp>(
@@ -331,11 +335,13 @@ Status ConvertMLIRToXlaComputation(
331335
bool return_tuple,
332336
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
333337
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
334-
custom_legalization_passes) {
338+
custom_legalization_passes,
339+
bool inline_after_legalization) {
335340
mlir::PassManager tf2xla(module_op.getContext());
336341
applyTensorflowAndCLOptions(tf2xla);
337342
CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
338-
custom_legalization_passes);
343+
custom_legalization_passes,
344+
inline_after_legalization);
339345

340346
if (VLOG_IS_ON(1)) {
341347
// Print the whole module after each pass which requires disabling
@@ -373,7 +379,8 @@ Status CompileMlirToXlaHlo(
373379
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
374380
XlaCompilationResult* compilation_result,
375381
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
376-
custom_legalization_passes) {
382+
custom_legalization_passes,
383+
bool inline_after_legalization) {
377384
if (VLOG_IS_ON(1))
378385
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
379386

@@ -391,7 +398,7 @@ Status CompileMlirToXlaHlo(
391398
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
392399
module_op, device_type, compilation_result->computation.get(),
393400
use_tuple_args, use_return_tuple, shape_representation_fn,
394-
custom_legalization_passes));
401+
custom_legalization_passes, inline_after_legalization));
395402

396403
// Construct mapping from XlaComputation's arg to input edges of execute
397404
// node.
@@ -434,7 +441,8 @@ Status CompileSerializedMlirToXlaHlo(
434441
return CompileMlirToXlaHlo(
435442
mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
436443
/*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false,
437-
shape_representation_fn, compilation_result, custom_legalization_passes);
444+
shape_representation_fn, compilation_result, custom_legalization_passes,
445+
/*inline_after_legalization=*/true);
438446
}
439447

440448
// Rewrites the given module with specified args. For each of the constant args,
@@ -535,7 +543,8 @@ Status CompileGraphToXlaHlo(
535543
auto status = CompileMlirToXlaHlo(
536544
module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
537545
/*use_resource_updates_for_aliases=*/true, shape_representation_fn,
538-
compilation_result, custom_legalization_passes);
546+
compilation_result, custom_legalization_passes,
547+
/*inline_after_legalization=*/false);
539548
compilation_result->input_mapping = remaining_params;
540549
return status;
541550
}

tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ namespace tensorflow {
3939
void CreateConvertMlirToXlaHloPipeline(
4040
mlir::OpPassManager& pm, llvm::StringRef device_type,
4141
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
42-
custom_legalization_passes);
42+
custom_legalization_passes,
43+
bool inline_after_legalization);
4344

4445
// Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
4546
// should only contain operations in tf dialect. If the input module contains
@@ -73,7 +74,8 @@ Status ConvertMLIRToXlaComputation(
7374
bool return_tuple,
7475
const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
7576
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
76-
custom_legalization_passes = {});
77+
custom_legalization_passes = {},
78+
bool inline_after_legalization = false);
7779

7880
// Helper struct representing argument tensor or resource handle shapes.
7981
struct TensorOrResourceShape {
@@ -83,14 +85,16 @@ struct TensorOrResourceShape {
8385

8486
// Compiles a MLIR module into XLA HLO, generates all accompanying metadata and
8587
// stores them in CompilationResult.
88+
// TODO(hinsu): Migrate options to separate struct.
8689
Status CompileMlirToXlaHlo(
8790
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
8891
llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
8992
bool use_resource_updates_for_aliases,
9093
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
9194
XlaCompilationResult* compilation_result,
9295
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
93-
custom_legalization_passes);
96+
custom_legalization_passes,
97+
bool inline_after_legalization);
9498

9599
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
96100
// metadata and stores them in CompilationResult.

tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace {
2121
void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) {
2222
tensorflow::CreateConvertMlirToXlaHloPipeline(
2323
pm, /*device_type=*/"XLA_CPU_JIT",
24-
/*custom_legalization_passes=*/{});
24+
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false);
2525
}
2626

2727
mlir::PassPipelineRegistration<> pipeline(

tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
250250
module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
251251
emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
252252
IdentityShapeRepresentationFn(), &compilation_result,
253-
/*custom_legalization_passes=*/{});
253+
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false);
254254
if (!compilation_status.ok()) {
255255
LOG(ERROR) << "TF/XLA compilation failed: "
256256
<< compilation_status.ToString();

0 commit comments

Comments
 (0)