Skip to content

[LV] Bundle sub reductions into VPExpressionRecipe #147255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/SamTebbs33/sub-reductions
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1645,8 +1645,10 @@ class TargetTransformInfo {
/// extensions. This is the cost of as:
/// ResTy vecreduce.add(mul (A, B)).
/// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
/// The multiply can optionally be negated, which signifies that it is a sub
/// reduction.
LLVM_ABI InstructionCost getMulAccReductionCost(
bool IsUnsigned, Type *ResTy, VectorType *Ty,
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth keeping the booleans together, i.e. next to IsUnsigned?

TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;

/// Calculate the cost of an extended reduction pattern, similar to
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ class TargetTransformInfoImplBase {

virtual InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
TTI::TargetCostKind CostKind) const {
bool Negated, TTI::TargetCostKind CostKind) const {
return 1;
}

Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3116,7 +3116,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {

InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
bool Negated,
TTI::TargetCostKind CostKind) const override {
if (Negated)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we add a cost for this?

return InstructionCost::getInvalid(CostKind);
// Without any native support, this is equivalent to the cost of
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
// vecreduce.add(mul(A, B)).
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1274,9 +1274,10 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
}

InstructionCost TargetTransformInfo::getMulAccReductionCost(
bool IsUnsigned, Type *ResTy, VectorType *Ty,
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
TTI::TargetCostKind CostKind) const {
return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, Negated,
CostKind);
}

InstructionCost
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5316,8 +5316,10 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(

InstructionCost
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
VectorType *VecTy,
VectorType *VecTy, bool Negated,
TTI::TargetCostKind CostKind) const {
if (Negated)
return InstructionCost::getInvalid(CostKind);
EVT VecVT = TLI->getValueType(DL, VecTy);
EVT ResVT = TLI->getValueType(DL, ResTy);

Expand All @@ -5332,7 +5334,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
return LT.first + 2;
}

return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, Negated,
CostKind);
}

InstructionCost
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
TTI::TargetCostKind CostKind) const override;

InstructionCost getMulAccReductionCost(
bool IsUnsigned, Type *ResTy, VectorType *Ty,
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;

InstructionCost
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1884,8 +1884,10 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(

InstructionCost
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
VectorType *ValTy,
VectorType *ValTy, bool Negated,
TTI::TargetCostKind CostKind) const {
if (Negated)
return InstructionCost::getInvalid(CostKind);
EVT ValVT = TLI->getValueType(DL, ValTy);
EVT ResVT = TLI->getValueType(DL, ResTy);

Expand All @@ -1906,7 +1908,8 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
return ST->getMVEVectorCostFactor(CostKind) * LT.first;
}

return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind);
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, Negated,
CostKind);
}

InstructionCost
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/ARM/ARMTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
TTI::TargetCostKind CostKind) const override;
InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
bool Negated,
TTI::TargetCostKind CostKind) const override;

InstructionCost
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5538,7 +5538,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI::CastContextHint::None, CostKind, RedOp);

InstructionCost RedCost = TTI.getMulAccReductionCost(
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: /*Negated=*/false and same for other below.


if (RedCost.isValid() &&
RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
Expand Down Expand Up @@ -5583,7 +5583,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);

InstructionCost RedCost = TTI.getMulAccReductionCost(
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
InstructionCost ExtraExtCost = 0;
if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
Expand All @@ -5602,7 +5602,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);

InstructionCost RedCost = TTI.getMulAccReductionCost(
true, RdxDesc.getRecurrenceType(), VectorTy, CostKind);
true, RdxDesc.getRecurrenceType(), VectorTy, false, CostKind);

if (RedCost.isValid() && RedCost < MulCost + BaseCost)
return I == RetI ? RedCost : 0;
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -2757,6 +2757,12 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
/// vector operands, performing a reduction.add on the result, and adding
/// the scalar result to a chain.
MulAccReduction,
/// Represent an inloop multiply-accumulate reduction, multiplying the
/// extended vector operands, negating the multiplication, performing a
/// reduction.add
/// on the result, and adding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting of the comment looks a bit odd - can you fix it?

