@@ -916,31 +916,63 @@ def before():
916916 x = relay .var ('x' , shape = (1 , 10 , 10 , 10 ))
917917 w = relay .var ('w' , shape = (10 , 10 , 3 , 3 ))
918918 b = relay .var ('b' , shape = (8 ,))
919- conv = relay .nn .conv2d (x ,
919+ add = relay .op .add (x , x )
920+ relu = relay .nn .relu (add )
921+ conv = relay .nn .conv2d (relu ,
920922 w ,
921923 kernel_size = (3 , 3 ),
922924 kernel_layout = "OIHW" ,
923925 data_layout = "NHWC" )
924926 bias = relay .nn .bias_add (conv , b )
925- relu = relay .nn .relu (bias )
926- return relay .Function ([x , w , b ], relu )
927+ relu2 = relay .nn .relu (bias )
928+ return run_opt_pass ( relay .Function ([x , w , b ], relu2 ), relay . transform . InferType () )
927929
928- def expected ():
929- x = relay .var ('x' )
930- w = relay .var ('w' )
931- b = relay .var ('b' )
932- conv = relay .nn .conv2d (x , w , kernel_size = (3 , 3 ), kernel_layout = "OIHW" , data_layout = "NHWC" )
930+ def expected_false ():
931+ x = relay .var ('x' , shape = (1 , 10 , 10 , 10 ))
932+ w = relay .var ('w' , shape = (10 , 10 , 3 , 3 ))
933+ b = relay .var ('b' , shape = (8 , ))
934+
935+ x0 = relay .var ('x' )
936+ y0 = relay .var ('y' )
937+
938+ add = relay .op .add (y0 , y0 )
939+ relu = relay .nn .relu (add )
940+ func = relay .Function ([x0 , y0 ], relu )
941+ func = func .with_attr ("PartitionedFromPattern" , "add_nn.relu_" )
942+ func = func .with_attr ("Composite" , "add_relu" )
943+ call = relay .Call (func , [x , x ])
944+
945+ conv = relay .nn .conv2d (call , w , kernel_size = (3 , 3 ), kernel_layout = "OIHW" , data_layout = "NHWC" )
933946 bias = relay .nn .bias_add (conv , b )
934- relu = relay .nn .relu (bias )
935- func = relay .Function ([x , w , b ], relu )
936- func = func .with_attr ("Composite" , "conv_bias_relu" )
937- func = func .with_attr ("PartitionedFromPattern" , "nn.conv2d_nn.bias_add_nn.relu_" )
947+ relu2 = relay .nn .relu (bias )
948+ return relay .Function ([x , w , b ], relu2 )
938949
950+ def expected_true ():
939951 x = relay .var ('x' , shape = (1 , 10 , 10 , 10 ))
940952 w = relay .var ('w' , shape = (10 , 10 , 3 , 3 ))
941953 b = relay .var ('b' , shape = (8 , ))
942- return relay .Function ([x , w , b ], func (x , w , b ))
943954
955+ x0 = relay .var ('x' )
956+ y0 = relay .var ('y' )
957+
958+ add = relay .op .add (y0 , y0 )
959+ relu = relay .nn .relu (add )
960+ func = relay .Function ([x0 , y0 ], relu )
961+ func = func .with_attr ("PartitionedFromPattern" , "add_nn.relu_" )
962+ func = func .with_attr ("Composite" , "add_relu" )
963+ call = relay .Call (func , [x , x ])
964+
965+ x2 = relay .var ('x' )
966+ w1 = relay .var ('w' )
967+ b1 = relay .var ('b' )
968+ conv = relay .nn .conv2d (x2 , w1 , kernel_size = (3 , 3 ), kernel_layout = "OIHW" , data_layout = "NHWC" )
969+ bias = relay .nn .bias_add (conv , b1 )
970+ relu2 = relay .nn .relu (bias )
971+ func = relay .Function ([x2 , w1 , b1 ], relu2 )
972+ func = func .with_attr ("Composite" , "conv_bias_relu" )
973+ func = func .with_attr ("PartitionedFromPattern" , "nn.conv2d_nn.bias_add_nn.relu_" )
974+ call = relay .Call (func , [call , w , b ])
975+ return relay .Function ([x , w , b ], call )
944976
945977 def _check_type_true (extract ):
946978 conv = extract .args [0 ].args [0 ]
@@ -953,14 +985,16 @@ def _check_type_false(extract):
953985 return bool (typ .shape [0 ] != 1 )
954986
955987 pattern_table_false = [
988+ ("add_relu" , make_add_relu_pattern ()),
956989 ("conv_bias_relu" , make_conv_bias_relu_pattern (), _check_type_false )
957990 ]
958- check_result (pattern_table_false , before (), before ())
991+ check_result (pattern_table_false , before (), expected_false ())
959992
960993 pattern_table_true = [
994+ ("add_relu" , make_add_relu_pattern ()),
961995 ("conv_bias_relu" , make_conv_bias_relu_pattern (), _check_type_true )
962996 ]
963- check_result (pattern_table_true , before (), expected ())
997+ check_result (pattern_table_true , before (), expected_true ())
964998
965999
9661000if __name__ == "__main__" :
0 commit comments