Skip to content

Commit a05f8bd

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Add the op expansion graph optimization pass to tensorflow
PiperOrigin-RevId: 336770266 Change-Id: Iadc1917c9be25bc9010129af551814ae72160347
1 parent e5f3013 commit a05f8bd

File tree

9 files changed

+267
-31
lines changed

9 files changed

+267
-31
lines changed

tensorflow/compiler/mlir/tfr/BUILD

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ package_group(
1818
includes = ["//third_party/mlir:subpackages"],
1919
packages = [
2020
"//learning/brain/experimental/mlir/tfr/...",
21-
"//tensorflow/compiler/mlir/...",
21+
"//tensorflow/c/...",
22+
"//tensorflow/compiler/...",
2223
],
2324
)
2425

@@ -179,7 +180,6 @@ cc_library(
179180
"//tensorflow/compiler/mlir/tfr:passes",
180181
"//tensorflow/core:framework",
181182
"//tensorflow/core:graph",
182-
"//tensorflow/core:lib",
183183
"//tensorflow/core:protos_all_cc",
184184
"//tensorflow/core/common_runtime:optimization_registry",
185185
"//tensorflow/stream_executor/lib",
@@ -201,7 +201,6 @@ tf_cc_test(
201201
":tfr_decompose_ctx",
202202
"//tensorflow/compiler/xla:test",
203203
"//tensorflow/core:framework",
204-
"//tensorflow/core:lib",
205204
"//tensorflow/core:protos_all_cc",
206205
"//tensorflow/core:test",
207206
"//tensorflow/core:test_main",
@@ -216,7 +215,6 @@ cc_library(
216215
name = "graph_decompose_pass",
217216
srcs = ["integration/graph_decompose_pass.cc"],
218217
hdrs = ["integration/graph_decompose_pass.h"],
219-
data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
220218
deps = [
221219
":tfr_decompose_ctx",
222220
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
@@ -228,6 +226,22 @@ cc_library(
228226
alwayslink = 1,
229227
)
230228

229+
tf_py_test(
230+
name = "graph_decompose_test",
231+
size = "small",
232+
srcs = ["integration/graph_decompose_test.py"],
233+
data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
234+
python_version = "PY3",
235+
tags = [
236+
"no_oss",
237+
"notap",
238+
],
239+
deps = [
240+
"//tensorflow/compiler/mlir/tfr/resources:composite_ops",
241+
"//tensorflow/python/eager:def_function",
242+
],
243+
)
244+
231245
tf_python_pybind_extension(
232246
name = "tfr_wrapper",
233247
srcs = ["python/tfr_wrapper.cc"],

tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,28 @@ limitations under the License.
1515
#include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h"
1616

1717
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
18-
#include "tensorflow/core/lib/core/status.h"
1918
#include "tensorflow/core/platform/env.h"
2019
#include "tensorflow/core/platform/path.h"
2120
#include "tensorflow/core/util/env_var.h"
2221
#include "tensorflow/stream_executor/lib/statusor.h"
2322

2423
namespace tensorflow {
2524

25+
constexpr const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR";
26+
27+
bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const {
28+
const char* tfr_lib_env_val = getenv(string(kTFRLibEnv).c_str());
29+
return tfr_lib_env_val != nullptr;
30+
}
31+
2632
Status GraphDecomposePass::Run(const ConfigProto& config_proto,
2733
mlir::ModuleOp module) {
34+
if (!IsEnabled(config_proto)) {
35+
VLOG(1) << "Skipping Graph Decomposition Pass, decompositin library was "
36+
"not found";
37+
return Status::OK();
38+
}
39+
VLOG(1) << "Run Graph Decomposition Passes";
2840
TF_ASSIGN_OR_RETURN(ctx_, LoadDecompositionLib(module.getContext()));
2941
TF_RETURN_IF_ERROR(ctx_->Decompose(module));
3042
return ctx_->Destroy();
@@ -35,8 +47,7 @@ GraphDecomposePass::LoadDecompositionLib(mlir::MLIRContext* mlir_ctx) {
3547
Env* env = Env::Default();
3648
std::string tfr_lib_dir;
3749
TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
38-
"TF_MLIR_TFR_LIB_DIR", "tensorflow/compiler/mlir/tfr/resources",
39-
&tfr_lib_dir));
50+
kTFRLibEnv, "tensorflow/compiler/mlir/tfr/resources", &tfr_lib_dir));
4051
string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir);
4152
std::vector<string> files;
4253
TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files));
@@ -59,7 +70,7 @@ GraphDecomposePass::LoadDecompositionLib(mlir::MLIRContext* mlir_ctx) {
5970
}
6071

6172
namespace {
62-
constexpr int kMlirGraphDecomposePassPriority = 1;
73+
constexpr int kMlirGraphDecomposePassPriority = -1;
6374

6475
static mlir_pass_registration::MlirOptimizationPassRegistration
6576
register_mlir_graph_decompose_pass(kMlirGraphDecomposePassPriority,

tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ limitations under the License.
1616
#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_
1717

1818
#include "mlir/IR/MLIRContext.h" // from @llvm-project
19-
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
2019
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
21-
#include "tensorflow/core/lib/core/status.h"
20+
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
21+
#include "tensorflow/stream_executor/lib/statusor.h"
2222

2323
namespace tensorflow {
2424

@@ -30,10 +30,9 @@ class GraphDecomposePass : public MlirOptimizationPass {
3030
public:
3131
llvm::StringRef name() const override { return "tfr"; }
3232

33-
bool IsEnabled(const ConfigProto& config_proto) const override {
34-
// TODO(fengliuai): make a new flag in config_proto.experimental()
35-
return true;
36-
}
33+
// Whether to run this pass. If this is enabled, the GraphDef will be imported
34+
// to MLIR even no tf composition file is found.
35+
bool IsEnabled(const ConfigProto& config_proto) const override;
3736

3837
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
3938
// API integrated with the Tensorflow runtime.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for third_party.tensorflow.compiler.mlir.tfr.integrattion.graph_decompose."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import os
21+
22+
from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops
23+
from tensorflow.python.eager import def_function
24+
from tensorflow.python.framework import constant_op
25+
from tensorflow.python.framework import load_library
26+
from tensorflow.python.framework import ops
27+
from tensorflow.python.ops import nn_ops
28+
from tensorflow.python.platform import test
29+
30+
_lib_dir = os.path.dirname(gen_composite_ops.__file__)
31+
_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace(
32+
'.py', '.so')
33+
load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
34+
35+
36+
class GraphDecomposeTest(test.TestCase):
37+
38+
def setUp(self):
39+
os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources'
40+
super(GraphDecomposeTest, self).setUp()
41+
42+
def tearDown(self):
43+
del os.environ['TF_MLIR_TFR_LIB_DIR']
44+
super(GraphDecomposeTest, self).tearDown()
45+
46+
def testAddN(self):
47+
add = def_function.function(gen_composite_ops.my_add_n)
48+
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
49+
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
50+
t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
51+
sq1 = add([t1])
52+
sq2 = add([t1, t2])
53+
sq3 = add([t1, t2, t3])
54+
self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4])
55+
self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8])
56+
self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12])
57+
58+
def testBiasedDense(self):
59+
biased_dense = def_function.function(gen_composite_ops.my_biased_dense)
60+
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
61+
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
62+
t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
63+
sq = biased_dense(t1, t2, t3)
64+
self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12])
65+
66+
def testBiasedDenseRelu(self):
67+
biased_dense = def_function.function(gen_composite_ops.my_biased_dense)
68+
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
69+
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
70+
t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
71+
sq = biased_dense(t1, t2, t3, act='relu')
72+
self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12])
73+
74+
def testWithKnownKernel(self):
75+
76+
@def_function.function
77+
def biasd_dense_elu(x, y, z):
78+
dot = gen_composite_ops.my_biased_dense(x, y, z)
79+
return nn_ops.elu(dot) # with known kernel, should not expand.
80+
81+
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
82+
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
83+
t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
84+
sq = biasd_dense_elu(t1, t2, t3)
85+
self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12])
86+
87+
88+
if __name__ == '__main__':
89+
ops.enable_eager_execution()
90+
test.main()

tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,9 @@ limitations under the License.
3636
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
3737
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
3838
#include "tensorflow/compiler/mlir/tfr/passes/passes.h"
39-
#include "tensorflow/core/framework/graph.pb.h"
40-
#include "tensorflow/core/framework/node_def.pb.h"
41-
#include "tensorflow/core/framework/node_def_util.h"
42-
#include "tensorflow/core/framework/versions.pb.h"
4339
#include "tensorflow/core/graph/graph.h"
4440
#include "tensorflow/core/graph/node_builder.h"
4541
#include "tensorflow/core/graph/tensor_id.h"
46-
#include "tensorflow/core/lib/core/errors.h"
47-
#include "tensorflow/core/lib/strings/str_util.h"
48-
#include "tensorflow/core/platform/errors.h"
49-
#include "tensorflow/core/platform/protobuf.h"
50-
#include "tensorflow/core/platform/types.h"
5142
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
5243
#include "tensorflow/core/protobuf/struct.pb.h"
5344
#include "tensorflow/stream_executor/lib/statusor.h"

tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,8 @@ limitations under the License.
2020
#include "mlir/IR/Module.h" // from @llvm-project
2121
#include "mlir/Pass/PassManager.h" // from @llvm-project
2222
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
23-
#include "tensorflow/core/common_runtime/optimization_registry.h"
2423
#include "tensorflow/core/framework/function.h"
2524
#include "tensorflow/core/framework/graph.pb.h"
26-
#include "tensorflow/core/framework/node_def.pb.h"
27-
#include "tensorflow/core/framework/node_def_builder.h"
28-
#include "tensorflow/core/framework/types.pb.h"
29-
#include "tensorflow/core/platform/stringpiece.h"
3025
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
3126
#include "tensorflow/stream_executor/lib/statusor.h"
3227

tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ limitations under the License.
2929
#include "tensorflow/core/framework/node_def_builder.h"
3030
#include "tensorflow/core/framework/types.h"
3131
#include "tensorflow/core/framework/types.pb.h"
32-
#include "tensorflow/core/lib/core/errors.h"
3332
#include "tensorflow/core/lib/core/status_test_util.h"
3433
#include "tensorflow/core/platform/test.h"
3534
#include "tensorflow/stream_executor/lib/statusor.h"

tensorflow/compiler/mlir/tfr/resources/BUILD

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,31 @@ cc_library(
3333
alwayslink = 1,
3434
)
3535

36+
tf_custom_op_library(
37+
name = "composite_ops.so",
38+
srcs = [
39+
"composite_ops.cc",
40+
],
41+
)
42+
3643
tf_gen_op_wrapper_py(
37-
name = "composite_ops",
38-
out = "composite_ops.py",
44+
name = "gen_composite_ops",
45+
out = "gen_composite_ops.py",
3946
deps = [
4047
":composite_ops_cc",
4148
],
4249
)
4350

51+
tf_custom_op_py_library(
52+
name = "composite_ops",
53+
dso = [":composite_ops.so"],
54+
kernels = [":composite_ops_cc"],
55+
visibility = ["//visibility:public"],
56+
deps = [
57+
":gen_composite_ops",
58+
],
59+
)
60+
4461
cc_library(
4562
name = "test_ops_cc",
4663
srcs = ["test_ops.cc"],

0 commit comments

Comments
 (0)