Skip to content

Commit 04496d3

Browse files
authored
[BYOC][FIX] Infer types in MergeComposite (apache#5766)
If InferType isn't run between partitioning passes, function calls are inserted which don't have a type. This can result in failures for patterns which want to check types. This works around it simply by running InferType after every partitioning. Change-Id: Ie0887f0564a41eb0913bfe42a362e8effe9681b9
1 parent f672639 commit 04496d3

File tree

2 files changed

+59
-18
lines changed

2 files changed

+59
-18
lines changed

src/relay/transforms/merge_composite.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,24 @@ namespace tvm {
3636
namespace relay {
3737
namespace merge_composite {
3838

39+
Function InferType(const Function& expr) {
40+
auto mod = IRModule::FromExpr(expr);
41+
mod = transform::InferType()(mod);
42+
return Downcast<Function>(mod->Lookup("main"));
43+
}
44+
3945
Expr MergeComposite(const Function& func, const Array<runtime::String>& pattern_names,
4046
const Array<DFPattern>& patterns, const std::vector<PackedFunc>& checks) {
4147
CHECK_EQ(pattern_names.size(), patterns.size());
42-
Expr merged_expr = func->body;
48+
Function merged_func = func;
4349
// merge the patterns one-by-one in order
4450
for (size_t i = 0; i < patterns.size(); i++) {
4551
Map<String, ObjectRef> attrs;
4652
attrs.Set("Composite", pattern_names[i]);
47-
merged_expr = PartitionPattern(patterns[i], merged_expr, attrs, checks[i]);
53+
merged_func = Downcast<Function>(PartitionPattern(patterns[i], merged_func, attrs, checks[i]));
54+
merged_func = InferType(merged_func);
4855
}
49-
return Function(func->params, merged_expr, func->ret_type, func->type_params, func->attrs);
56+
return std::move(merged_func);
5057
}
5158

5259
} // namespace merge_composite

tests/python/relay/test_pass_merge_composite.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -916,31 +916,63 @@ def before():
916916
x = relay.var('x', shape=(1, 10, 10, 10))
917917
w = relay.var('w', shape=(10, 10, 3, 3))
918918
b = relay.var('b', shape=(8,))
919-
conv = relay.nn.conv2d(x,
919+
add = relay.op.add(x, x)
920+
relu = relay.nn.relu(add)
921+
conv = relay.nn.conv2d(relu,
920922
w,
921923
kernel_size=(3, 3),
922924
kernel_layout="OIHW",
923925
data_layout="NHWC")
924926
bias = relay.nn.bias_add(conv, b)
925-
relu = relay.nn.relu(bias)
926-
return relay.Function([x, w, b], relu)
927+
relu2 = relay.nn.relu(bias)
928+
return run_opt_pass(relay.Function([x, w, b], relu2), relay.transform.InferType())
927929

928-
def expected():
929-
x = relay.var('x')
930-
w = relay.var('w')
931-
b = relay.var('b')
932-
conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
930+
def expected_false():
931+
x = relay.var('x', shape=(1, 10, 10, 10))
932+
w = relay.var('w', shape=(10, 10, 3, 3))
933+
b = relay.var('b', shape=(8, ))
934+
935+
x0 = relay.var('x')
936+
y0 = relay.var('y')
937+
938+
add = relay.op.add(y0, y0)
939+
relu = relay.nn.relu(add)
940+
func = relay.Function([x0, y0], relu)
941+
func = func.with_attr("PartitionedFromPattern", "add_nn.relu_")
942+
func = func.with_attr("Composite", "add_relu")
943+
call = relay.Call(func, [x, x])
944+
945+
conv = relay.nn.conv2d(call, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
933946
bias = relay.nn.bias_add(conv, b)
934-
relu = relay.nn.relu(bias)
935-
func = relay.Function([x, w, b], relu)
936-
func = func.with_attr("Composite", "conv_bias_relu")
937-
func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_")
947+
relu2 = relay.nn.relu(bias)
948+
return relay.Function([x, w, b], relu2)
938949

950+
def expected_true():
939951
x = relay.var('x', shape=(1, 10, 10, 10))
940952
w = relay.var('w', shape=(10, 10, 3, 3))
941953
b = relay.var('b', shape=(8, ))
942-
return relay.Function([x, w, b], func(x, w, b))
943954

955+
x0 = relay.var('x')
956+
y0 = relay.var('y')
957+
958+
add = relay.op.add(y0, y0)
959+
relu = relay.nn.relu(add)
960+
func = relay.Function([x0, y0], relu)
961+
func = func.with_attr("PartitionedFromPattern", "add_nn.relu_")
962+
func = func.with_attr("Composite", "add_relu")
963+
call = relay.Call(func, [x, x])
964+
965+
x2 = relay.var('x')
966+
w1 = relay.var('w')
967+
b1 = relay.var('b')
968+
conv = relay.nn.conv2d(x2, w1, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
969+
bias = relay.nn.bias_add(conv, b1)
970+
relu2 = relay.nn.relu(bias)
971+
func = relay.Function([x2, w1, b1], relu2)
972+
func = func.with_attr("Composite", "conv_bias_relu")
973+
func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_")
974+
call = relay.Call(func, [call, w, b])
975+
return relay.Function([x, w, b], call)
944976

945977
def _check_type_true(extract):
946978
conv = extract.args[0].args[0]
@@ -953,14 +985,16 @@ def _check_type_false(extract):
953985
return bool(typ.shape[0] != 1)
954986

955987
pattern_table_false = [
988+
("add_relu", make_add_relu_pattern()),
956989
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false)
957990
]
958-
check_result(pattern_table_false, before(), before())
991+
check_result(pattern_table_false, before(), expected_false())
959992

960993
pattern_table_true = [
994+
("add_relu", make_add_relu_pattern()),
961995
("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true)
962996
]
963-
check_result(pattern_table_true, before(), expected())
997+
check_result(pattern_table_true, before(), expected_true())
964998

965999

9661000
if __name__ == "__main__":

0 commit comments

Comments
 (0)