@@ -45513,6 +45513,7 @@ static SDValue combinevXi1ConstantToInteger(SDValue Op, SelectionDAG &DAG) {
45513
45513
static SDValue combineCastedMaskArithmetic(SDNode *N, SelectionDAG &DAG,
45514
45514
TargetLowering::DAGCombinerInfo &DCI,
45515
45515
const X86Subtarget &Subtarget) {
45516
+ using namespace SDPatternMatch;
45516
45517
assert(N->getOpcode() == ISD::BITCAST && "Expected a bitcast");
45517
45518
45518
45519
if (!DCI.isBeforeLegalizeOps())
@@ -45526,34 +45527,25 @@ static SDValue combineCastedMaskArithmetic(SDNode *N, SelectionDAG &DAG,
45526
45527
SDValue Op = N->getOperand(0);
45527
45528
EVT SrcVT = Op.getValueType();
45528
45529
45529
- if (!Op.hasOneUse())
45530
- return SDValue();
45531
-
45532
- // Look for logic ops.
45533
- if (Op.getOpcode() != ISD::AND &&
45534
- Op.getOpcode() != ISD::OR &&
45535
- Op.getOpcode() != ISD::XOR)
45536
- return SDValue();
45537
-
45538
45530
// Make sure we have a bitcast between mask registers and a scalar type.
45539
45531
if (!(SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
45540
45532
DstVT.isScalarInteger()) &&
45541
45533
!(DstVT.isVector() && DstVT.getVectorElementType() == MVT::i1 &&
45542
45534
SrcVT.isScalarInteger()))
45543
45535
return SDValue();
45544
45536
45545
- SDValue LHS = Op.getOperand(0);
45546
- SDValue RHS = Op.getOperand(1);
45537
+ SDValue LHS, RHS;
45547
45538
45548
- if (LHS.hasOneUse() && LHS.getOpcode() == ISD::BITCAST &&
45549
- LHS.getOperand(0).getValueType() == DstVT)
45550
- return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT, LHS.getOperand(0),
45551
- DAG.getBitcast(DstVT, RHS));
45539
+ // Look for logic ops.
45540
+ if (!sd_match(Op, m_OneUse(m_BitwiseLogic(m_Value(LHS), m_Value(RHS)))))
45541
+ return SDValue();
45552
45542
45553
- if (RHS.hasOneUse() && RHS.getOpcode() == ISD::BITCAST &&
45554
- RHS.getOperand(0).getValueType() == DstVT)
45543
+ // If either operand was bitcast from DstVT, then perform logic with DstVT (at
45544
+ // least one of the getBitcast() will fold away).
45545
+ if (sd_match(LHS, m_OneUse(m_BitCast(m_SpecificVT(DstVT)))) ||
45546
+ sd_match(RHS, m_OneUse(m_BitCast(m_SpecificVT(DstVT)))))
45555
45547
return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT,
45556
- DAG.getBitcast(DstVT, LHS), RHS.getOperand(0 ));
45548
+ DAG.getBitcast(DstVT, LHS), DAG.getBitcast(DstVT, RHS ));
45557
45549
45558
45550
// If the RHS is a vXi1 build vector, this is a good reason to flip too.
45559
45551
// Most of these have to move a constant from the scalar domain anyway.
0 commit comments