Skip to content

Commit 17f8c32

Browse files
nickggfacebook-github-bot
authored andcommitted
[NNC] IRSimplifier rules for Compare and Mod (#46412)
Summary: Adds new rules to the NNC IRSimplifier to take care of the following cases: * Comparisons which are symbolic but have a constant difference. E.g. this is most useful in cases like `if (x > x + 4) ...` which we can now eliminate. * Simplification of `Mod` nodes, including simple rules such as `0 % x` and `x % 1`, but also factorization of both sides to find common symbolic multiples. E.g. `(x * y) % x` can be cancelled out to `0`. See tests for many more examples! Pull Request resolved: #46412 Reviewed By: navahgar Differential Revision: D24396151 Pulled By: nickgg fbshipit-source-id: abb954dc930867d62010dcbcd8a4701430733715
1 parent a06b95b commit 17f8c32

File tree

4 files changed

+442
-18
lines changed

4 files changed

+442
-18
lines changed

test/cpp/tensorexpr/test_simplify.cpp

+303-2
Original file line numberDiff line numberDiff line change
@@ -1120,12 +1120,130 @@ void testSimplifyDiv() {
11201120

11211121
IS_VAR_WITH_NAME(simplified.node(), "x");
11221122
}
1123+
}
1124+
1125+
void testSimplifyMod() {
1126+
KernelScope kernel_scope;
1127+
VarHandle x("x", kInt);
1128+
VarHandle y("y", kInt);
1129+
VarHandle z("z", kInt);
1130+
1131+
{
1132+
// Constant folding works.
1133+
ExprHandle body = ExprHandle(10) % 8;
1134+
ExprHandle simplified = IRSimplifier::simplify(body);
1135+
IS_IMM_WITH_VAL(Int, simplified.node(), 2);
1136+
}
11231137

11241138
{
1125-
ExprHandle body = x / x;
1139+
// x % x => 0
1140+
ExprHandle body = x % x;
11261141
ExprHandle simplified = IRSimplifier::simplify(body);
1142+
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1143+
}
11271144

1128-
IS_IMM_WITH_VAL(Int, simplified.node(), 1);
1145+
{
1146+
// 0 % x => 0
1147+
ExprHandle body = ExprHandle(0) % x;
1148+
ExprHandle simplified = IRSimplifier::simplify(body);
1149+
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1150+
}
1151+
1152+
{
1153+
// x % 1 => 0
1154+
ExprHandle body = x % 1;
1155+
ExprHandle simplified = IRSimplifier::simplify(body);
1156+
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1157+
}
1158+
1159+
{
1160+
// Doesn't change unknown mods.
1161+
// x % y => x % y
1162+
ExprHandle body = x % y;
1163+
ExprHandle simplified = IRSimplifier::simplify(body);
1164+
IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1165+
IS_VAR_WITH_NAME(mod->lhs(), "x");
1166+
IS_VAR_WITH_NAME(mod->rhs(), "y");
1167+
}
1168+
1169+
{
1170+
// don't touch if RHS is unknown.
1171+
// 4 % x => 4 % x
1172+
ExprHandle body = ExprHandle(4) % x;
1173+
ExprHandle simplified = IRSimplifier::simplify(body);
1174+
IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1175+
IS_IMM_WITH_VAL(Int, mod->lhs(), 4);
1176+
IS_VAR_WITH_NAME(mod->rhs(), "x");
1177+
}
1178+
1179+
{
1180+
// don't touch if LHS is unknown.
1181+
// x % 4 => x % 4
1182+
ExprHandle body = x % 4;
1183+
ExprHandle simplified = IRSimplifier::simplify(body);
1184+
IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1185+
IS_VAR_WITH_NAME(mod->lhs(), "x");
1186+
IS_IMM_WITH_VAL(Int, mod->rhs(), 4);
1187+
}
1188+
1189+
{
1190+
// if LHS is a multiple of RHS, mod is zero.
1191+
// 2 * x % x => 0
1192+
ExprHandle body = (x * 2) % x;
1193+
ExprHandle simplified = IRSimplifier::simplify(body);
1194+
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1195+
}
1196+
1197+
{
1198+
// true even if the multiple is not constant.
1199+
// x * y % x => 0
1200+
ExprHandle body = (x * y) % x;
1201+
ExprHandle simplified = IRSimplifier::simplify(body);
1202+
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1203+
}
1204+
1205+
{
1206+
// true with multiple unknown values in LHS.
1207+
// x * y * z % x => 0
1208+
ExprHandle body = (x * y * z) % x;
1209+
ExprHandle simplified = IRSimplifier::simplify(body);
1210+
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1211+
}
1212+
1213+
{
1214+
// true if the denom is compound.
1215+
// x * y * z % y * z => 0
1216+
ExprHandle body = (x * y * z) % (y * z);
1217+
ExprHandle simplified = IRSimplifier::simplify(body);
1218+
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1219+
}
1220+
1221+
{
1222+
// Sanity check true with scalars that are multiples.
1223+
// 12 * x % 4 => 0
1224+
ExprHandle body = (x * 12) % 4;
1225+
ExprHandle simplified = IRSimplifier::simplify(body);
1226+
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
1227+
}
1228+
1229+
{
1230+
// Sanity check not true if the smaller scalar is on LHS.
1231+
// 4 * x % 12 => 4 * x % 12
1232+
ExprHandle body = (x * 4) % 12;
1233+
ExprHandle simplified = IRSimplifier::simplify(body);
1234+
IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1235+
IS_NODE_WITH_NAME(Mul, mod->lhs(), mul);
1236+
IS_IMM_WITH_VAL(Int, mul->lhs(), 4);
1237+
IS_VAR_WITH_NAME(mul->rhs(), "x");
1238+
IS_IMM_WITH_VAL(Int, mod->rhs(), 12);
1239+
}
1240+
1241+
{
1242+
// Both scalar and symbolic in multiple.
1243+
// (6 * x * y) % (3 * x * y) => 0
1244+
ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3);
1245+
ExprHandle simplified = IRSimplifier::simplify(body);
1246+
IS_IMM_WITH_VAL(Int, simplified.node(), 0);
11291247
}
11301248
}
11311249

