Skip to content

Transforms match into an assignment statement #120614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Transforms a match containing negative numbers into an assignment sta…
…tement as well
  • Loading branch information
dianqk committed Apr 8, 2024
commit e752af765ea04ba663d82524cfdcc2b7b6cb58aa
49 changes: 38 additions & 11 deletions compiler/rustc_mir_transform/src/match_branches.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use rustc_index::IndexVec;
use rustc_middle::mir::*;
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
use rustc_target::abi::Size;
use std::iter;

use super::simplify::simplify_cfg;
Expand Down Expand Up @@ -67,13 +68,13 @@ trait SimplifyMatch<'tcx> {
_ => unreachable!(),
};

if !self.can_simplify(tcx, targets, param_env, bbs) {
let discr_ty = discr.ty(local_decls, tcx);
if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) {
return false;
}

// Take ownership of items now that we know we can optimize.
let discr = discr.clone();
let discr_ty = discr.ty(local_decls, tcx);

// Introduce a temporary for the discriminant value.
let source_info = bbs[switch_bb_idx].terminator().source_info;
Expand Down Expand Up @@ -104,6 +105,7 @@ trait SimplifyMatch<'tcx> {
targets: &SwitchTargets,
param_env: ParamEnv<'tcx>,
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
discr_ty: Ty<'tcx>,
) -> bool;

fn new_stmts(
Expand Down Expand Up @@ -157,6 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
targets: &SwitchTargets,
param_env: ParamEnv<'tcx>,
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
_discr_ty: Ty<'tcx>,
) -> bool {
if targets.iter().len() != 1 {
return false;
Expand Down Expand Up @@ -268,7 +271,7 @@ struct SimplifyToExp {
enum CompareType<'tcx, 'a> {
Same(&'a StatementKind<'tcx>),
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
Discr(&'a Place<'tcx>, Ty<'tcx>),
Discr(&'a Place<'tcx>, Ty<'tcx>, bool),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment the 3 variants?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. I've also updated Discr. But I'm not sure the current comments are clear enough.

}

enum TransfromType {
Expand All @@ -282,7 +285,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
match compare_type {
CompareType::Same(_) => TransfromType::Same,
CompareType::Eq(_, _, _) => TransfromType::Eq,
CompareType::Discr(_, _) => TransfromType::Discr,
CompareType::Discr(_, _, _) => TransfromType::Discr,
}
}
}
Expand Down Expand Up @@ -333,6 +336,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
targets: &SwitchTargets,
param_env: ParamEnv<'tcx>,
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
discr_ty: Ty<'tcx>,
) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code readability idea: would it simplify anything to return an Option<()> here, and use ? for short-circuiting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm using it in the wrong way. I didn't feel the improvement. Maybe because there is not Option here?
Probably better to use something like anyhow::bail!.

if targets.iter().len() < 2 || targets.iter().len() > 64 {
return false;
Expand All @@ -355,13 +359,19 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
return false;
}

let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
let first_stmts = &bbs[first_target].statements;
let (second_val, second_target) = target_iter.next().unwrap();
let second_stmts = &bbs[second_target].statements;
if first_stmts.len() != second_stmts.len() {
return false;
}

fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool {
l.try_to_int(l.size()).unwrap()
== ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap()
Copy link
Member

@RalfJung RalfJung Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of doing the comparison in such a complicated way? Why turn r into a ScalarInt and then back into a i128?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because SwitchTargets only saves the value of the corresponding bit value by u128, it lacks information on bit-width and sign.

Copy link
Member

@RalfJung RalfJung Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need the sign, as I said: == is sign-independent, so we can figure out the switch target without knowing the sign. And it has the width, it is determined by the type of the match operand.

Copy link
Member Author

@dianqk dianqk Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right. This code is used to compare signed integers of different widths, such as i8 and i16. We must add additional conversions for this scenario.
Hmm, it's been just a month, and I have to carefully review the code to respond correctly. It seems I indeed wrote some hard-to-understand code. :>
I'll add some comments. I should also move the unsigned comparison to the front, as I expect this might make the code a bit faster.

}

let mut compare_types = Vec::new();
for (f, s) in iter::zip(first_stmts, second_stmts) {
let compare_type = match (&f.kind, &s.kind) {
Expand All @@ -382,12 +392,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
) {
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
(Some(f), Some(s))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch definitely needs a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. The main thing here is to deal with the various enum variants.

if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
&& int_equal(f, first_val, discr_size)
&& int_equal(s, second_val, discr_size))
|| (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
&& Some(s)
== ScalarInt::try_from_uint(second_val, s.size())) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very strange. == is sign-independent, so I don't understand why both cases need to be considered here. Furthermore, if the sign mattered, then surely it makes no sense to check they_are_equal_signed || they_are_equal_unsigned; instead you have to check if they_are_signed { they_are_equal_signed } else { they_are_equal_unsigned }. Finally, f_c.const_.ty().is_signed() || discr_ty.is_signed() sounds like you are mixing signed and unsigned values (as in, LHS and RHS can have different sign), which should never happen.

What is going on here?

Copy link
Member Author

@dianqk dianqk Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to handle conversions such as from enum(u32) to i32 or from enum(i32) to u32:

#[repr(u8)]
enum EnumAu8 {
A = 1,
B = 2,
}
// EMIT_MIR matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff
fn match_u8_i16(i: EnumAu8) -> i16 {
// CHECK-LABEL: fn match_u8_i16(
// CHECK-NOT: switchInt
// CHECK: _0 = _3 as i16 (IntToInt);
// CHECH: return
match i {
EnumAu8::A => 1,
EnumAu8::B => 2,
}
}
.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha... I have no idea what this means. ;)

But it makes no sense to compare things twice here. == is entirely based on being equal bitwise, so the sign doesn't matter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above (different width), corresponding test case:

#[repr(i8)]
enum EnumAi8 {
A = -1,
B = 2,
C = -3,
}
// EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
fn match_i8_i16(i: EnumAi8) -> i16 {
// CHECK-LABEL: fn match_i8_i16(
// CHECK-NOT: switchInt
// CHECK: _0 = _3 as i16 (IntToInt);
// CHECH: return
match i {
EnumAi8::A => -1,
EnumAi8::B => 2,
EnumAi8::C => -3,
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you are comparing both signed and unsigned representation completely disregarding whether the value is actually signed or unsigned. It sounds like you want two cases

  • value is signed; then convert everything to i128 (with sign extension, e.g. via try_to_int) and compare there
  • value is unsigned; the convert everything to u128 and compare there

But currently you're interpreting the number both ways and then checking if either comparison succeeds. It seems to me that can sometimes lead to blatantly wrong results, e.g. when two numbers are equal unsigned but different after sign extension, and they are actually signed -- you code will treat them as equal, I think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I'm fixing it right now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #124122, thanks for pointing it out.

{
CompareType::Discr(lhs_f, f_c.const_.ty())
CompareType::Discr(
lhs_f,
f_c.const_.ty(),
f_c.const_.ty().is_signed() || discr_ty.is_signed(),
)
}
_ => {
return false;
}
_ => return false,
}
}

Expand All @@ -413,15 +433,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
&& s_c.const_.ty() == f_ty
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
(
CompareType::Discr(lhs_f, f_ty),
CompareType::Discr(lhs_f, f_ty, is_signed),
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
return false;
};
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
return false;
if is_signed
&& s_c.const_.ty().is_signed()
&& int_equal(f, other_val, discr_size)
{
continue;
}
if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
continue;
}
return false;
}
_ => return false,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,37 @@
debug i => _1;
let mut _0: i8;
let mut _2: i16;
+ let mut _3: i16;

bb0: {
_2 = discriminant(_1);
switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1];
}

