@@ -1120,12 +1120,130 @@ void testSimplifyDiv() {
1120
1120
1121
1121
IS_VAR_WITH_NAME (simplified.node (), " x" );
1122
1122
}
1123
+ }
1124
+
1125
+ void testSimplifyMod () {
1126
+ KernelScope kernel_scope;
1127
+ VarHandle x (" x" , kInt );
1128
+ VarHandle y (" y" , kInt );
1129
+ VarHandle z (" z" , kInt );
1130
+
1131
+ {
1132
+ // Constant folding works.
1133
+ ExprHandle body = ExprHandle (10 ) % 8 ;
1134
+ ExprHandle simplified = IRSimplifier::simplify (body);
1135
+ IS_IMM_WITH_VAL (Int, simplified.node (), 2 );
1136
+ }
1123
1137
1124
1138
{
1125
- ExprHandle body = x / x;
1139
+ // x % x => 0
1140
+ ExprHandle body = x % x;
1126
1141
ExprHandle simplified = IRSimplifier::simplify (body);
1142
+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1143
+ }
1127
1144
1128
- IS_IMM_WITH_VAL (Int, simplified.node (), 1 );
1145
+ {
1146
+ // 0 % x => 0
1147
+ ExprHandle body = ExprHandle (0 ) % x;
1148
+ ExprHandle simplified = IRSimplifier::simplify (body);
1149
+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1150
+ }
1151
+
1152
+ {
1153
+ // x % 1 => 0
1154
+ ExprHandle body = x % 1 ;
1155
+ ExprHandle simplified = IRSimplifier::simplify (body);
1156
+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1157
+ }
1158
+
1159
+ {
1160
+ // Doesn't change unknown mods.
1161
+ // x % y => x % y
1162
+ ExprHandle body = x % y;
1163
+ ExprHandle simplified = IRSimplifier::simplify (body);
1164
+ IS_NODE_WITH_NAME (Mod, simplified.node (), mod);
1165
+ IS_VAR_WITH_NAME (mod->lhs (), " x" );
1166
+ IS_VAR_WITH_NAME (mod->rhs (), " y" );
1167
+ }
1168
+
1169
+ {
1170
+ // don't touch if RHS is unknown.
1171
+ // 4 % x => 4 % x
1172
+ ExprHandle body = ExprHandle (4 ) % x;
1173
+ ExprHandle simplified = IRSimplifier::simplify (body);
1174
+ IS_NODE_WITH_NAME (Mod, simplified.node (), mod);
1175
+ IS_IMM_WITH_VAL (Int, mod->lhs (), 4 );
1176
+ IS_VAR_WITH_NAME (mod->rhs (), " x" );
1177
+ }
1178
+
1179
+ {
1180
+ // don't touch if LHS is unknown.
1181
+ // x % 4 => x % 4
1182
+ ExprHandle body = x % 4 ;
1183
+ ExprHandle simplified = IRSimplifier::simplify (body);
1184
+ IS_NODE_WITH_NAME (Mod, simplified.node (), mod);
1185
+ IS_VAR_WITH_NAME (mod->lhs (), " x" );
1186
+ IS_IMM_WITH_VAL (Int, mod->rhs (), 4 );
1187
+ }
1188
+
1189
+ {
1190
+ // if LHS is a multiple of RHS, mod is zero.
1191
+ // 2 * x % x => 0
1192
+ ExprHandle body = (x * 2 ) % x;
1193
+ ExprHandle simplified = IRSimplifier::simplify (body);
1194
+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1195
+ }
1196
+
1197
+ {
1198
+ // true even if the multiple is not constant.
1199
+ // x * y % x => 0
1200
+ ExprHandle body = (x * y) % x;
1201
+ ExprHandle simplified = IRSimplifier::simplify (body);
1202
+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1203
+ }
1204
+
1205
+ {
1206
+ // true with multiple unknown values in LHS.
1207
+ // x * y * z % x => 0
1208
+ ExprHandle body = (x * y * z) % x;
1209
+ ExprHandle simplified = IRSimplifier::simplify (body);
1210
+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1211
+ }
1212
+
1213
+ {
1214
+ // true if the denom is compound.
1215
+ // x * y * z % y * z => 0
1216
+ ExprHandle body = (x * y * z) % (y * z);
1217
+ ExprHandle simplified = IRSimplifier::simplify (body);
1218
+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1219
+ }
1220
+
1221
+ {
1222
+ // Sanity check true with scalars that are multiples.
1223
+ // 12 * x % 4 => 0
1224
+ ExprHandle body = (x * 12 ) % 4 ;
1225
+ ExprHandle simplified = IRSimplifier::simplify (body);
1226
+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1227
+ }
1228
+
1229
+ {
1230
+ // Sanity check not true if the smaller scalar is on LHS.
1231
+ // 4 * x % 12 => 4 * x % 12
1232
+ ExprHandle body = (x * 4 ) % 12 ;
1233
+ ExprHandle simplified = IRSimplifier::simplify (body);
1234
+ IS_NODE_WITH_NAME (Mod, simplified.node (), mod);
1235
+ IS_NODE_WITH_NAME (Mul, mod->lhs (), mul);
1236
+ IS_IMM_WITH_VAL (Int, mul->lhs (), 4 );
1237
+ IS_VAR_WITH_NAME (mul->rhs (), " x" );
1238
+ IS_IMM_WITH_VAL (Int, mod->rhs (), 12 );
1239
+ }
1240
+
1241
+ {
1242
+ // Both scalar and symbolic in multiple.
1243
+ // (6 * x * y) % (3 * x * y) => 0
1244
+ ExprHandle body = (ExprHandle (6 ) * x * y) % (x * y * 3 );
1245
+ ExprHandle simplified = IRSimplifier::simplify (body);
1246
+ IS_IMM_WITH_VAL (Int, simplified.node (), 0 );
1129
1247
}
1130
1248
}
1131
1249
@@ -2807,6 +2925,189 @@ void testSimplifyEliminateEmptyCond() {
2807
2925
}
2808
2926
}
2809
2927
2928
+ void testSimplifyConstantComparisons () {
2929
+ KernelScope kernel_scope;
2930
+
2931
+ auto ComparisonTest =
2932
+ [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) {
2933
+ ExprHandle body = CompareSelect::make (a, b, op);
2934
+ ExprHandle simplified = IRSimplifier::simplify (body);
2935
+ IS_IMM_WITH_VAL (Int, simplified.node (), result);
2936
+ };
2937
+
2938
+ // Equals.
2939
+ ComparisonTest (2 , 2 , kEQ , 1 );
2940
+ ComparisonTest (1 , 2 , kEQ , 0 );
2941
+ ComparisonTest (2 , 1 , kEQ , 0 );
2942
+
2943
+ // Greater than.
2944
+ ComparisonTest (2 , 2 , kGT , 0 );
2945
+ ComparisonTest (1 , 2 , kGT , 0 );
2946
+ ComparisonTest (2 , 1 , kGT , 1 );
2947
+
2948
+ // Greater or Equal.
2949
+ ComparisonTest (2 , 2 , kGE , 1 );
2950
+ ComparisonTest (1 , 2 , kGE , 0 );
2951
+ ComparisonTest (2 , 1 , kGE , 1 );
2952
+
2953
+ // Less Than.
2954
+ ComparisonTest (2 , 2 , kLT , 0 );
2955
+ ComparisonTest (1 , 2 , kLT , 1 );
2956
+ ComparisonTest (2 , 1 , kLT , 0 );
2957
+
2958
+ // Less or Equal.
2959
+ ComparisonTest (2 , 2 , kLE , 1 );
2960
+ ComparisonTest (1 , 2 , kLE , 1 );
2961
+ ComparisonTest (2 , 1 , kLE , 0 );
2962
+
2963
+ // Not equal.
2964
+ ComparisonTest (2 , 2 , kNE , 0 );
2965
+ ComparisonTest (1 , 2 , kNE , 1 );
2966
+ ComparisonTest (2 , 1 , kNE , 1 );
2967
+
2968
+ // With specified results:
2969
+ ExprHandle body = CompareSelect::make (2 , 2 , 5 , 42 , kNE );
2970
+ ExprHandle simplified = IRSimplifier::simplify (body);
2971
+ IS_IMM_WITH_VAL (Int, simplified.node (), 42 );
2972
+ }
2973
+
2974
+ void testSimplifySymbolicComparisons () {
2975
+ KernelScope kernel_scope;
2976
+ VarHandle x (" x" , kInt );
2977
+ VarHandle y (" y" , kInt );
2978
+
2979
+ auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL (Int, a.node (), 1 ); };
2980
+ auto TookFalseBranch = [](ExprHandle a) {
2981
+ IS_IMM_WITH_VAL (Int, a.node (), 0 );
2982
+ };
2983
+
2984
+ // EQ
2985
+
2986
+ // x == x => 1
2987
+ ExprHandle body = CompareSelect::make (x, x, kEQ );
2988
+ TookTrueBranch (IRSimplifier::simplify (body));
2989
+
2990
+ // x == x+1 => 0
2991
+ body = CompareSelect::make (x, x + 1 , kEQ );
2992
+ TookFalseBranch (IRSimplifier::simplify (body));
2993
+
2994
+ // x == x * 2 cannot simplify since we don't know x is nonzero.
2995
+ body = CompareSelect::make (x, x * 2 , kEQ );
2996
+ IS_NODE (CompareSelect, IRSimplifier::simplify (body).node ());
2997
+
2998
+ // x == x * 1 => 1
2999
+ body = CompareSelect::make (x, x * 1 , kEQ );
3000
+ TookTrueBranch (IRSimplifier::simplify (body));
3001
+
3002
+ {
3003
+ // x == y => x == y
3004
+ body = CompareSelect::make (x, y, kEQ );
3005
+ ExprHandle simplified = IRSimplifier::simplify (body);
3006
+ IS_NODE_WITH_NAME (CompareSelect, simplified.node (), cmp);
3007
+ ASSERT_EQ (cmp->compare_select_op (), kEQ );
3008
+ IS_VAR_WITH_NAME (cmp->lhs (), " x" );
3009
+ IS_VAR_WITH_NAME (cmp->rhs (), " y" );
3010
+ }
3011
+
3012
+ {
3013
+ // x == 5 => x == 5
3014
+ body = CompareSelect::make (x, 5 , kEQ );
3015
+ ExprHandle simplified = IRSimplifier::simplify (body);
3016
+ IS_NODE_WITH_NAME (CompareSelect, simplified.node (), cmp);
3017
+ ASSERT_EQ (cmp->compare_select_op (), kEQ );
3018
+ IS_VAR_WITH_NAME (cmp->lhs (), " x" );
3019
+ IS_IMM_WITH_VAL (Int, cmp->rhs (), 5 );
3020
+ }
3021
+
3022
+ // GT
3023
+
3024
+ // x+1 > x => 1
3025
+ body = CompareSelect::make (x + 1 , x, kGT );
3026
+ TookTrueBranch (IRSimplifier::simplify (body));
3027
+
3028
+ // x > x + 1 => 0
3029
+ body = CompareSelect::make (x, x + 1 , kGT );
3030
+ TookFalseBranch (IRSimplifier::simplify (body));
3031
+
3032
+ // x > x - 1 => 1
3033
+ body = CompareSelect::make (x, x - 1 , kGT );
3034
+ TookTrueBranch (IRSimplifier::simplify (body));
3035
+
3036
+ // x - 1 > x => 0
3037
+ body = CompareSelect::make (x - 1 , x, kGT );
3038
+ TookFalseBranch (IRSimplifier::simplify (body));
3039
+
3040
+ // x > x => 0
3041
+ body = CompareSelect::make (x, x, kGT );
3042
+ TookFalseBranch (IRSimplifier::simplify (body));
3043
+
3044
+ // x * 2 > x => x * 2 > x
3045
+ // since we don't know the sign of x.
3046
+ body = CompareSelect::make (x * 2 , x, kGT );
3047
+ IS_NODE (CompareSelect, IRSimplifier::simplify (body).node ());
3048
+
3049
+ // GE
3050
+
3051
+ // x+1 >= x => 1
3052
+ body = CompareSelect::make (x + 1 , x, kGE );
3053
+ TookTrueBranch (IRSimplifier::simplify (body));
3054
+
3055
+ // x >= x + 1 => 0
3056
+ body = CompareSelect::make (x, x + 1 , kGE );
3057
+ TookFalseBranch (IRSimplifier::simplify (body));
3058
+
3059
+ // x >= x => 1
3060
+ body = CompareSelect::make (x, x, kGE );
3061
+ TookTrueBranch (IRSimplifier::simplify (body));
3062
+
3063
+ // x * 2 >= x => x * 2 >= x
3064
+ // since we don't know the sign of x.
3065
+ body = CompareSelect::make (x * 2 , x, kGE );
3066
+ IS_NODE (CompareSelect, IRSimplifier::simplify (body).node ());
3067
+
3068
+ // LT
3069
+
3070
+ // x+1 < x => 0
3071
+ body = CompareSelect::make (x + 1 , x, kLT );
3072
+ TookFalseBranch (IRSimplifier::simplify (body));
3073
+
3074
+ // x < x + 1 => 1
3075
+ body = CompareSelect::make (x, x + 1 , kLT );
3076
+ TookTrueBranch (IRSimplifier::simplify (body));
3077
+
3078
+ // x < x => 0
3079
+ body = CompareSelect::make (x, x, kLT );
3080
+ TookFalseBranch (IRSimplifier::simplify (body));
3081
+
3082
+ // LE
3083
+
3084
+ // x+1 <= x => 0
3085
+ body = CompareSelect::make (x + 1 , x, kLE );
3086
+ TookFalseBranch (IRSimplifier::simplify (body));
3087
+
3088
+ // x <= x + 1 => 1
3089
+ body = CompareSelect::make (x, x + 1 , kLE );
3090
+ TookTrueBranch (IRSimplifier::simplify (body));
3091
+
3092
+ // x <= x => 1
3093
+ body = CompareSelect::make (x, x, kLE );
3094
+ TookTrueBranch (IRSimplifier::simplify (body));
3095
+
3096
+ // NE
3097
+
3098
+ // x+1 != x => 1
3099
+ body = CompareSelect::make (x + 1 , x, kNE );
3100
+ TookTrueBranch (IRSimplifier::simplify (body));
3101
+
3102
+ // x != x + 1 => 1
3103
+ body = CompareSelect::make (x, x + 1 , kNE );
3104
+ TookTrueBranch (IRSimplifier::simplify (body));
3105
+
3106
+ // x != x => 0
3107
+ body = CompareSelect::make (x, x, kNE );
3108
+ TookFalseBranch (IRSimplifier::simplify (body));
3109
+ }
3110
+
2810
3111
void testSimplifyEliminateZeroLengthFor () {
2811
3112
KernelScope kernel_scope;
2812
3113
0 commit comments