@@ -54,6 +54,7 @@ limitations under the License.
5454#include " tensorflow/stream_executor/lib/statusor.h"
5555
5656namespace tensorflow {
57+ namespace tfr {
5758
5859const char * const kTFRLibEnv = " TF_MLIR_TFR_LIB_DIR" ;
5960
@@ -66,6 +67,10 @@ StatusOr<std::unique_ptr<TFRDecomposeContext>> TFRDecomposeContext::Get(
6667 string composite_mlir_dir = io::JoinPath (env->GetRunfilesDir (), tfr_lib_dir);
6768 std::vector<string> files;
6869 TF_RETURN_IF_ERROR (env->GetChildren (composite_mlir_dir, &files));
70+ if (files.empty ()) {
71+ return errors::Internal (absl::StrCat (
72+ " Failed to find the decomposition lib from path " , composite_mlir_dir));
73+ }
6974 std::string tfr_raw_text;
7075 for (const auto & file : files) {
7176 string fullpath = io::JoinPath (composite_mlir_dir, file);
@@ -76,15 +81,15 @@ StatusOr<std::unique_ptr<TFRDecomposeContext>> TFRDecomposeContext::Get(
7681 }
7782 }
7883
79- auto ctx = TFRDecomposeContext::Get (tfr_raw_text, mlir_ctx);
84+ auto ctx = TFRDecomposeContext::GetFromText (tfr_raw_text, mlir_ctx);
8085 if (!ctx) {
8186 return errors::Internal (absl::StrCat (
8287 " Failed to load the imported decomposition lib: " , tfr_raw_text));
8388 }
8489 return ctx;
8590}
8691
87- std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::Get (
92+ std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::GetFromText (
8893 StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx) {
8994 mlir_ctx->allowUnregisteredDialects (/* allow=*/ true );
9095 // Load dialects involved in the conversion
@@ -105,20 +110,22 @@ std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::Get(
105110 llvm::SourceMgr source_mgr;
106111 source_mgr.AddNewSourceBuffer (std::move (memory_buffer), llvm::SMLoc ());
107112 mlir::OwningModuleRef module = mlir::parseSourceFile (source_mgr, mlir_ctx);
113+ // The MLIRContext owns the module
114+ auto module_op = module .release ();
108115
109116 // Create the context
110- return absl::make_unique<TFRDecomposeContext>(std::move ( module ) );
117+ return absl::make_unique<TFRDecomposeContext>(module_op );
111118}
112119
113- StatusOr<FunctionDef> TFRDecomposeContext::Decompose (const NodeDef& node_def,
114- StringPiece func_name) {
120+ StatusOr<FunctionDef> TFRDecomposeContext::ExpandNode (const NodeDef& node_def,
121+ StringPiece func_name) {
115122 const OpDef* op_def;
116123 TF_RETURN_IF_ERROR (OpRegistry::Global ()->LookUpOpDef (node_def.op (), &op_def));
117124 DataTypeVector input_dtys, output_dtys;
118125 TF_RETURN_IF_ERROR (InputTypesForNode (node_def, *op_def, &input_dtys));
119126 TF_RETURN_IF_ERROR (OutputTypesForNode (node_def, *op_def, &output_dtys));
120127
121- mlir::MLIRContext* context = tfr_module_-> getContext ();
128+ mlir::MLIRContext* context = tfr_module_. getContext ();
122129 llvm::SmallVector<mlir::Type, 4 > input_tys, output_tys;
123130 mlir::Builder builder (context);
124131 for (auto ty : input_dtys) {
@@ -159,15 +166,8 @@ StatusOr<FunctionDef> TFRDecomposeContext::Decompose(const NodeDef& node_def,
159166 mlir::Operation* tf_op = op_builder.createOperation (op_state);
160167 op_builder.create <mlir::ReturnOp>(loc, tf_op->getResults ());
161168
162- if (failed (mlir::verify (module ))) {
163- return errors::Internal (absl::StrCat (
164- " Failed to verify the imported NodeDef: " , node_def.DebugString ()));
165- }
166-
167- // Call the decompose passes by using the external symbol table.
168- if (failed (pm_.run (module ))) {
169- return errors::Internal (" Failed to run the decompose passes." );
170- }
169+ // Run the decompose passes on the module
170+ TF_RETURN_IF_ERROR (DecomposeGraph (module ));
171171
172172 // Export the result as a FunctionDef.
173173 FunctionDef func_def;
@@ -177,43 +177,46 @@ StatusOr<FunctionDef> TFRDecomposeContext::Decompose(const NodeDef& node_def,
177177 return func_def;
178178}
179179
180- Status TFRDecomposeContext::Decompose (mlir::ModuleOp user_module) {
180+ Status TFRDecomposeContext::DecomposeGraph (mlir::ModuleOp user_module) {
181181 // Call the decompose passes by using the external symbol table.
182182 if (failed (pm_.run (user_module))) {
183183 return errors::Internal (" Failed to run the decompose passes." );
184184 }
185185 return Status::OK ();
186186}
187187
188- StatusOr<FunctionDef> TFRDecomposeContext::Expand (const NodeDef& node_def,
189- StringPiece func_name) {
190- mlir::MLIRContext mlir_ctx;
191- mlir_ctx.allowUnregisteredDialects (/* allow=*/ true );
192- TF_ASSIGN_OR_RETURN (auto ctx, Get (&mlir_ctx));
193- return ctx->Decompose (node_def, func_name);
194- }
195-
196- Status TFRDecomposeContext::Destroy () {
197- tfr_module_.release ().erase ();
198- return Status::OK ();
199- }
200-
201188// Constructor of the decompose context.
202- TFRDecomposeContext::TFRDecomposeContext (mlir::OwningModuleRef tfr_module)
203- : tfr_module_(std::move( tfr_module)) , pm_(tfr_module_-> getContext ()) {
189+ TFRDecomposeContext::TFRDecomposeContext (mlir::ModuleOp tfr_module)
190+ : tfr_module_(tfr_module), pm_(tfr_module_. getContext()) {
204191 mlir::OpPassManager& func_pm = pm_.nest <mlir::FuncOp>();
205192
206193 // Prepare the imported graph.
207194 func_pm.addPass (mlir::CreateExecutorDialectToFunctionalConversionPass ());
208195
209196 // Run TFR lowering, inlining and raising to tf.
210- func_pm.addPass (mlir::TFR::CreateDecomposeTFOpsPass (tfr_module_. get () ));
197+ func_pm.addPass (mlir::TFR::CreateDecomposeTFOpsPass (tfr_module_));
211198 func_pm.addPass (mlir::TFR::CreateRaiseToTFOpsPass (
212- tfr_module_. get () , /* materialize_derived_attrs=*/ true ));
199+ tfr_module_, /* materialize_derived_attrs=*/ true ));
213200
214201 // Prepare to be exported.
215202 func_pm.addPass (mlir::CreateFunctionalToExecutorDialectConversionPass ());
216203 pm_.addPass (mlir::CreateBreakUpIslandsPass ());
217204}
218205
206+ void TFRDecomposeContext::Destroy () { tfr_module_.erase (); }
207+
208+ StatusOr<FunctionDef> ExpandNode (const NodeDef& node_def,
209+ StringPiece func_name) {
210+ mlir::MLIRContext mlir_ctx;
211+ TF_ASSIGN_OR_RETURN (auto ctx, TFRDecomposeContext::Get (&mlir_ctx));
212+ return ctx->ExpandNode (node_def, func_name);
213+ }
214+
215+ Status DecomposeGraph (mlir::ModuleOp user_module) {
216+ mlir::MLIRContext* mlir_ctx = user_module.getContext ();
217+ TF_ASSIGN_OR_RETURN (auto ctx, TFRDecomposeContext::Get (mlir_ctx));
218+ return ctx->DecomposeGraph (user_module);
219+ }
220+
221+ } // namespace tfr
219222} // namespace tensorflow
0 commit comments