bb1: {
unreachable;
}

bb2: {
_0 = const -3_i8;
goto -> bb5;
}

bb3: {
_0 = const -1_i8;
goto -> bb5;
}

bb4: {
_0 = const 2_i8;
goto -> bb5;
}

bb5: {
- switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1];
- }
-
- bb1: {
- unreachable;
- }
-
- bb2: {
- _0 = const -3_i8;
- goto -> bb5;
- }
-
- bb3: {
- _0 = const -1_i8;
- goto -> bb5;
- }
-
- bb4: {
- _0 = const 2_i8;
- goto -> bb5;
- }
-
- bb5: {
+ StorageLive(_3);
+ _3 = move _2;
+ _0 = _3 as i8 (IntToInt);
+ StorageDead(_3);
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,37 @@
debug i => _1;
let mut _0: i16;
let mut _2: i8;
+ let mut _3: i8;

bb0: {
_2 = discriminant(_1);
switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
}

bb1: {
unreachable;
}

bb2: {
_0 = const -3_i16;
goto -> bb5;
}

bb3: {
_0 = const -1_i16;
goto -> bb5;
}

bb4: {
_0 = const 2_i16;
goto -> bb5;
}

bb5: {
- switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
- }
-
- bb1: {
- unreachable;
- }
-
- bb2: {
- _0 = const -3_i16;
- goto -> bb5;
- }
-
- bb3: {
- _0 = const -1_i16;
- goto -> bb5;
- }
-
- bb4: {
- _0 = const 2_i16;
- goto -> bb5;
- }
-
- bb5: {
+ StorageLive(_3);
+ _3 = move _2;
+ _0 = _3 as i16 (IntToInt);
+ StorageDead(_3);
return;
}
}
Expand Down
8 changes: 6 additions & 2 deletions tests/mir-opt/matches_reduce_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ enum EnumAi8 {
// EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
fn match_i8_i16(i: EnumAi8) -> i16 {
// CHECK-LABEL: fn match_i8_i16(
// CHECK: switchInt
// CHECK-NOT: switchInt
// CHECK: _0 = _3 as i16 (IntToInt);
// CHECH: return
match i {
EnumAi8::A => -1,
EnumAi8::B => 2,
Expand Down Expand Up @@ -233,7 +235,9 @@ enum EnumAi16 {
// EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
fn match_i16_i8(i: EnumAi16) -> i8 {
// CHECK-LABEL: fn match_i16_i8(
// CHECK: switchInt
// CHECK-NOT: switchInt
// CHECK: _0 = _3 as i8 (IntToInt);
// CHECH: return
match i {
EnumAi16::A => -1,
EnumAi16::B => 2,
Expand Down