Skip to content

Commit 5952da0

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Clean up the TFR Decompose context implementation
This cl did the code refactoring and resolved the tsan error for the tests. PiperOrigin-RevId: 337314570 Change-Id: I72454b187ed0e6d4a4cb5a8381ace73d55bbad69
1 parent 3e3d85e commit 5952da0

File tree

10 files changed

+97
-87
lines changed

10 files changed

+97
-87
lines changed

tensorflow/compiler/mlir/tfr/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@ tf_py_test(
232232
data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
233233
python_version = "PY3",
234234
tags = [
235+
"no_pip",
235236
"no_windows", # TODO(b/170752141)
236237
"nomac", # TODO(b/170752141)
237-
"notsan", # TODO(b/170752141)
238238
],
239239
deps = [
240240
"//tensorflow/compiler/mlir/tfr/resources:composite_ops",
@@ -263,9 +263,9 @@ tf_py_test(
263263
data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
264264
python_version = "PY3",
265265
tags = [
266+
"no_pip",
266267
"no_windows", # TODO(b/170752141)
267268
"nomac", # TODO(b/170752141)
268-
"notsan", # TODO(b/170862433)
269269
],
270270
deps = [
271271
"//tensorflow/compiler/mlir/tfr/resources:composite_ops",

tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ limitations under the License.
1414
==============================================================================*/
1515
#include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h"
1616

17+
#include "mlir/IR/MLIRContext.h" // from @llvm-project
1718
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
1819
#include "tensorflow/stream_executor/lib/statusor.h"
1920

2021
namespace tensorflow {
22+
namespace tfr {
2123

2224
bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const {
2325
const char* tfr_lib_env_val = getenv(std::string(kTFRLibEnv).c_str());
@@ -27,14 +29,18 @@ bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const {
2729
Status GraphDecomposePass::Run(const ConfigProto& config_proto,
2830
mlir::ModuleOp module) {
2931
if (!IsEnabled(config_proto)) {
30-
VLOG(1) << "Skipping Graph Decomposition Pass, decompositin library was "
31-
"not found";
32+
LOG_FIRST_N(INFO, 1) << "Skipping Graph Decomposition Pass, decompositin "
33+
"library was not found";
3234
return Status::OK();
3335
}
34-
VLOG(1) << "Run Graph Decomposition Passes";
35-
TF_ASSIGN_OR_RETURN(ctx_, TFRDecomposeContext::Get(module.getContext()));
36-
TF_RETURN_IF_ERROR(ctx_->Decompose(module));
37-
return ctx_->Destroy();
36+
37+
LOG_FIRST_N(INFO, 1) << "Run Graph Decomposition Passes";
38+
39+
TF_RETURN_IF_ERROR(DecomposeGraph(module));
40+
41+
LOG_FIRST_N(INFO, 1) << "Finish Graph Decomposition Passes";
42+
43+
return Status::OK();
3844
}
3945

4046
namespace {
@@ -45,4 +51,5 @@ static mlir_pass_registration::MlirOptimizationPassRegistration
4551
std::make_unique<GraphDecomposePass>());
4652
} // namespace
4753

54+
} // namespace tfr
4855
} // namespace tensorflow

tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include "tensorflow/stream_executor/lib/statusor.h"
2222

2323
namespace tensorflow {
24+
namespace tfr {
2425

2526
// An optimization pass that decompose the composite ops in a module according
2627
// to the decomposition library. Currently the decomposition library is loaded
@@ -37,11 +38,9 @@ class GraphDecomposePass : public MlirOptimizationPass {
3738
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
3839
// API integrated with the Tensorflow runtime.
3940
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override;
40-
41-
private:
42-
std::unique_ptr<TFRDecomposeContext> ctx_;
4341
};
4442

43+
} // namespace tfr
4544
} // namespace tensorflow
4645

4746
#endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_

tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,6 @@
3535

3636
class GraphDecomposeTest(test.TestCase):
3737

38-
def setUp(self):
39-
os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources'
40-
super(GraphDecomposeTest, self).setUp()
41-
42-
def tearDown(self):
43-
del os.environ['TF_MLIR_TFR_LIB_DIR']
44-
super(GraphDecomposeTest, self).tearDown()
45-
4638
def testAddN(self):
4739
add = def_function.function(gen_composite_ops.my_add_n)
4840
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
@@ -86,5 +78,6 @@ def biasd_dense_elu(x, y, z):
8678

8779

8880
if __name__ == '__main__':
81+
os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources'
8982
ops.enable_eager_execution()
9083
test.main()

tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ limitations under the License.
2121
#include "tensorflow/stream_executor/lib/statusor.h"
2222

2323
namespace tensorflow {
24+
namespace tfr {
2425

2526
Status CompositeOpExpansion::Run(EagerOperation* orig_op,
2627
std::unique_ptr<EagerOperation>* out_op) {
2728
if (!IsEnabled()) return Status::OK();
2829
if (orig_op->Device() != kVariantDeviceNull) return Status::OK();
2930

30-
VLOG(1) << "Run Node Expansion Passes";
31+
LOG_FIRST_N(INFO, 1) << "Run Node Expansion Passes";
3132

3233
// Get the FunctionDef and insert that into the context
3334
const NodeDef& ndef = orig_op->MutableAttrs()->BuildNodeDef();
@@ -40,7 +41,7 @@ Status CompositeOpExpansion::Run(EagerOperation* orig_op,
4041
std::string fname =
4142
absl::StrCat("_expanded_", ndef.name(), "_", std::to_string(x));
4243
if (!ctx.FindFunctionByName(fname)) {
43-
TF_ASSIGN_OR_RETURN(auto func, TFRDecomposeContext::Expand(ndef, fname));
44+
TF_ASSIGN_OR_RETURN(auto func, ExpandNode(ndef, fname));
4445
TF_RETURN_IF_ERROR(ctx.AddFunctionDef(func));
4546
}
4647

@@ -55,11 +56,14 @@ Status CompositeOpExpansion::Run(EagerOperation* orig_op,
5556
new_op->MutableAttrs()->CopyAttributes(orig_op->Attrs());
5657
out_op->reset(new_op);
5758

58-
VLOG(1) << "Rewrite the op to call function: " << fname;
59+
LOG_FIRST_N(INFO, 1)
60+
<< "Finish Node Expansion Passes. Rewrite the op to call function: "
61+
<< fname;
5962

6063
return Status::OK();
6164
}
6265

6366
REGISTER_REWRITE(EagerOpRewriteRegistry::POST_PLACEMENT, CompositeOpExpansion);
6467

68+
} // namespace tfr
6569
} // namespace tensorflow

tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include "tensorflow/stream_executor/lib/statusor.h"
2121

2222
namespace tensorflow {
23+
namespace tfr {
2324

2425
// An optimization pass that decompose the composite ops in a module according
2526
// to the decomposition library. Currently the decomposition library is loaded
@@ -42,6 +43,7 @@ class CompositeOpExpansion : public EagerOpRewrite {
4243
}
4344
};
4445

46+
} // namespace tfr
4547
} // namespace tensorflow
4648

4749
#endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_NODE_EXPANSION_PASS_H_

tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,6 @@
3434

3535
class NodeExpansionTest(test.TestCase):
3636

37-
def setUp(self):
38-
os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources'
39-
super(NodeExpansionTest, self).setUp()
40-
41-
def tearDown(self):
42-
del os.environ['TF_MLIR_TFR_LIB_DIR']
43-
super(NodeExpansionTest, self).tearDown()
44-
4537
def testAddN(self):
4638
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
4739
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
@@ -81,5 +73,6 @@ def biasd_dense_elu(x, y, z):
8173

8274

8375
if __name__ == '__main__':
76+
os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources'
8477
ops.enable_eager_execution()
8578
test.main()

tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ limitations under the License.
5454
#include "tensorflow/stream_executor/lib/statusor.h"
5555

5656
namespace tensorflow {
57+
namespace tfr {
5758

5859
const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR";
5960

@@ -66,6 +67,10 @@ StatusOr<std::unique_ptr<TFRDecomposeContext>> TFRDecomposeContext::Get(
6667
string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir);
6768
std::vector<string> files;
6869
TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files));
70+
if (files.empty()) {
71+
return errors::Internal(absl::StrCat(
72+
"Failed to find the decomposition lib from path ", composite_mlir_dir));
73+
}
6974
std::string tfr_raw_text;
7075
for (const auto& file : files) {
7176
string fullpath = io::JoinPath(composite_mlir_dir, file);
@@ -76,15 +81,15 @@ StatusOr<std::unique_ptr<TFRDecomposeContext>> TFRDecomposeContext::Get(
7681
}
7782
}
7883

79-
auto ctx = TFRDecomposeContext::Get(tfr_raw_text, mlir_ctx);
84+
auto ctx = TFRDecomposeContext::GetFromText(tfr_raw_text, mlir_ctx);
8085
if (!ctx) {
8186
return errors::Internal(absl::StrCat(
8287
"Failed to load the imported decomposition lib: ", tfr_raw_text));
8388
}
8489
return ctx;
8590
}
8691

87-
std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::Get(
92+
std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::GetFromText(
8893
StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx) {
8994
mlir_ctx->allowUnregisteredDialects(/*allow=*/true);
9095
// Load dialects involved in the conversion
@@ -105,20 +110,22 @@ std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::Get(
105110
llvm::SourceMgr source_mgr;
106111
source_mgr.AddNewSourceBuffer(std::move(memory_buffer), llvm::SMLoc());
107112
mlir::OwningModuleRef module = mlir::parseSourceFile(source_mgr, mlir_ctx);
113+
// The MLIRContext owns the module
114+
auto module_op = module.release();
108115

109116
// Create the context
110-
return absl::make_unique<TFRDecomposeContext>(std::move(module));
117+
return absl::make_unique<TFRDecomposeContext>(module_op);
111118
}
112119

113-
StatusOr<FunctionDef> TFRDecomposeContext::Decompose(const NodeDef& node_def,
114-
StringPiece func_name) {
120+
StatusOr<FunctionDef> TFRDecomposeContext::ExpandNode(const NodeDef& node_def,
121+
StringPiece func_name) {
115122
const OpDef* op_def;
116123
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
117124
DataTypeVector input_dtys, output_dtys;
118125
TF_RETURN_IF_ERROR(InputTypesForNode(node_def, *op_def, &input_dtys));
119126
TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, &output_dtys));
120127

121-
mlir::MLIRContext* context = tfr_module_->getContext();
128+
mlir::MLIRContext* context = tfr_module_.getContext();
122129
llvm::SmallVector<mlir::Type, 4> input_tys, output_tys;
123130
mlir::Builder builder(context);
124131
for (auto ty : input_dtys) {
@@ -159,15 +166,8 @@ StatusOr<FunctionDef> TFRDecomposeContext::Decompose(const NodeDef& node_def,
159166
mlir::Operation* tf_op = op_builder.createOperation(op_state);
160167
op_builder.create<mlir::ReturnOp>(loc, tf_op->getResults());
161168

162-
if (failed(mlir::verify(module))) {
163-
return errors::Internal(absl::StrCat(
164-
"Failed to verify the imported NodeDef: ", node_def.DebugString()));
165-
}
166-
167-
// Call the decompose passes by using the external symbol table.
168-
if (failed(pm_.run(module))) {
169-
return errors::Internal("Failed to run the decompose passes.");
170-
}
169+
// Run the decompose passes on the module
170+
TF_RETURN_IF_ERROR(DecomposeGraph(module));
171171

172172
// Export the result as a FunctionDef.
173173
FunctionDef func_def;
@@ -177,43 +177,46 @@ StatusOr<FunctionDef> TFRDecomposeContext::Decompose(const NodeDef& node_def,
177177
return func_def;
178178
}
179179

180-
Status TFRDecomposeContext::Decompose(mlir::ModuleOp user_module) {
180+
Status TFRDecomposeContext::DecomposeGraph(mlir::ModuleOp user_module) {
181181
// Call the decompose passes by using the external symbol table.
182182
if (failed(pm_.run(user_module))) {
183183
return errors::Internal("Failed to run the decompose passes.");
184184
}
185185
return Status::OK();
186186
}
187187

188-
StatusOr<FunctionDef> TFRDecomposeContext::Expand(const NodeDef& node_def,
189-
StringPiece func_name) {
190-
mlir::MLIRContext mlir_ctx;
191-
mlir_ctx.allowUnregisteredDialects(/*allow=*/true);
192-
TF_ASSIGN_OR_RETURN(auto ctx, Get(&mlir_ctx));
193-
return ctx->Decompose(node_def, func_name);
194-
}
195-
196-
Status TFRDecomposeContext::Destroy() {
197-
tfr_module_.release().erase();
198-
return Status::OK();
199-
}
200-
201188
// Constructor of the decompose context.
202-
TFRDecomposeContext::TFRDecomposeContext(mlir::OwningModuleRef tfr_module)
203-
: tfr_module_(std::move(tfr_module)), pm_(tfr_module_->getContext()) {
189+
TFRDecomposeContext::TFRDecomposeContext(mlir::ModuleOp tfr_module)
190+
: tfr_module_(tfr_module), pm_(tfr_module_.getContext()) {
204191
mlir::OpPassManager& func_pm = pm_.nest<mlir::FuncOp>();
205192

206193
// Prepare the imported graph.
207194
func_pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass());
208195

209196
// Run TFR lowering, inlining and raising to tf.
210-
func_pm.addPass(mlir::TFR::CreateDecomposeTFOpsPass(tfr_module_.get()));
197+
func_pm.addPass(mlir::TFR::CreateDecomposeTFOpsPass(tfr_module_));
211198
func_pm.addPass(mlir::TFR::CreateRaiseToTFOpsPass(
212-
tfr_module_.get(), /*materialize_derived_attrs=*/true));
199+
tfr_module_, /*materialize_derived_attrs=*/true));
213200

214201
// Prepare to be exported.
215202
func_pm.addPass(mlir::CreateFunctionalToExecutorDialectConversionPass());
216203
pm_.addPass(mlir::CreateBreakUpIslandsPass());
217204
}
218205

206+
void TFRDecomposeContext::Destroy() { tfr_module_.erase(); }
207+
208+
StatusOr<FunctionDef> ExpandNode(const NodeDef& node_def,
209+
StringPiece func_name) {
210+
mlir::MLIRContext mlir_ctx;
211+
TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(&mlir_ctx));
212+
return ctx->ExpandNode(node_def, func_name);
213+
}
214+
215+
Status DecomposeGraph(mlir::ModuleOp user_module) {
216+
mlir::MLIRContext* mlir_ctx = user_module.getContext();
217+
TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(mlir_ctx));
218+
return ctx->DecomposeGraph(user_module);
219+
}
220+
221+
} // namespace tfr
219222
} // namespace tensorflow

0 commit comments

Comments
 (0)