-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Add support for single reductions in ComplexDeinterleavingPass #112875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for single reductions in ComplexDeinterleavingPass #112875
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Nicholas Guy (NickGuy-Arm) ChangesThe Complex Deinterleaving pass assumes that all values emitted will result in complex numbers, this patch aims to remove that assumption and adds support for emitting just the real or imaginary components, not both. Patch is 25.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112875.diff 4 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
index 84a2673fecb5bf..a3fa2197727701 100644
--- a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
+++ b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
@@ -43,6 +43,7 @@ enum class ComplexDeinterleavingOperation {
ReductionPHI,
ReductionOperation,
ReductionSelect,
+ ReductionSingle
};
enum class ComplexDeinterleavingRotation {
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 8573b016d1e5bb..08287a4d5ed022 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -145,6 +145,7 @@ struct ComplexDeinterleavingCompositeNode {
friend class ComplexDeinterleavingGraph;
using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
using RawNodePtr = ComplexDeinterleavingCompositeNode *;
+ bool OperandsValid = true;
public:
ComplexDeinterleavingOperation Operation;
@@ -161,7 +162,11 @@ struct ComplexDeinterleavingCompositeNode {
SmallVector<RawNodePtr> Operands;
Value *ReplacementNode = nullptr;
- void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
+ void addOperand(NodePtr Node) {
+ if (!Node || !Node.get())
+ OperandsValid = false;
+ Operands.push_back(Node.get());
+ }
void dump() { dump(dbgs()); }
void dump(raw_ostream &OS) {
@@ -195,6 +200,10 @@ struct ComplexDeinterleavingCompositeNode {
PrintNodeRef(Op);
}
}
+
+ bool AreOperandsValid() {
+ return OperandsValid;
+ }
};
class ComplexDeinterleavingGraph {
@@ -294,7 +303,7 @@ class ComplexDeinterleavingGraph {
NodePtr submitCompositeNode(NodePtr Node) {
CompositeNodes.push_back(Node);
- if (Node->Real && Node->Imag)
+ if (Node->Real)
CachedResult[{Node->Real, Node->Imag}] = Node;
return Node;
}
@@ -328,8 +337,10 @@ class ComplexDeinterleavingGraph {
/// i: ai - br
NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
+ NodePtr identifyPartialReduction(Value *R, Value *I);
NodePtr identifyNode(Value *R, Value *I);
+ NodePtr identifyNode(Value *R, Value *I, bool &FromCache);
/// Determine if a sum of complex numbers can be formed from \p RealAddends
/// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
@@ -397,6 +408,7 @@ class ComplexDeinterleavingGraph {
/// * Deinterleave the final value outside of the loop and repurpose original
/// reduction users
void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
+ void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
public:
void dump() { dump(dbgs()); }
@@ -893,16 +905,26 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
- LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
- assert(R->getType() == I->getType() &&
- "Real and imaginary parts should not have different types");
+ bool _;
+ return identifyNode(R, I, _);
+}
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I, bool &FromCache) {
auto It = CachedResult.find({R, I});
if (It != CachedResult.end()) {
LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
+ FromCache = true;
return It->second;
}
+ if(NodePtr CN = identifyPartialReduction(R, I))
+ return CN;
+
+ bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
+ if(!IsReduction && R->getType() != I->getType())
+ return nullptr;
+
if (NodePtr CN = identifySplat(R, I))
return CN;
@@ -1428,12 +1450,18 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
if (It != RootToNode.end()) {
auto RootNode = It->second;
assert(RootNode->Operation ==
- ComplexDeinterleavingOperation::ReductionOperation);
+ ComplexDeinterleavingOperation::ReductionOperation || RootNode->Operation == ComplexDeinterleavingOperation::ReductionSingle);
// Find out which part, Real or Imag, comes later, and only if we come to
// the latest part, add it to OrderedRoots.
auto *R = cast<Instruction>(RootNode->Real);
- auto *I = cast<Instruction>(RootNode->Imag);
- auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
+ auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
+
+ Instruction *ReplacementAnchor;
+ if(I)
+ ReplacementAnchor = R->comesBefore(I) ? I : R;
+ else
+ ReplacementAnchor = R;
+
if (ReplacementAnchor != RootI)
return false;
OrderedRoots.push_back(RootI);
@@ -1521,11 +1549,11 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
for (size_t i = 0; i < OperationInstruction.size(); ++i) {
if (Processed[i])
continue;
+ auto *Real = OperationInstruction[i];
for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
if (Processed[j])
continue;
-
- auto *Real = OperationInstruction[i];
+
auto *Imag = OperationInstruction[j];
if (Real->getType() != Imag->getType())
continue;
@@ -1557,6 +1585,25 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
break;
}
}
+
+ // We want to check that we have 2 operands, but the function attributes
+ // being counted as operands bloats this value.
+ if(Real->getNumOperands() < 2)
+ continue;
+
+ RealPHI = ReductionInfo[Real].first;
+ ImagPHI = nullptr;
+ PHIsFound = false;
+ auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
+ if(Node && PHIsFound) {
+ LLVM_DEBUG(dbgs() << "Identified single reduction starting from instruction: "
+ << *Real << "/" << *ReductionInfo[Real].second << "\n");
+ Processed[i] = true;
+ auto RootNode = prepareCompositeNode(ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
+ RootNode->addOperand(Node);
+ RootToNode[Real] = RootNode;
+ submitCompositeNode(RootNode);
+ }
}
RealPHI = nullptr;
@@ -1564,6 +1611,12 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
}
bool ComplexDeinterleavingGraph::checkNodes() {
+
+ for (NodePtr N : CompositeNodes) {
+ if (!N->AreOperandsValid())
+ return false;
+ }
+
// Collect all instructions from roots to leaves
SmallPtrSet<Instruction *, 16> AllInstructions;
SmallVector<Instruction *, 8> Worklist;
@@ -1832,7 +1885,7 @@ ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
Instruction *Imag) {
- if (Real != RealPHI || Imag != ImagPHI)
+ if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
return nullptr;
PHIsFound = true;
@@ -1970,13 +2023,18 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
case ComplexDeinterleavingOperation::ReductionPHI: {
// If Operation is ReductionPHI, a new empty PHINode is created.
// It is filled later when the ReductionOperation is processed.
+ auto *OldPHI = cast<PHINode>(Node->Real);
auto *VTy = cast<VectorType>(Node->Real->getType());
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
- OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
+ OldToNewPHI[OldPHI] = NewPHI;
ReplacementNode = NewPHI;
break;
}
+ case ComplexDeinterleavingOperation::ReductionSingle:
+ ReplacementNode = replaceNode(Builder, Node->Operands[0]);
+ processReductionSingle(ReplacementNode, Node);
+ break;
case ComplexDeinterleavingOperation::ReductionOperation:
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
processReductionOperation(ReplacementNode, Node);
@@ -2001,6 +2059,37 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
return ReplacementNode;
}
+void ComplexDeinterleavingGraph::processReductionSingle(Value *OperationReplacement, RawNodePtr Node) {
+ auto *Real = cast<Instruction>(Node->Real);
+ auto *OldPHI = ReductionInfo[Real].first;
+ auto *NewPHI = OldToNewPHI[OldPHI];
+ auto *VTy = cast<VectorType>(Real->getType());
+ auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
+
+ Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
+
+ IRBuilder<> Builder(Incoming->getTerminator());
+
+ Value *NewInit = nullptr;
+ if(auto *C = dyn_cast<Constant>(Init)) {
+ if(C->isZeroValue())
+ NewInit = Constant::getNullValue(NewVTy);
+ }
+
+ if (!NewInit)
+ NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
+ {Init, Constant::getNullValue(VTy)});
+
+ NewPHI->addIncoming(NewInit, Incoming);
+ NewPHI->addIncoming(OperationReplacement, BackEdge);
+
+ auto *FinalReduction = ReductionInfo[Real].second;
+ Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
+ // TODO Ensure that the `AddReduce` here matches the original, found in `FinalReduction`
+ auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
+ FinalReduction->replaceAllUsesWith(AddReduce);
+}
+
void ComplexDeinterleavingGraph::processReductionOperation(
Value *OperationReplacement, RawNodePtr Node) {
auto *Real = cast<Instruction>(Node->Real);
@@ -2060,8 +2149,12 @@ void ComplexDeinterleavingGraph::replaceNodes() {
auto *RootImag = cast<Instruction>(RootNode->Imag);
ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
- DeadInstrRoots.push_back(cast<Instruction>(RootReal));
- DeadInstrRoots.push_back(cast<Instruction>(RootImag));
+ DeadInstrRoots.push_back(RootReal);
+ DeadInstrRoots.push_back(RootImag);
+ } else if(RootNode->Operation == ComplexDeinterleavingOperation::ReductionSingle) {
+ auto *RootInst = cast<Instruction>(RootNode->Real);
+ ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
+ DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
} else {
assert(R && "Unable to find replacement for RootInstruction");
DeadInstrRoots.push_back(RootInstruction);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5e5afdb7fa0a6c..8068bb67408814 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29171,6 +29171,8 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
Value *Accumulator) const {
VectorType *Ty = cast<VectorType>(InputA->getType());
+ if (Accumulator == nullptr)
+ Accumulator = Constant::getNullValue(Ty);
bool IsScalable = Ty->isScalableTy();
bool IsInt = Ty->getElementType()->isIntegerTy();
@@ -29182,6 +29184,7 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
if (TyWidth > 128) {
int Stride = Ty->getElementCount().getKnownMinValue() / 2;
+ int AccStride = cast<VectorType>(Accumulator->getType())->getElementCount().getKnownMinValue() / 2;
auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
@@ -29191,25 +29194,23 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride));
Value *LowerSplitAcc = nullptr;
Value *UpperSplitAcc = nullptr;
- if (Accumulator) {
- LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0));
+ Type *FullTy = Ty;
+ FullTy = Accumulator->getType();
+ auto *HalfAccTy = VectorType::getHalfElementsVectorType(cast<VectorType>(Accumulator->getType()));
+ LowerSplitAcc = B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
UpperSplitAcc =
- B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride));
- }
+ B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
auto *LowerSplitInt = createComplexDeinterleavingIR(
B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
auto *UpperSplitInt = createComplexDeinterleavingIR(
B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
- auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt,
+ auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy), LowerSplitInt,
B.getInt64(0));
- return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride));
+ return B.CreateInsertVector(FullTy, Result, UpperSplitInt, B.getInt64(AccStride));
}
if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
- if (Accumulator == nullptr)
- Accumulator = Constant::getNullValue(Ty);
-
if (IsScalable) {
if (IsInt)
return B.CreateIntrinsic(
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-cdot.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-cdot.ll
new file mode 100644
index 00000000000000..6277f9a3842bbe
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-cdot.ll
@@ -0,0 +1,170 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S --passes=complex-deinterleaving %s --mattr=+sve2 -o - | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define i32 @cdotp(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: define i32 @cdotp(
+; CHECK-SAME: ptr nocapture noundef readonly [[A:%.*]], ptr nocapture noundef readonly [[B:%.*]], i32 noundef [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: [[ENTRY:.*]]:
+; CHECK-NEXT: [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEXT: br i1 [[CMP28_NOT]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY_PREHEADER:.*]]
+; CHECK: [[FOR_BODY_PREHEADER]]:
+; CHECK-NEXT: [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 16
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
+; CHECK: [[VECTOR_PH]]:
+; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 16
+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 16
+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
+; CHECK: [[VECTOR_BODY]]:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP11:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP20:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[INDEX_I:%.*]] = shl nuw nsw i64 [[INDEX]], 1
+; CHECK-NEXT: [[A_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[INDEX_I]]
+; CHECK-NEXT: [[A_LOAD:%.*]] = load <vscale x 32 x i8>, ptr [[A_PTR]], align 32
+; CHECK-NEXT: [[B_PTR:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[INDEX_I]]
+; CHECK-NEXT: [[B_LOAD:%.*]] = load <vscale x 32 x i8>, ptr [[B_PTR]], align 32
+; CHECK-NEXT: [[TMP6:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A_LOAD]], i64 0)
+; CHECK-NEXT: [[TMP7:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B_LOAD]], i64 0)
+; CHECK-NEXT: [[TMP8:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A_LOAD]], i64 16)
+; CHECK-NEXT: [[TMP9:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B_LOAD]], i64 16)
+; CHECK-NEXT: [[VEC_PHI:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP11]], i64 0)
+; CHECK-NEXT: [[TMP13:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP11]], i64 4)
+; CHECK-NEXT: [[TMP10:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i8> [[TMP6]], <vscale x 16 x i8> [[TMP7]], i32 0)
+; CHECK-NEXT: [[TMP21:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP13]], <vscale x 16 x i8> [[TMP8]], <vscale x 16 x i8> [[TMP9]], i32 0)
+; CHECK-NEXT: [[TMP22:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> [[TMP10]], i64 0)
+; CHECK-NEXT: [[TMP20]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> [[TMP22]], <vscale x 4 x i32> [[TMP21]], i64 4)
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]]
+; CHECK: [[MIDDLE_BLOCK]]:
+; CHECK-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP20]])
+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP_LOOPEXIT:.*]], label %[[SCALAR_PH]]
+; CHECK: [[SCALAR_PH]]:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP23]], %[[MIDDLE_BLOCK]] ], [ 0, %[[FOR_BODY_PREHEADER]] ]
+; CHECK-NEXT: br label %[[FOR_BODY:.*]]
+; CHECK: [[FOR_COND_CLEANUP_LOOPEXIT]]:
+; CHECK-NEXT: [[SUB_LCSSA:%.*]] = phi i32 [ [[SUB:%.*]], %[[FOR_BODY]] ], [ [[TMP23]], %[[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: br label %[[FOR_COND_CLEANUP]]
+; CHECK: [[FOR_COND_CLEANUP]]:
+; CHECK-NEXT: [[RES_0_LCSSA:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[SUB_LCSSA]], %[[FOR_COND_CLEANUP_LOOPEXIT]] ]
+; CHECK-NEXT: ret i32 [[RES_0_LCSSA]]
+; CHECK: [[FOR_BODY]]:
+; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
+; CHECK-NEXT: [[RES_030:%.*]] = phi i32 [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ], [ [[SUB]], %[[FOR_BODY]] ]
+; CHECK-NEXT: [[TMP14:%.*]] = shl nuw nsw i64 [[INDVARS_IV]], 1
+; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP14]]
+; CHECK-NEXT: [[TMP15:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
+; CHECK-NEXT: [[CONV:%.*]] = sext i8 [[TMP15]] to i32
+; CHECK-NEXT: [[TMP16:%.*]] = or disjoint i64 [[TMP14]], 1
+; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP16]]
+; CHECK-NEXT: [[TMP17:%.*]] = load i8, ptr [[ARRAYIDX4]], align 1
+; CHECK-NEXT: [[CONV5:%.*]] = sext i8 [[TMP17]] to i32
+; CHECK-NEXT: [[ARRAYIDX9:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP14]]
+; CHECK-NEXT: [[TMP18:%.*]] = load i8, ptr [[ARRAYIDX9]], align 1
+; CHECK-NEXT: [[CONV10:%.*]] = sext i8 [[TMP18]] to i32
+; CHECK-NEXT: [[ARRAYIDX14:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP16]]
+; CHECK-NEXT: [[TMP19:%.*]] = load i8, ptr [[ARRAYIDX14]], align 1
+; CHECK-NEXT: [[CONV15:%.*]] = sext i8 [[TMP19]] to i32
+; CHECK-NEXT: [[MUL16:%.*]] = mul nsw i32 [[CONV10]], [[CONV]]
+; CHECK-NEXT: [[ADD17:%.*]] = add nsw i32 [[MUL16]], [[RES_030]]
+; CHECK-NEXT: [[MUL18:%.*]] = mul nsw i32 [[CONV15]], [[CONV5]]
+; CHECK-NEXT: [[SUB]] = sub i32 [[ADD17]], [[MUL18]]
+; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
+...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
aac0ea6
to
918312c
Compare
LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); | ||
assert(R->getType() == I->getType() && | ||
"Real and imaginary parts should not have different types"); | ||
bool _; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this here for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For one of the sanity-checks elsewhere (https://github.com/llvm/llvm-project/pull/112875/files/b2410688146531936db5f58ed2f0ebf78bf8387a#diff-ebbbd6cbc055d2185b50e106f58ee13188b47cd5fa49f21fb66a9ea82d54b086R993), we check that we're pulling the already-matched node from the cache after unwrapping any casts.
The bool here is to act as an optional parameter to identifyNode
, working around the lvalue/rvalue mismatch of having it inlined.
It's either this approach, or we update every call to identifyNode
to have an extra bool variable assigned nearby. (Or we remove the sanity check altogether)
Edit: That said, I could just use a pointer and use conventional default arguments..
@@ -0,0 +1,58 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 | |||
; RUN: opt -S --passes=complex-deinterleaving %s --mattr=+sve2 -o - | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need some more tests, such as:
- No sve, no sve2
- More rotations
- I don't really understand the caching system, but some tests that test that functionality are needed
- More types
And negative tests for other code paths.
Ping |
@@ -1,41 +1,100 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 | |||
; RUN: opt -S --passes=complex-deinterleaving %s --mattr=+sve2 -o - | FileCheck %s | |||
; RUN: opt -S --passes=complex-deinterleaving %s --mattr=+sve2 -o - | FileCheck %s --check-prefix=CHECK-SVE2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we're still missing some negative tests and can the caching system be tested at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a negative test, but the caching is an implementation detail and is not exposed in any testable way, it's also not new in this PR
if (!I->hasOneUser()) | ||
return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can there only be one user? We should have a negative test for it so someone removing the check in the future doesn't think all is well!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hasOneUser
check was to simplify things later, but it's not strictly necessary so I've removed it and fixed the succeeding code to not assume one user.
} | ||
|
||
ComplexDeinterleavingGraph::NodePtr | ||
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I, bool *FromCache) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can FromCache
be a reference instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just removed this parameter altogether, it's only used in one place, and the value of the FromCache
parameter was dubious at best
VectorType *RealTy = dyn_cast<VectorType>(R->getType()); | ||
if (!RealTy) | ||
return nullptr; | ||
VectorType *ImagTy = dyn_cast<VectorType>(I->getType()); | ||
if (!ImagTy) | ||
return nullptr; | ||
|
||
if (RealTy->isScalableTy() != ImagTy->isScalableTy()) | ||
return nullptr; | ||
if (RealTy->getElementType() != ImagTy->getElementType()) | ||
return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have some tests for these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some tests
ccce001
to
b3550a5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with one request.
ret i16 %0 | ||
} | ||
|
||
define i32 @cdotp_i8_rot0_fixed_length(<32 x i8> %a, <32 x i8> %b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name makes it seem like a positive test so I'd suggest it be changed to not_cdotp_fixed_length
.
It looks like this may be causing some buildbot failures: https://lab.llvm.org/buildbot/#/builders/41/builds/4171 Could you please take a look and revert if it isn't a trivial fix
|
#112875)" This reverts commit b3eede5. This has been breaking most AArch64 stage2 builds for 4+ hours, reverting to get the bots back to green. https://lab.llvm.org/buildbot/#/builders/41/builds/4172 https://lab.llvm.org/buildbot/#/builders/4/builds/4281 https://lab.llvm.org/buildbot/#/builders/199/builds/263 https://lab.llvm.org/buildbot/#/builders/198/builds/334 https://lab.llvm.org/buildbot/#/builders/143/builds/4276 https://lab.llvm.org/buildbot/#/builders/17/builds/4725
Reverted for now in 76714be as this has been breaking most stage2 builds on AArch64 for 4+ hours |
Thanks, I'm now testing a fix with a local stage2 build. Not sure why I didn't receive the buildbot emails though, so I didn't realise it was failing things at first. |
…ss (llvm#112875)" This reverts commit 76714be.
Opened a new PR at #120441 to reland this patch with a fix for the build failures. |
…rt for single reductions in ComplexDeinterleavingPass (#112875)" (#120441) This reverts commit 76714be, fixing the build failure that caused the revert. The failure stemmed from the complex deinterleaving pass identifying a series of add operations as a "complex to single reduction", so when it tried to transform this erroneously identified pattern, it faulted. The fix applied is to ensure that complex numbers (or patterns that match them) are used throughout, by checking if there is a deinterleave node amidst the graph.
…rt for single reductions in ComplexDeinterleavingPass (llvm#112875)" (llvm#120441) This reverts commit 76714be, fixing the build failure that caused the revert. The failure stemmed from the complex deinterleaving pass identifying a series of add operations as a "complex to single reduction", so when it tried to transform this erroneously identified pattern, it faulted. The fix applied is to ensure that complex numbers (or patterns that match them) are used throughout, by checking if there is a deinterleave node amidst the graph.
auto *Real = OperationInstruction[i]; | ||
// We want to check that we have 2 operands, but the function attributes | ||
// being counted as operands bloats this value. | ||
if (Real->getNumOperands() < 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we are missing a check to ensure we only have integer types here. With FP types, we currently crash due to creating llvm.reduce.add
reductions for floating point types. It also doesn't check for the rigth FP flags AFAICT.
Put up #139469 to bail out
The Complex Deinterleaving pass assumes that all values emitted will result in complex numbers, this patch aims to remove that assumption and adds support for emitting just the real or imaginary components, not both.