Skip to content

Commit cae7bb0

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Add end2end examples for defining and running composite ops
PiperOrigin-RevId: 337520243 Change-Id: I32265c18b3736990765a9f6c416b00c03c93d335
1 parent 6976e16 commit cae7bb0

File tree

14 files changed

+1080
-54
lines changed

14 files changed

+1080
-54
lines changed

tensorflow/compiler/mlir/tfr/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_py_test")
22
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
3+
load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries")
34
load(
45
"//third_party/mlir:tblgen.bzl",
56
"gentbl",
@@ -347,3 +348,17 @@ py_test(
347348
"//tensorflow/python:client_testlib",
348349
],
349350
)
351+
352+
py_library(
353+
name = "test_utils",
354+
srcs = ["python/test_utils.py"],
355+
srcs_version = "PY2AND3",
356+
deps = [
357+
"//tensorflow/python:client_testlib",
358+
],
359+
)
360+
361+
gen_op_libraries(
362+
name = "one_op",
363+
src = "define_op_template.py",
364+
)
Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""BUILD extension for TF composition project."""
22

3-
load("//tensorflow:tensorflow.bzl", "py_binary", "tf_gen_op_wrapper_py")
4-
load("//tensorflow:tensorflow.google.bzl", "pytype_library")
3+
load("//tensorflow:tensorflow.bzl", "py_binary", "tf_custom_op_library", "tf_gen_op_wrapper_py")
4+
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
55

66
def gen_op_libraries(
77
name,
88
src,
9-
deps,
9+
deps = [],
1010
tags = [],
1111
test = False):
1212
"""gen_op_libraries() generates all cc and py libraries for composite op source.
@@ -21,36 +21,33 @@ def gen_op_libraries(
2121
if not src.endswith(".py") or name == src[:-3]:
2222
fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src)
2323

24-
gen_op_lib_exec = src[:-3]
24+
gen_op_lib_exec = src[:-3] # Strip off the .py
2525
py_binary(
2626
name = gen_op_lib_exec,
2727
srcs = [src],
2828
srcs_version = "PY2AND3",
2929
python_version = "PY3",
3030
deps = [
31-
"//tensorflow/python:platform",
31+
"//tensorflow/compiler/mlir/tfr:op_reg_gen",
32+
"//tensorflow/compiler/mlir/tfr:tfr_gen",
33+
"//tensorflow/compiler/mlir/tfr:composite",
3234
] + deps,
3335
)
3436

35-
register_op = "register_" + name
37+
registed_op = "registed_" + name
3638
native.genrule(
37-
name = register_op,
39+
name = registed_op,
3840
srcs = [],
3941
outs = [name + ".inc.cc"],
4042
cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec,
4143
exec_tools = [":" + gen_op_lib_exec],
42-
local = 1,
4344
tags = tags,
4445
)
4546

4647
native.cc_library(
4748
name = name + "_cc",
4849
testonly = test,
49-
srcs = [":" + register_op],
50-
copts = [
51-
"-Wno-unused-result",
52-
"-Wno-unused-variable",
53-
],
50+
srcs = [":" + registed_op],
5451
deps = [
5552
"//tensorflow/core:framework",
5653
"//tensorflow/core:lib",
@@ -59,62 +56,61 @@ def gen_op_libraries(
5956
alwayslink = 1,
6057
)
6158

59+
tf_custom_op_library(
60+
name = name + ".so",
61+
srcs = [":" + registed_op],
62+
)
63+
6264
tf_gen_op_wrapper_py(
63-
name = name,
64-
out = name + ".py",
65+
name = "gen_" + name,
66+
out = "gen_" + name + ".py",
6567
deps = [
6668
":%s_cc" % name,
6769
],
6870
)
6971

70-
pytype_library(
71-
name = name + "_grads",
72-
srcs = [
73-
src,
74-
],
72+
tf_custom_op_py_library(
73+
name = name,
74+
dso = [":%s.so" % name],
75+
kernels = [":%s_cc" % name],
7576
srcs_version = "PY2AND3",
7677
deps = [
77-
"//third_party/py/numpy",
78-
"//third_party/py/tensorflow",
79-
] + deps,
80-
)
81-
82-
pytype_library(
83-
name = name + "_lib",
84-
srcs = [
85-
name + ".py",
78+
":gen_%s" % name,
8679
],
87-
srcs_version = "PY2AND3",
88-
deps = [
89-
":%s" % name,
90-
":%s_cc" % name,
91-
":%s_grads" % name,
92-
"//third_party/py/numpy",
93-
"//third_party/py/tensorflow",
94-
] + deps,
9580
)
9681

9782
# Link the register op and rebuild the binary
98-
gen_tfr_lib_exec = gen_op_lib_exec + "_registered"
83+
gen_tfr_lib_exec = gen_op_lib_exec + "_with_op_library"
9984
py_binary(
10085
name = gen_tfr_lib_exec,
10186
main = src,
10287
srcs = [src],
10388
srcs_version = "PY2AND3",
10489
python_version = "PY3",
10590
deps = [
106-
"//tensorflow/python:platform",
107-
":%s" % name + "_cc",
91+
"//tensorflow/compiler/mlir/tfr:op_reg_gen",
92+
"//tensorflow/compiler/mlir/tfr:tfr_gen",
93+
"//tensorflow/compiler/mlir/tfr:composite",
94+
":%s" % name,
10895
] + deps,
10996
)
11097

111-
op_tfr = "composite_" + name
11298
native.genrule(
113-
name = op_tfr,
99+
name = name + "_mlir",
114100
srcs = [],
115101
outs = [name + ".mlir"],
116102
cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec,
117103
exec_tools = [":" + gen_tfr_lib_exec],
118-
local = 1,
119104
tags = tags,
120105
)
106+
107+
native.py_library(
108+
name = name + "_py",
109+
srcs = [src],
110+
srcs_version = "PY2AND3",
111+
deps = [
112+
"//tensorflow/compiler/mlir/tfr:op_reg_gen",
113+
"//tensorflow/compiler/mlir/tfr:tfr_gen",
114+
"//tensorflow/compiler/mlir/tfr:composite",
115+
] + deps,
116+
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""A template to define composite ops."""
15+
16+
# pylint: disable=g-direct-tensorflow-import
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
import sys
24+
25+
from tensorflow.compiler.mlir.tfr.python.composite import Composite
26+
from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
27+
from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
28+
from tensorflow.python.platform import app
29+
from tensorflow.python.platform import flags
30+
31+
FLAGS = flags.FLAGS
32+
33+
flags.DEFINE_string(
34+
'output', None,
35+
'Path to write the genereated register op file and MLIR file.')
36+
37+
flags.DEFINE_bool('gen_register_op', True,
38+
'Generate register op cc file or tfr mlir file.')
39+
40+
flags.mark_flag_as_required('output')
41+
42+
43+
@Composite('TestRandom', derived_attrs=['T: numbertype'], outputs=['o: T'])
44+
def _composite_random_op():
45+
pass
46+
47+
48+
def main(_):
49+
if FLAGS.gen_register_op:
50+
assert FLAGS.output.endswith('.cc')
51+
generated_code = gen_register_op(sys.modules[__name__], '_composite_')
52+
else:
53+
assert FLAGS.output.endswith('.mlir')
54+
generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_')
55+
56+
dirname = os.path.dirname(FLAGS.output)
57+
if not os.path.exists(dirname):
58+
os.makedirs(dirname)
59+
with open(FLAGS.output, 'w') as f:
60+
f.write(generated_code)
61+
62+
63+
if __name__ == '__main__':
64+
app.run(main=main)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
load("//tensorflow:tensorflow.bzl", "py_binary")
2+
load("//tensorflow:tensorflow.bzl", "tf_py_test")
3+
load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries")
4+
5+
package(
6+
default_visibility = [
7+
":friends",
8+
],
9+
licenses = ["notice"], # Apache 2.0
10+
)
11+
12+
package_group(
13+
name = "friends",
14+
includes = ["//third_party/mlir:subpackages"],
15+
packages = [
16+
"//tensorflow/compiler/mlir/tfr/...",
17+
],
18+
)
19+
20+
gen_op_libraries(
21+
name = "mnist_ops",
22+
src = "ops_defs.py",
23+
deps = [
24+
"//tensorflow:tensorflow_py",
25+
],
26+
)
27+
28+
tf_py_test(
29+
name = "mnist_ops_test",
30+
size = "small",
31+
srcs = ["mnist_ops_test.py"],
32+
data = [":mnist_ops_mlir"],
33+
python_version = "PY3",
34+
srcs_version = "PY2AND3",
35+
tags = [
36+
"no_pip",
37+
"no_windows", # TODO(b/170752141)
38+
"nomac", # TODO(b/170752141)
39+
],
40+
deps = [
41+
":mnist_ops",
42+
":mnist_ops_py",
43+
"//tensorflow:tensorflow_py",
44+
"//tensorflow/compiler/mlir/tfr:test_utils",
45+
],
46+
)
47+
48+
py_binary(
49+
name = "mnist_train",
50+
srcs = ["mnist_train.py"],
51+
data = [":mnist_ops_mlir"],
52+
python_version = "PY3",
53+
deps = [
54+
":mnist_ops",
55+
":mnist_ops_py",
56+
"//tensorflow:tensorflow_py",
57+
"@absl_py//absl:app",
58+
"@absl_py//absl/flags",
59+
],
60+
)

0 commit comments

Comments
 (0)