/// the scalar result to a chain.
ExtNegatedMulAccReduction,
};

/// Type of the expression.
Expand All @@ -2780,6 +2786,11 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
VPWidenRecipe *Mul, VPReductionRecipe *Red)
: VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
{Ext0, Ext1, Mul, Red}) {}
VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
VPWidenRecipe *Mul, VPWidenRecipe *Sub,
VPReductionRecipe *Red)
: VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
{Ext0, Ext1, Mul, Sub, Red}) {}

~VPExpressionRecipe() override {
for (auto *R : reverse(ExpressionRecipes))
Expand Down
35 changes: 32 additions & 3 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2672,13 +2672,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
}
case ExpressionTypes::MulAccReduction:
return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, Ctx.CostKind);
return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, false,
Ctx.CostKind);

case ExpressionTypes::ExtMulAccReduction:
case ExpressionTypes::ExtNegatedMulAccReduction:
case ExpressionTypes::ExtMulAccReduction: {
bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction;
return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
RedTy, SrcVecTy, Ctx.CostKind);
RedTy, SrcVecTy, Negated, Ctx.CostKind);
}
}
llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
}
Expand Down Expand Up @@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << ")";
break;
}
case ExpressionTypes::ExtNegatedMulAccReduction: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to commonise this with the ExtMulAccReduction case if the only difference is a negate?

getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
O << " + ";
O << "reduce."
<< Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
<< " (sub (0, mul";
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
Mul->printFlags(O);
O << "(";
getOperand(0)->printAsOperand(O, SlotTracker);
auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to "
<< *Ext0->getResultType() << "), (";
getOperand(1)->printAsOperand(O, SlotTracker);
auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to "
<< *Ext1->getResultType() << ")";
if (Red->isConditional()) {
O << ", ";
Red->getCondOp()->printAsOperand(O, SlotTracker);
}
O << "))";
break;
}
case ExpressionTypes::MulAccReduction:
case ExpressionTypes::ExtMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
Expand Down
33 changes: 23 additions & 10 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2908,16 +2908,17 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,

// Clamp the range if using multiply-accumulate-reduction is profitable.
auto IsMulAccValidAndClampRange =
[&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
[&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
bool Negated = false) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *SrcTy =
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
InstructionCost MulAccCost =
Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
IsZExt, RedTy, SrcVecTy, Negated, CostKind);
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
InstructionCost ExtCost = 0;
Expand All @@ -2935,14 +2936,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
};

VPValue *VecOp = Red->getVecOp();
VPValue *Mul = nullptr;
VPValue *Sub = nullptr;
VPValue *A, *B;
// Sub reductions will have a sub between the add reduction and vec op.
if (match(VecOp,
m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul))))
Sub = VecOp;
else
Mul = VecOp;
// Try to match reduce.add(mul(...)).
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) {
auto *RecipeA =
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
auto *RecipeB =
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe());

// Match reduce.add(mul(ext, ext)).
if (RecipeA && RecipeB &&
Expand All @@ -2951,12 +2960,16 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
Instruction::CastOps::ZExt,
Mul, RecipeA, RecipeB, nullptr)) {
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
MulR, RecipeA, RecipeB, nullptr, Sub)) {
if (Sub)
return new VPExpressionRecipe(
RecipeA, RecipeB, MulR,
cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red);
return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
}
// Match reduce.add(mul).
if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
return new VPExpressionRecipe(Mul, Red);
if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr, Sub))
return new VPExpressionRecipe(MulR, Red);
}
// Match reduce.add(ext(mul(ext(A), ext(B)))).
// All extend recipes must have same opcode or A == B
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1401,8 +1401,8 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II,
TTI::CastContextHint::None, CostKind, RedOp);

CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
CostAfterReduction =
TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(),
ExtType, false, CostKind);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably better written as /*Negated=*/false

return;
}
CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
Expand Down
Loading
Loading