Skip to content

Commit 0212138

Browse files
authored
[RELAY] Add MergeCompilerRegions pass (apache#5134)
* [RELAY] Add MergeCompilerRegions pass This pass is part of the flow to support creating compiler regions with multiple outputs. It should be called after AnnotateTarget and will merge together regions that share the same target to create larger compiler regions that can be off-loaded to external codegens. This pass implements an algorithm to ensure that during the merging, no data dependency issues are created. See the tests for an example of this case. Co-authored-by: Ramana Radhakrishnan <[email protected]> Co-authored-by: Manupa Karunaratne <[email protected]> Change-Id: Ibd99083564608d888482f57c5080109f3eefec88 * [RELAY] Annotate compiler_ends on each edge This alters the behaviour of the AnnotateTarget pass to enforce the property that all compiler annotations exist along a single data flow edge. Specifically, this means they should have exactly one parent and one child. Change-Id: I0e74803a77767f4f377d17755a13a74a30909797 * Fix comment * Rebase *Node::make * Moved block outside for loop * Code style * Update make API * Remove comment * Remove redundant 'else's * Make one line * Fix comment * RefWrite * Fix merge ordering * Add the RFC example as a test * [FIX] Fixed merging behaviour in AnnotateRegionSet Deleting items from a list while iterating it seems to result in undefined behaviour which sometimes segfaults. This makes sure all the item deletion happens separately. * Added checks * Move comment * Update comments
1 parent 3326031 commit 0212138

File tree

6 files changed

+726
-27
lines changed

6 files changed

+726
-27
lines changed

python/tvm/relay/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,17 @@ def MergeComposite(pattern_table):
397397
return _ffi_api.MergeComposite(pattern_names, patterns)
398398

399399

400+
def MergeCompilerRegions():
401+
"""Merge together compiler regions.
402+
403+
Returns
404+
-------
405+
ret : tvm.relay.Pass
406+
The registered pass that merges compiler regions.
407+
"""
408+
return _ffi_api.MergeCompilerRegions()
409+
410+
400411
def RewriteAnnotatedOps(fallback_device):
401412
"""Rewrite the annotated program where annotation operators, e.g.
402413
`on_deivce`, mark which device an expression should be scheduled to.

src/relay/analysis/annotated_region_set.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,18 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
5555
}
5656
// if any of the outputs of src are inputs of dest, they become internal nodes
5757
// so remove them from outs
58+
std::vector<Expr> ins_to_remove;
5859
for (const auto& input : dest->ins) {
5960
auto call = Downcast<Call>(input);
6061
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
6162
if (it != src->outs.end()) {
6263
dest->outs.remove(*it);
63-
dest->ins.remove(input);
64+
ins_to_remove.push_back(input);
6465
}
6566
}
67+
for (const auto& input : ins_to_remove) {
68+
dest->ins.remove(input);
69+
}
6670
regions_.erase(src);
6771
}
6872

src/relay/transforms/annotate_target.cc

Lines changed: 124 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,46 +38,144 @@ class AnnotateTargetWrapper : public ExprMutator {
3838
public:
3939
explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
4040

41+
Expr Annotate(const Expr& expr) {
42+
return InsertEnd(Mutate(expr));
43+
}
44+
45+
bool IsSupported(const Expr& expr) {
46+
if (expr->IsInstance<CallNode>()) {
47+
Call call = Downcast<Call>(expr);
48+
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
49+
Op op = Downcast<Op>(call->op);
50+
CHECK(op.defined());
51+
if (fannotate.count(op)) {
52+
return fannotate[op](call->attrs, call->args);
53+
}
54+
}
55+
return false;
56+
}
57+
58+
Expr InsertEnd(const Expr& arg) {
59+
if (IsSupported(arg)) {
60+
const auto *end_op =
61+
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
62+
CHECK(end_op);
63+
Expr end = (*end_op)(arg, target_);
64+
return end;
65+
}
66+
return arg;
67+
}
68+
4169
Expr VisitExpr_(const CallNode* cn) {
4270
// TODO(@zhiics, @comaniac) Handle composite functions.
4371
auto new_e = ExprMutator::VisitExpr_(cn);
4472

4573
Call call = Downcast<Call>(new_e);
46-
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
47-
Op op = Downcast<Op>(call->op);
48-
CHECK(op.defined());
49-
50-
if (fannotate.count(op)) {
51-
bool external = fannotate[op](call->attrs, call->args);
52-
if (external) {
53-
tvm::Array<tvm::relay::Expr> compiler_begins;
54-
for (const auto& it : call->args) {
55-
const auto* begin_op =
56-
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
57-
CHECK(begin_op);
58-
Expr begin = (*begin_op)(it, target_);
59-
compiler_begins.push_back(begin);
60-
}
61-
Expr update_call = Call(call->op, compiler_begins, call->attrs);
62-
const auto* end_op =
63-
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
64-
CHECK(end_op);
65-
Expr end = (*end_op)(update_call, target_);
66-
return end;
74+
75+
// add end annotations if the args are supported
76+
Array<Expr> compiler_ends;
77+
for (const auto& it : call->args) {
78+
compiler_ends.push_back(InsertEnd(it));
79+
}
80+
call = Call(call->op, compiler_ends, call->attrs);
81+
82+
// add begin annotations if the call node is supported
83+
if (IsSupported(call)) {
84+
tvm::Array<tvm::relay::Expr> compiler_begins;
85+
const auto* begin_op =
86+
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
87+
for (const auto& it : call->args) {
88+
CHECK(begin_op);
89+
Expr begin = (*begin_op)(it, target_);
90+
compiler_begins.push_back(begin);
6791
}
68-
} else {
69-
LOG(WARNING) << op->name << " in " << target_
70-
<< " is not registered. It will be executed on CPU.";
92+
call = Call(call->op, compiler_begins, call->attrs);
7193
}
72-
return new_e;
94+
95+
return std::move(call);
96+
}
97+
98+
Expr VisitExpr_(const TupleNode* op) {
99+
auto new_e = ExprMutator::VisitExpr_(op);
100+
101+
auto tup = Downcast<Tuple>(new_e);
102+
Array<Expr> new_fields;
103+
for (auto field : tup->fields) {
104+
new_fields.push_back(InsertEnd(field));
105+
}
106+
return Tuple(new_fields);
107+
}
108+
109+
Expr VisitExpr_(const TupleGetItemNode* op) {
110+
auto new_e = ExprMutator::VisitExpr_(op);
111+
112+
auto get = Downcast<TupleGetItem>(new_e);
113+
return TupleGetItem(
114+
InsertEnd(get->tuple),
115+
get->index);
116+
}
117+
118+
Expr VisitExpr_(const FunctionNode* op) {
119+
auto new_e = ExprMutator::VisitExpr_(op);
120+
121+
auto func = Downcast<Function>(new_e);
122+
return Function(
123+
func->params,
124+
InsertEnd(func->body),
125+
func->ret_type,
126+
func->type_params,
127+
func->attrs);
128+
}
129+
130+
Expr VisitExpr_(const LetNode* op) {
131+
auto new_e = ExprMutator::VisitExpr_(op);
132+
133+
auto let = Downcast<Let>(new_e);
134+
return Let(
135+
let->var,
136+
InsertEnd(let->value),
137+
InsertEnd(let->body));
138+
}
139+
140+
Expr VisitExpr_(const IfNode* op) {
141+
auto new_e = ExprMutator::VisitExpr_(op);
142+
143+
auto iff = Downcast<If>(new_e);
144+
return If(
145+
InsertEnd(iff->cond),
146+
InsertEnd(iff->true_branch),
147+
InsertEnd(iff->false_branch));
148+
}
149+
150+
Expr VisitExpr_(const RefCreateNode* op) {
151+
auto new_e = ExprMutator::VisitExpr_(op);
152+
153+
auto create = Downcast<RefCreate>(new_e);
154+
return RefCreate(InsertEnd(create->value));
155+
}
156+
157+
Expr VisitExpr_(const RefReadNode* op) {
158+
auto new_e = ExprMutator::VisitExpr_(op);
159+
160+
auto read = Downcast<RefRead>(new_e);
161+
return RefRead(InsertEnd(read->ref));
162+
}
163+
164+
Expr VisitExpr_(const RefWriteNode* op) {
165+
auto new_e = ExprMutator::VisitExpr_(op);
166+
167+
auto write = Downcast<RefWrite>(new_e);
168+
return RefWrite(
169+
InsertEnd(write->ref),
170+
InsertEnd(write->value));
73171
}
74172

75173
private:
76174
std::string target_;
77175
};
78176

79177
Expr AnnotateTarget(const Expr& expr, const std::string& target) {
80-
return AnnotateTargetWrapper(target).Mutate(expr);
178+
return AnnotateTargetWrapper(target).Annotate(expr);
81179
}
82180

83181
} // namespace annotate_target

0 commit comments

Comments
 (0)