Skip to content

[SCEVPatternMatch] Extend with more matchers #138836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
Original file line number Diff line number Diff line change
@@ -23,6 +23,8 @@ template <typename Pattern> bool match(const SCEV *S, const Pattern &P) {
}

template <typename Predicate> struct cst_pred_ty : public Predicate {
cst_pred_ty() = default;
cst_pred_ty(uint64_t V) : Predicate(V) {}
bool match(const SCEV *S) const {
assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
"no vector types expected from SCEVs");
@@ -58,6 +60,8 @@ template <typename Class> struct class_match {
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
};

inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }

template <typename Class> struct bind_ty {
Class *&VR;

@@ -93,6 +97,34 @@ struct specificscev_ty {
/// Match if we have a specific specified SCEV.
inline specificscev_ty m_Specific(const SCEV *S) { return S; }

struct is_specific_cst {
uint64_t CV;
is_specific_cst(uint64_t C) : CV(C) {}
bool isValue(const APInt &C) const { return C == CV; }
};

/// Match an SCEV constant with a plain unsigned integer.
inline cst_pred_ty<is_specific_cst> m_scev_SpecificInt(uint64_t V) { return V; }

struct bind_cst_ty {
const APInt *&CR;

bind_cst_ty(const APInt *&Op0) : CR(Op0) {}

bool match(const SCEV *S) const {
assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
"no vector types expected from SCEVs");
auto *C = dyn_cast<SCEVConstant>(S);
if (!C)
return false;
CR = &C->getAPInt();
return true;
}
};

/// Match an SCEV constant and bind it to an APInt.
inline bind_cst_ty m_scev_APInt(const APInt *&C) { return C; }

/// Match a unary SCEV.
template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
Op0_t Op0;
@@ -149,6 +181,17 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t>
m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
}
} // namespace SCEVPatternMatch
} // namespace llvm

23 changes: 11 additions & 12 deletions llvm/lib/Analysis/LoopAccessAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -65,6 +66,7 @@
#include <vector>

using namespace llvm;
using namespace llvm::SCEVPatternMatch;

#define DEBUG_TYPE "loop-accesses"

@@ -811,8 +813,8 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());

// Calculate the pointer stride and check if it is constant.
const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
if (!C) {
const APInt *APStepVal;
if (!match(Step, m_scev_APInt(APStepVal))) {
LLVM_DEBUG({
dbgs() << "LAA: Bad stride - Not a constant strided ";
if (Ptr)
@@ -825,13 +827,12 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
const auto &DL = Lp->getHeader()->getDataLayout();
TypeSize AllocSize = DL.getTypeAllocSize(AccessTy);
int64_t Size = AllocSize.getFixedValue();
const APInt &APStepVal = C->getAPInt();

// Huge step value - give up.
if (APStepVal.getBitWidth() > 64)
if (APStepVal->getBitWidth() > 64)
return std::nullopt;

int64_t StepVal = APStepVal.getSExtValue();
int64_t StepVal = APStepVal->getSExtValue();

// Strided access.
int64_t Stride = StepVal / Size;
@@ -2061,11 +2062,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
DL, SE, *(PSE.getSymbolicMaxBackedgeTakenCount()), *Dist, MaxStride))
return Dependence::NoDep;

const SCEVConstant *ConstDist = dyn_cast<SCEVConstant>(Dist);

// Attempt to prove strided accesses independent.
if (ConstDist) {
uint64_t Distance = ConstDist->getAPInt().abs().getZExtValue();
const APInt *ConstDist = nullptr;
if (match(Dist, m_scev_APInt(ConstDist))) {
uint64_t Distance = ConstDist->abs().getZExtValue();

// If the distance between accesses and their strides are known constants,
// check whether the accesses interlace each other.
@@ -2111,9 +2111,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
return Dependence::Unknown;
}
if (!HasSameSize ||
couldPreventStoreLoadForward(
ConstDist->getAPInt().abs().getZExtValue(), TypeByteSize)) {
if (!HasSameSize || couldPreventStoreLoadForward(
ConstDist->abs().getZExtValue(), TypeByteSize)) {
LLVM_DEBUG(
dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
return Dependence::ForwardButPreventsForwarding;
21 changes: 8 additions & 13 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
@@ -7149,16 +7149,11 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
"Should be!");

// Peel off a constant offset:
if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
// In the future we could consider being smarter here and handle
// {Start+Step,+,Step} too.
if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
return;

Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
S = SA->getOperand(1);
}
// Peel off a constant offset. In the future we could consider being
// smarter here and handle {Start+Step,+,Step} too.
const APInt *Off;
if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
Offset = *Off;

// Peel off a cast operation
if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
@@ -7337,11 +7332,11 @@ bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {

bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
return !SCEVExprContains(Op, [this](const SCEV *S) {
auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
const SCEV *Op1;
bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
// The UDiv may be UB if the divisor is poison or zero. Unless the divisor
// is a non-zero constant, we have to assume the UDiv may be UB.
return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
!isGuaranteedNotToBePoison(UDiv->getOperand(1)));
return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
});
}

Loading