-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[RISCV] Initial codegen support for zvqdotq extension #137039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch adds pattern matching for the basic usages of the dot product instructions introduced by the experimental zvqdotq extension. It specifically only handles the case where the pattern is feeding a i32 sum reduction as we need reassociate the reduction tree to use these instructions. The vecreduce_add (sext) and vecreduce_add (zext) cases are included mostly to exercise the VX matchers. For the generic matching, we fail to match due to an order of combine issue which results in the bitcast being separated from the splat. I chose to do this lowering as an early combine so as to avoid having to integrate the entire logic into the reduction lowering flow. In particular, that would get a lot more complicated as we extend this to handle add-trees feeding the reductions.
@llvm/pr-subscribers-backend-risc-v Author: Philip Reames (preames) ChangesThis patch adds pattern matching for the basic usages of the dot product instructions introduced by the experimental zvqdotq extension. It specifically only handles the case where the pattern is feeding a i32 sum reduction as we need reassociate the reduction tree to use these instructions. The vecreduce_add (sext) and vecreduce_add (zext) cases are included mostly to exercise the VX matchers. For the generic matching, we fail to match due to an order of combine issue which results in the bitcast being separated from the splat. I chose to do this lowering as an early combine so as to avoid having to integrate the entire logic into the reduction lowering flow. In particular, that would get a lot more complicated as we extend this to handle add-trees feeding the reductions. Full diff: https://github.com/llvm/llvm-project/pull/137039.diff 4 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 11f2095ac9bce..f0c80da123fb1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -6956,7 +6956,7 @@ static bool hasPassthruOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 133 &&
+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 136 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6980,7 +6980,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 133 &&
+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 136 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
@@ -18003,6 +18003,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
DAG.getBuildVector(VT, DL, RHSOps));
}
+static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
+ const SDLoc &DL, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
+ RISCVISD::VQDOTSU_VL == Opc);
+ MVT VT = Op0.getSimpleValueType();
+ assert(VT == Op1.getSimpleValueType() &&
+ VT.getVectorElementType() == MVT::i32);
+
+ assert(VT.isFixedLengthVector());
+ MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+ SDValue Passthru = convertToScalableVector(
+ ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+ Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
+
+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+ const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
+ SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
+ SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
+ {Op0, Op1, Passthru, Mask, VL, PolicyOp});
+ return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+}
+
+static MVT getQDOTXResultType(MVT OpVT) {
+ ElementCount OpEC = OpVT.getVectorElementCount();
+ assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
+ return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
+}
+
+static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
+ SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget,
+ const RISCVTargetLowering &TLI) {
+ // Note: We intentionally do not check the legality of the reduction type.
+ // We want to handle the m4/m8 *src* types, and thus need to let illegal
+ // intermediate types flow through here.
+ if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
+ !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
+ return SDValue();
+
+ // reduce (sext a) <--> reduce (mul zext a. zext 1)
+ // reduce (zext a) <--> reduce (mul sext a. sext 1)
+ if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
+ InVec.getOpcode() == ISD::SIGN_EXTEND) {
+ SDValue A = InVec.getOperand(0);
+ if (A.getValueType().getVectorElementType() != MVT::i8 ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
+ A = DAG.getBitcast(ResVT, A);
+ SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
+
+ bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ }
+
+ // mul (sext, sext) -> vqdot
+ // mul (zext, zext) -> vqdotu
+ // mul (sext, zext) -> vqdotsu
+ // mul (zext, sext) -> vqdotsu (swapped)
+ // TODO: Improve .vx handling - we end up with a sub-vector insert
+ // which confuses the splat pattern matching. Also, match vqdotus.vx
+ if (InVec.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue A = InVec.getOperand(0);
+ SDValue B = InVec.getOperand(1);
+ unsigned Opc = 0;
+ if (A.getOpcode() == B.getOpcode()) {
+ if (A.getOpcode() == ISD::SIGN_EXTEND)
+ Opc = RISCVISD::VQDOT_VL;
+ else if (A.getOpcode() == ISD::ZERO_EXTEND)
+ Opc = RISCVISD::VQDOTU_VL;
+ else
+ return SDValue();
+ } else {
+ if (B.getOpcode() != ISD::ZERO_EXTEND)
+ std::swap(A, B);
+ if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
+ Opc = RISCVISD::VQDOTSU_VL;
+ }
+ assert(Opc);
+
+ if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
+ A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
+ A = DAG.getBitcast(ResVT, A.getOperand(0));
+ B = DAG.getBitcast(ResVT, B.getOperand(0));
+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+}
+
+static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget,
+ const RISCVTargetLowering &TLI) {
+ if (!Subtarget.hasStdExtZvqdotq())
+ return SDValue();
+
+ SDLoc DL(N);
+ MVT VT = N->getSimpleValueType(0);
+ SDValue InVec = N->getOperand(0);
+ if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
+ return SDValue();
+}
+
static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
@@ -19779,8 +19891,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return SDValue();
}
- case ISD::CTPOP:
case ISD::VECREDUCE_ADD:
+ if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
+ return V;
+ [[fallthrough]];
+ case ISD::CTPOP:
if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
return V;
break;
@@ -22306,6 +22421,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(RI_VZIP2B_VL)
NODE_NAME_CASE(RI_VUNZIP2A_VL)
NODE_NAME_CASE(RI_VUNZIP2B_VL)
+ NODE_NAME_CASE(VQDOT_VL)
+ NODE_NAME_CASE(VQDOTU_VL)
+ NODE_NAME_CASE(VQDOTSU_VL)
NODE_NAME_CASE(READ_CSR)
NODE_NAME_CASE(WRITE_CSR)
NODE_NAME_CASE(SWAP_CSR)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 6e50ab8e1f296..c59baf1fd3e58 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -412,7 +412,12 @@ enum NodeType : unsigned {
RI_VUNZIP2A_VL,
RI_VUNZIP2B_VL,
- LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL,
+ // zvqdot instructions with additional passthru, mask and VL operands
+ VQDOT_VL,
+ VQDOTU_VL,
+ VQDOTSU_VL,
+
+ LAST_VL_VECTOR_OP = VQDOTSU_VL,
// Read VLENB CSR
READ_VLENB,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
index 205fffd5115ee..25192ec7e0db2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
@@ -26,3 +26,34 @@ let Predicates = [HasStdExtZvqdotq] in {
def VQDOTSU_VX : VALUVX<0b101010, OPMVX, "vqdotsu.vx">;
def VQDOTUS_VX : VALUVX<0b101110, OPMVX, "vqdotus.vx">;
} // Predicates = [HasStdExtZvqdotq]
+
+
+def riscv_vqdot_vl : SDNode<"RISCVISD::VQDOT_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotu_vl : SDNode<"RISCVISD::VQDOTU_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotsu_vl : SDNode<"RISCVISD::VQDOTSU_VL", SDT_RISCVIntBinOp_VL>;
+
+multiclass VPseudoVQDOT_VV_VX {
+ foreach m = MxSet<32>.m in {
+ defm "" : VPseudoBinaryV_VV<m>,
+ SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", m.MX,
+ forcePassthruRead=true>;
+ defm "" : VPseudoBinaryV_VX<m>,
+ SchedBinary<"WriteVIALUX", "ReadVIALUV", "ReadVIALUX", m.MX,
+ forcePassthruRead=true>;
+ }
+}
+
+// TODO: Add pseudo and patterns for vqdotus.vx
+let Predicates = [HasStdExtZvqdotq], mayLoad = 0, mayStore = 0,
+ hasSideEffects = 0 in {
+ defm PseudoVQDOT : VPseudoVQDOT_VV_VX;
+ defm PseudoVQDOTU : VPseudoVQDOT_VV_VX;
+ defm PseudoVQDOTSU : VPseudoVQDOT_VV_VX;
+}
+
+
+defvar AllE32Vectors = [VI32MF2, VI32M1, VI32M2, VI32M4, VI32M8];
+defm : VPatBinaryVL_VV_VX<riscv_vqdot_vl, "PseudoVQDOT", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotu_vl, "PseudoVQDOTU", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotsu_vl, "PseudoVQDOTSU", AllE32Vectors>;
+
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 25192ea19aab3..e48bc9cdfea4e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,21 +1,31 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdot_vv:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vsext.vf2 v14, v9
-; CHECK-NEXT: vwmul.vv v8, v12, v14
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v14, v9
+; NODOT-NEXT: vwmul.vv v8, v12, v14
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -63,17 +73,27 @@ entry:
}
define i32 @vqdotu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotu_vv:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT: vwmulu.vv v10, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT: vwredsumu.vs v8, v10, v8
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotu_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT: vwmulu.vv v10, v8, v9
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT: vwredsumu.vs v8, v10, v8
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.zext = zext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -102,17 +122,27 @@ entry:
}
define i32 @vqdotsu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vzext.vf2 v14, v9
-; CHECK-NEXT: vwmulsu.vv v8, v12, v14
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vzext.vf2 v14, v9
+; NODOT-NEXT: vwmulsu.vv v8, v12, v14
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotsu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -122,17 +152,27 @@ entry:
}
define i32 @vqdotsu_vv_swapped(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv_swapped:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vzext.vf2 v14, v9
-; CHECK-NEXT: vwmulsu.vv v8, v12, v14
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv_swapped:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vzext.vf2 v14, v9
+; NODOT-NEXT: vwmulsu.vv v8, v12, v14
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_swapped:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotsu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -181,14 +221,38 @@ entry:
}
define i32 @reduce_of_sext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_sext:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vsext.vf4 v12, v8
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: reduce_of_sext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT: vsext.vf4 v12, v8
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT32-LABEL: reduce_of_sext:
+; DOT32: # %bb.0: # %entry
+; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT: vmv.v.i v9, 0
+; DOT32-NEXT: lui a0, 4112
+; DOT32-NEXT: addi a0, a0, 257
+; DOT32-NEXT: vqdot.vx v9, v8, a0
+; DOT32-NEXT: vmv.s.x v8, zero
+; DOT32-NEXT: vredsum.vs v8, v9, v8
+; DOT32-NEXT: vmv.x.s a0, v8
+; DOT32-NEXT: ret
+;
+; DOT64-LABEL: reduce_of_sext:
+; DOT64: # %bb.0: # %entry
+; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT: vmv.v.i v9, 0
+; DOT64-NEXT: lui a0, 4112
+; DOT64-NEXT: addiw a0, a0, 257
+; DOT64-NEXT: vqdot.vx v9, v8, a0
+; DOT64-NEXT: vmv.s.x v8, zero
+; DOT64-NEXT: vredsum.vs v8, v9, v8
+; DOT64-NEXT: vmv.x.s a0, v8
+; DOT64-NEXT: ret
entry:
%a.ext = sext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -196,14 +260,38 @@ entry:
}
define i32 @reduce_of_zext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_zext:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vzext.vf4 v12, v8
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: reduce_of_zext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT: vzext.vf4 v12, v8
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT32-LABEL: reduce_of_zext:
+; DOT32: # %bb.0: # %entry
+; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT: vmv.v.i v9, 0
+; DOT32-NEXT: lui a0, 4112
+; DOT32-NEXT: addi a0, a0, 257
+; DOT32-NEXT: vqdotu.vx v9, v8, a0
+; DOT32-NEXT: vmv.s.x v8, zero
+; DOT32-NEXT: vredsum.vs v8, v9, v8
+; DOT32-NEXT: vmv.x.s a0, v8
+; DOT32-NEXT: ret
+;
+; DOT64-LABEL: reduce_of_zext:
+; DOT64: # %bb.0: # %entry
+; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT: vmv.v.i v9, 0
+; DOT64-NEXT: lui a0, 4112
+; DOT64-NEXT: addiw a0, a0, 257
+; DOT64-NEXT: vqdotu.vx v9, v8, a0
+; DOT64-NEXT: vmv.s.x v8, zero
+; DOT64-NEXT: vredsum.vs v8, v9, v8
+; DOT64-NEXT: vmv.x.s a0, v8
+; DOT64-NEXT: ret
entry:
%a.ext = zext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
|
; NODOT-LABEL: reduce_of_sext: | ||
; NODOT: # %bb.0: # %entry | ||
; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma | ||
; NODOT-NEXT: vsext.vf4 v12, v8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this patch, but is this a missed opportunity to use vsext.vf2+vwredsum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, agreed. I was a bit surprised by the codegen as I swear I've seen us do that in workloads.
return SDValue(); | ||
Opc = RISCVISD::VQDOTSU_VL; | ||
} | ||
assert(Opc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this ever got triggered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't with the current code structure, but asserts exist to check assumptions?
// mul (zext, sext) -> vqdotsu (swapped) | ||
// TODO: Improve .vx handling - we end up with a sub-vector insert | ||
// which confuses the splat pattern matching. Also, match vqdotus.vx | ||
if (InVec.getOpcode() != ISD::MUL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, what about left shifts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Annoyingly complicated, possible future work.
The problem is that we have to expand the shift as a multiply by 2^N, and the range of shift amounts we can handle is very limited due to the input being an i8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p.s. I don't have a real case which benefits from this, if you do, please share and I'll re-prioritize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM other than comment about adding a FIXME for isCommutable.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/162/builds/21818 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/159/builds/21691 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/175/builds/18068 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/185/builds/17855 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/137/builds/18100 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/30986 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/33/builds/16045 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/16/builds/18545 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/56/builds/25226 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/60/builds/26713 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/92/builds/18330 Here is the relevant piece of the build log for the reference
|
The github interface for this is now really confusing. This change was not reverted. I'd considered doing so, but the web-interface got confused, so I switched to just pushing a fix. I fixed forward the build failure in 5f7213e. The build breakage was because I'd merged main to resolve a conflict, and then forgot to commit the final edits on the merge before pushing to the PR and merging via the web-interface. These half terminal and half web-interface flows are extremely error prone. |
This patch adds pattern matching for the basic usages of the dot product instructions introduced by the experimental zvqdotq extension. It specifically only handles the case where the pattern is feeding a i32 sum reduction as we need reassociate the reduction tree to use these instructions.
The vecreduce_add (sext) and vecreduce_add (zext) cases are included mostly to exercise the VX matchers. For the generic matching, we fail to match due to an order of combine issue which results in the bitcast being separated from the splat.
I chose to do this lowering as an early combine so as to avoid having to integrate the entire logic into the reduction lowering flow. In particular, that would get a lot more complicated as we extend this to handle add-trees feeding the reductions.