@@ -2807,6 +2925,189 @@ void testSimplifyEliminateEmptyCond() {
28072925
}
28082926
}
28092927

2928+
void testSimplifyConstantComparisons() {
2929+
KernelScope kernel_scope;
2930+
2931+
auto ComparisonTest =
2932+
[](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) {
2933+
ExprHandle body = CompareSelect::make(a, b, op);
2934+
ExprHandle simplified = IRSimplifier::simplify(body);
2935+
IS_IMM_WITH_VAL(Int, simplified.node(), result);
2936+
};
2937+
2938+
// Equals.
2939+
ComparisonTest(2, 2, kEQ, 1);
2940+
ComparisonTest(1, 2, kEQ, 0);
2941+
ComparisonTest(2, 1, kEQ, 0);
2942+
2943+
// Greater than.
2944+
ComparisonTest(2, 2, kGT, 0);
2945+
ComparisonTest(1, 2, kGT, 0);
2946+
ComparisonTest(2, 1, kGT, 1);
2947+
2948+
// Greater or Equal.
2949+
ComparisonTest(2, 2, kGE, 1);
2950+
ComparisonTest(1, 2, kGE, 0);
2951+
ComparisonTest(2, 1, kGE, 1);
2952+
2953+
// Less Than.
2954+
ComparisonTest(2, 2, kLT, 0);
2955+
ComparisonTest(1, 2, kLT, 1);
2956+
ComparisonTest(2, 1, kLT, 0);
2957+
2958+
// Less or Equal.
2959+
ComparisonTest(2, 2, kLE, 1);
2960+
ComparisonTest(1, 2, kLE, 1);
2961+
ComparisonTest(2, 1, kLE, 0);
2962+
2963+
// Not equal.
2964+
ComparisonTest(2, 2, kNE, 0);
2965+
ComparisonTest(1, 2, kNE, 1);
2966+
ComparisonTest(2, 1, kNE, 1);
2967+
2968+
// With specified results:
2969+
ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE);
2970+
ExprHandle simplified = IRSimplifier::simplify(body);
2971+
IS_IMM_WITH_VAL(Int, simplified.node(), 42);
2972+
}
2973+
2974+
void testSimplifySymbolicComparisons() {
2975+
KernelScope kernel_scope;
2976+
VarHandle x("x", kInt);
2977+
VarHandle y("y", kInt);
2978+
2979+
auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); };
2980+
auto TookFalseBranch = [](ExprHandle a) {
2981+
IS_IMM_WITH_VAL(Int, a.node(), 0);
2982+
};
2983+
2984+
// EQ
2985+
2986+
// x == x => 1
2987+
ExprHandle body = CompareSelect::make(x, x, kEQ);
2988+
TookTrueBranch(IRSimplifier::simplify(body));
2989+
2990+
// x == x+1 => 0
2991+
body = CompareSelect::make(x, x + 1, kEQ);
2992+
TookFalseBranch(IRSimplifier::simplify(body));
2993+
2994+
// x == x * 2 cannot simplify since we don't know x is nonzero.
2995+
body = CompareSelect::make(x, x * 2, kEQ);
2996+
IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
2997+
2998+
// x == x * 1 => 1
2999+
body = CompareSelect::make(x, x * 1, kEQ);
3000+
TookTrueBranch(IRSimplifier::simplify(body));
3001+
3002+
{
3003+
// x == y => x == y
3004+
body = CompareSelect::make(x, y, kEQ);
3005+
ExprHandle simplified = IRSimplifier::simplify(body);
3006+
IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
3007+
ASSERT_EQ(cmp->compare_select_op(), kEQ);
3008+
IS_VAR_WITH_NAME(cmp->lhs(), "x");
3009+
IS_VAR_WITH_NAME(cmp->rhs(), "y");
3010+
}
3011+
3012+
{
3013+
// x == 5 => x == 5
3014+
body = CompareSelect::make(x, 5, kEQ);
3015+
ExprHandle simplified = IRSimplifier::simplify(body);
3016+
IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
3017+
ASSERT_EQ(cmp->compare_select_op(), kEQ);
3018+
IS_VAR_WITH_NAME(cmp->lhs(), "x");
3019+
IS_IMM_WITH_VAL(Int, cmp->rhs(), 5);
3020+
}
3021+
3022+
// GT
3023+
3024+
// x+1 > x => 1
3025+
body = CompareSelect::make(x + 1, x, kGT);
3026+
TookTrueBranch(IRSimplifier::simplify(body));
3027+
3028+
// x > x + 1 => 0
3029+
body = CompareSelect::make(x, x + 1, kGT);
3030+
TookFalseBranch(IRSimplifier::simplify(body));
3031+
3032+
// x > x - 1 => 1
3033+
body = CompareSelect::make(x, x - 1, kGT);
3034+
TookTrueBranch(IRSimplifier::simplify(body));
3035+
3036+
// x - 1 > x => 0
3037+
body = CompareSelect::make(x - 1, x, kGT);
3038+
TookFalseBranch(IRSimplifier::simplify(body));
3039+
3040+
// x > x => 0
3041+
body = CompareSelect::make(x, x, kGT);
3042+
TookFalseBranch(IRSimplifier::simplify(body));
3043+
3044+
// x * 2 > x => x * 2 > x
3045+
// since we don't know the sign of x.
3046+
body = CompareSelect::make(x * 2, x, kGT);
3047+
IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
3048+
3049+
// GE
3050+
3051+
// x+1 >= x => 1
3052+
body = CompareSelect::make(x + 1, x, kGE);
3053+
TookTrueBranch(IRSimplifier::simplify(body));
3054+
3055+
// x >= x + 1 => 0
3056+
body = CompareSelect::make(x, x + 1, kGE);
3057+
TookFalseBranch(IRSimplifier::simplify(body));
3058+
3059+
// x >= x => 1
3060+
body = CompareSelect::make(x, x, kGE);
3061+
TookTrueBranch(IRSimplifier::simplify(body));
3062+
3063+
// x * 2 >= x => x * 2 >= x
3064+
// since we don't know the sign of x.
3065+
body = CompareSelect::make(x * 2, x, kGE);
3066+
IS_NODE(CompareSelect, IRSimplifier::simplify(body).node());
3067+
3068+
// LT
3069+
3070+
// x+1 < x => 0
3071+
body = CompareSelect::make(x + 1, x, kLT);
3072+
TookFalseBranch(IRSimplifier::simplify(body));
3073+
3074+
// x < x + 1 => 1
3075+
body = CompareSelect::make(x, x + 1, kLT);
3076+
TookTrueBranch(IRSimplifier::simplify(body));
3077+
3078+
// x < x => 0
3079+
body = CompareSelect::make(x, x, kLT);
3080+
TookFalseBranch(IRSimplifier::simplify(body));
3081+
3082+
// LE
3083+
3084+
// x+1 <= x => 0
3085+
body = CompareSelect::make(x + 1, x, kLE);
3086+
TookFalseBranch(IRSimplifier::simplify(body));
3087+
3088+
// x <= x + 1 => 1
3089+
body = CompareSelect::make(x, x + 1, kLE);
3090+
TookTrueBranch(IRSimplifier::simplify(body));
3091+
3092+
// x <= x => 1
3093+
body = CompareSelect::make(x, x, kLE);
3094+
TookTrueBranch(IRSimplifier::simplify(body));
3095+
3096+
// NE
3097+
3098+
// x+1 != x => 1
3099+
body = CompareSelect::make(x + 1, x, kNE);
3100+
TookTrueBranch(IRSimplifier::simplify(body));
3101+
3102+
// x != x + 1 => 1
3103+
body = CompareSelect::make(x, x + 1, kNE);
3104+
TookTrueBranch(IRSimplifier::simplify(body));
3105+
3106+
// x != x => 0
3107+
body = CompareSelect::make(x, x, kNE);
3108+
TookFalseBranch(IRSimplifier::simplify(body));
3109+
}
3110+
28103111
void testSimplifyEliminateZeroLengthFor() {
28113112
KernelScope kernel_scope;
28123113

test/cpp/tensorexpr/tests.h

+3
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ namespace jit {
194194
_(SimplifyMuls) \
195195
_(SimplifySubs) \
196196
_(SimplifyDiv) \
197+
_(SimplifyMod) \
197198
_(SimplifyMultiOp) \
198199
_(SimplifyManyOps) \
199200
_(SimplifyFactorization) \
@@ -214,6 +215,8 @@ namespace jit {
214215
_(SimplifyConstantBranches) \
215216
_(SimplifyConstantCond) \
216217
_(SimplifyEliminateEmptyCond) \
218+
_(SimplifyConstantComparisons) \
219+
_(SimplifySymbolicComparisons) \
217220
_(SimplifyEliminateZeroLengthFor) \
218221
_(SimplifyOneLoopFor) \
219222
_(SimplifyForWontLoseLoopOptions) \

0 commit comments

Comments
 (0)