Skip to content

Commit 54975a3

Browse files
authored
[RELAY][FIX] Fix hang in MergeCompilerRegions (apache#5227)
For certain network topologies, MCR could hang. This patch fixes that case. Change-Id: I3edd8a8a6b452b2b838b777720adea22a3b995b4
1 parent b796c13 commit 54975a3

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/relay/analysis/annotated_region_set.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include <tvm/relay/expr.h>
2323
#include <tvm/ir/error.h>
2424

25-
#include <algorithm>
2625
#include <unordered_map>
2726
#include <vector>
2827

@@ -58,8 +57,8 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
5857
std::vector<Expr> ins_to_remove;
5958
for (const auto& input : dest->ins) {
6059
auto call = Downcast<Call>(input);
61-
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
62-
if (it != src->outs.end()) {
60+
auto it = src->nodes.find(call->args[0]);
61+
if (it != src->nodes.end()) {
6362
dest->outs.remove(*it);
6463
ins_to_remove.push_back(input);
6564
}

src/relay/transforms/merge_compiler_regions.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ class RegionMerger : public ExprVisitor {
263263
void VisitExpr_(const CallNode* call) final {
264264
if (call->op == compiler_end_op) {
265265
auto region = regions_->GetRegion(GetRef<Call>(call));
266+
if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
266267
// set the region target
267268
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
268269
region_targets_[region->GetID()] = compiler_attrs->compiler;
@@ -281,13 +282,13 @@ class RegionMerger : public ExprVisitor {
281282
}
282283
}
283284
// get the mergeable regions now all the parents have been visited
284-
std::vector<AnnotatedRegion> mergeable_regions;
285+
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
285286
for (const auto& arg : region->GetInputs()) {
286287
auto begin = Downcast<Call>(arg);
287288
CHECK_EQ(begin->op, compiler_begin_op);
288289
auto parent_region = regions_->GetRegion(begin->args[0]);
289290
if (!parent_region.defined()) continue;
290-
mergeable_regions.push_back(parent_region);
291+
mergeable_regions.insert(parent_region);
291292
}
292293
auto& region_restrictions = region_restrictions_[region->GetID()];
293294
for (const auto& parent_region : mergeable_regions) {

0 commit comments

Comments
 (0)