Skip to content

[RISCV] Fold add_vl into accumulator operand of vqdot* #139484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 12, 2025

Conversation

preames
Copy link
Collaborator

@preames preames commented May 12, 2025

If we have a add_vl following a vqdot* instruction, we can move the add before the vqdot instead. For cases where the prior accumulator was zero, we can fold the add into the vqdot* instruction entirely. This directly parallels the folding we do for multiply add variants.

If we have a add_vl following a vqdot* instruction, we can move the
add before the vqdot instead.  For cases where the prior accumulator
was zero, we can fold the add into the vqdot* instruction entirely.
This directly parallels the folding we do for multiply add variants.
@preames preames requested review from lukel97 and topperc May 12, 2025 00:22
@preames preames changed the title [RISCV] Fold add_vl into accumulator opeerand of vqdot* [RISCV] Fold add_vl into accumulator operand of vqdot* May 12, 2025
@llvmbot
Copy link
Member

llvmbot commented May 12, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

Changes

If we have a add_vl following a vqdot* instruction, we can move the add before the vqdot instead. For cases where the prior accumulator was zero, we can fold the add into the vqdot* instruction entirely. This directly parallels the folding we do for multiply add variants.


Full diff: https://github.com/llvm/llvm-project/pull/139484.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+70-3)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll (+12-17)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c53550ea3b23b..93aabdc004b42 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18459,9 +18459,74 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(Opc, DL, VT, Ops);
 }
 
-static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
-                                           ISD::MemIndexType &IndexType,
-                                           RISCVTargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
+                                 const RISCVSubtarget &Subtarget) {
+
+  assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+  if (!N->getValueType(0).isVector())
+    return SDValue();
+
+  SDValue Addend = N->getOperand(0);
+  SDValue DotOp = N->getOperand(1);
+
+  SDValue AddPassthruOp = N->getOperand(2);
+  if (!AddPassthruOp.isUndef())
+    return SDValue();
+
+  auto IsVqdotqOpc = [](unsigned Opc) {
+    switch (Opc) {
+    case RISCVISD::VQDOT_VL:
+    case RISCVISD::VQDOTU_VL:
+    case RISCVISD::VQDOTSU_VL:
+      return true;
+    default:
+      return false;
+    }
+  };
+
+  if (!IsVqdotqOpc(DotOp.getOpcode()))
+    std::swap(Addend, DotOp);
+
+  if (!IsVqdotqOpc(DotOp.getOpcode()))
+    return SDValue();
+
+  SDValue AddMask = N->getOperand(3);
+  SDValue AddVL = N->getOperand(4);
+
+  SDValue MulVL = DotOp.getOperand(4);
+  if (AddVL != MulVL)
+    return SDValue();
+
+  if (AddMask.getOpcode() != RISCVISD::VMSET_VL ||
+      AddMask.getOperand(0) != MulVL)
+    return SDValue();
+
+  SDValue AccumOp = DotOp.getOperand(2);
+  bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
+  // Peak through fixed to scalable
+  if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR &&
+      AccumOp.getOperand(0).isUndef())
+    IsNullAdd =
+        ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode());
+
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  // The manual constant folding is required, this case is not constant folded
+  // or combined.
+  if (!IsNullAdd)
+    Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend,
+                         DAG.getUNDEF(VT), AddMask, AddVL);
+
+  SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend,
+                   DotOp.getOperand(3), DotOp->getOperand(4)};
+  return DAG.getNode(DotOp->getOpcode(), DL, VT, Ops);
+}
+
+static bool
+legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
+                               ISD::MemIndexType &IndexType,
+                               RISCVTargetLowering::DAGCombinerInfo &DCI) {
   if (!DCI.isBeforeLegalize())
     return false;
 
@@ -19582,6 +19647,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::ADD_VL:
     if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
       return V;
+    if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
+      return V;
     return combineToVWMACC(N, DAG, Subtarget);
   case RISCVISD::VWADD_W_VL:
   case RISCVISD::VWADDU_W_VL:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index e5546ad404c1b..ff61ef82176e6 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -314,11 +314,10 @@ define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
 ; DOT-LABEL: vqdot_vv_accum:
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT:    vmv.v.i v10, 0
-; DOT-NEXT:    vqdot.vv v10, v8, v9
-; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdot.vv v16, v8, v9
 ; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.v.v v12, v16
 ; DOT-NEXT:    vmv.s.x v8, zero
 ; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
 ; DOT-NEXT:    vredsum.vs v8, v12, v8
@@ -349,11 +348,10 @@ define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
 ; DOT-LABEL: vqdotu_vv_accum:
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT:    vmv.v.i v10, 0
-; DOT-NEXT:    vqdotu.vv v10, v8, v9
-; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotu.vv v16, v8, v9
 ; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.v.v v12, v16
 ; DOT-NEXT:    vmv.s.x v8, zero
 ; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
 ; DOT-NEXT:    vredsum.vs v8, v12, v8
@@ -384,11 +382,10 @@ define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
 ; DOT-LABEL: vqdotsu_vv_accum:
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT:    vmv.v.i v10, 0
-; DOT-NEXT:    vqdotsu.vv v10, v8, v9
-; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotsu.vv v16, v8, v9
 ; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
-; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.v.v v12, v16
 ; DOT-NEXT:    vmv.s.x v8, zero
 ; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
 ; DOT-NEXT:    vredsum.vs v8, v12, v8
@@ -516,12 +513,10 @@ define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %
 ; DOT:       # %bb.0: # %entry
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
 ; DOT-NEXT:    vmv.v.i v12, 0
-; DOT-NEXT:    vmv.v.i v13, 0
 ; DOT-NEXT:    vqdot.vv v12, v8, v9
-; DOT-NEXT:    vqdot.vv v13, v10, v11
-; DOT-NEXT:    vadd.vv v8, v12, v13
-; DOT-NEXT:    vmv.s.x v9, zero
-; DOT-NEXT:    vredsum.vs v8, v8, v9
+; DOT-NEXT:    vqdot.vv v12, v10, v11
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
 ; DOT-NEXT:    vmv.x.s a0, v8
 ; DOT-NEXT:    ret
 entry:

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit 6408291 into llvm:main May 12, 2025
7 of 10 checks passed
@preames preames deleted the pr-vqdotq-accumulator-folding branch May 12, 2025 22:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants