Skip to content

Commit 0884659

Browse files
authored
[RELAY][BYOC] Preserve type information in Merge Composite (apache#5640)
Keep the type information when extracting patterns so that it can be used as part of 'check' functions. Change-Id: I16cc70c3d013a794d2ceefb5bec815129c7b8825
1 parent c365c2a commit 0884659

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

src/relay/transforms/merge_composite.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class MergeCompositeWrapper : public ExprMutator {
4646
if (var_map->find(pattern->name_hint()) == var_map->end()) {
4747
// if we haven't encountered this var yet, make a new free var and associate
4848
// it with the value at 'root'
49-
auto free_var = Var(pattern->name_hint(), Type());
49+
auto free_var = Var(pattern->name_hint(), root->checked_type());
50+
free_var->checked_type_ = root->checked_type();
5051
var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
5152
return std::move(free_var);
5253
} else {
@@ -147,7 +148,9 @@ class MergeCompositeWrapper : public ExprMutator {
147148
new_args.push_back(new_arg);
148149
i++;
149150
}
150-
return Call(root->op, new_args, root->attrs);
151+
Call new_call = Call(root->op, new_args, root->attrs);
152+
new_call->checked_type_ = root->checked_type();
153+
return std::move(new_call);
151154
}
152155

153156
Expr VisitExpr_(const CallNode* cn) {
@@ -163,12 +166,15 @@ class MergeCompositeWrapper : public ExprMutator {
163166
auto new_e = this->Mutate(arg);
164167
new_args.push_back(new_e);
165168
}
166-
return Call(call->op, new_args, call->attrs);
169+
Call new_call = Call(call->op, new_args, call->attrs);
170+
new_call->checked_type_ = call->checked_type();
171+
return std::move(new_call);
167172
}
168173
}
169174

170175
Expr expr = ExprMutator::VisitExpr_(cn);
171176
call = Downcast<Call>(expr);
177+
call->checked_type_ = cn->checked_type();
172178
if (!call->op->IsInstance<OpNode>()) return std::move(call);
173179

174180
// only call patterns are supported
@@ -189,6 +195,7 @@ class MergeCompositeWrapper : public ExprMutator {
189195
args.push_back(args_map[free_var->name_hint()][1]);
190196
}
191197
auto new_call = Call(f, args);
198+
new_call->checked_type_ = call->checked_type();
192199
return std::move(new_call);
193200
}
194201
return std::move(call);

tests/python/relay/test_pass_merge_composite.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,46 @@ def get_net():
803803
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
804804

805805

806+
def test_type_check():
807+
"""Test that we can query tensor types in the 'check' function."""
808+
def before():
809+
x = relay.var('x', shape=(1, 10, 10, 10))
810+
w = relay.var('w', shape=(10, 10, 3, 3))
811+
b = relay.var('b', shape=(8,))
812+
conv = relay.nn.conv2d(x,
813+
w,
814+
kernel_size=(3, 3),
815+
kernel_layout="OIHW",
816+
data_layout="NHWC")
817+
bias = relay.nn.bias_add(conv, b)
818+
relu = relay.nn.relu(bias)
819+
return relay.Function([x, w, b], relu)
820+
821+
def _check_type_true(extract):
822+
conv = extract.args[0].args[0]
823+
typ = conv.checked_type
824+
return bool(typ.shape[0] == 1)
825+
826+
def _check_type_false(extract):
827+
conv = extract.args[0].args[0]
828+
typ = conv.checked_type
829+
return bool(typ.shape[0] != 1)
830+
831+
pattern_table_true = [
832+
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true)
833+
]
834+
pattern_table_false = [
835+
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false)
836+
]
837+
838+
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false))
839+
expected = run_opt_pass(before(), relay.transform.InferType())
840+
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
841+
842+
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true))
843+
assert result.body.op.attrs["Composite"] == "conv_bias_relu"
844+
845+
806846
if __name__ == "__main__":
807847
test_simple_merge()
808848
test_branch_merge()
@@ -814,3 +854,4 @@ def get_net():
814854
test_tuple_get_item_merge()
815855
test_pattern_with_check()
816856
test_diamond_not_merge()
857+
test_type_check()

0 commit comments

Comments
 (0)