Skip to content

[AMDGPU] Refine GCNHazardRecognizer hasHazard() #138841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

perlfu
Copy link
Contributor

@perlfu perlfu commented May 7, 2025

Remove recursion to avoid stack overflow on large CFGs.
Avoid worklist for hazard search within single MachineBasicBlock.
Ensure predecessors are visited for all state combinations.

Remove recursion to avoid stack overflow on large CFGs.
Avoid worklist for hazard search within single MachineBasicBlock.
Ensure predecessors are visited for all state combinations.
@perlfu perlfu requested review from jayfoad and rampitec May 7, 2025 10:22
@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Carl Ritson (perlfu)

Changes

Remove recursion to avoid stack overflow on large CFGs.
Avoid worklist for hazard search within single MachineBasicBlock.
Ensure predecessors are visited for all state combinations.


Full diff: https://github.com/llvm/llvm-project/pull/138841.diff

1 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp (+48-31)
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index aaefe27b1324f..644fbb77a495a 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -436,42 +436,55 @@ using IsExpiredFn = function_ref<bool(const MachineInstr &, int WaitStates)>;
 using GetNumWaitStatesFn = function_ref<unsigned int(const MachineInstr &)>;
 
 // Search for a hazard in a block and its predecessors.
+// StateT must implement getHashValue().
 template <typename StateT>
 static bool
-hasHazard(StateT State,
+hasHazard(StateT InitialState,
           function_ref<HazardFnResult(StateT &, const MachineInstr &)> IsHazard,
           function_ref<void(StateT &, const MachineInstr &)> UpdateState,
-          const MachineBasicBlock *MBB,
-          MachineBasicBlock::const_reverse_instr_iterator I,
-          DenseSet<const MachineBasicBlock *> &Visited) {
-  for (auto E = MBB->instr_rend(); I != E; ++I) {
-    // No need to look at parent BUNDLE instructions.
-    if (I->isBundle())
-      continue;
+          const MachineBasicBlock *InitialMBB,
+          MachineBasicBlock::const_reverse_instr_iterator InitialI) {
+  SmallVector<std::pair<const MachineBasicBlock *, StateT>> Worklist;
+  DenseSet<std::pair<const MachineBasicBlock *, unsigned>> Visited;
+  const MachineBasicBlock *MBB = InitialMBB;
+  StateT State = InitialState;
+  auto I = InitialI;
+
+  for (;;) {
+    bool Expired = false;
+    for (auto E = MBB->instr_rend(); I != E; ++I) {
+      // No need to look at parent BUNDLE instructions.
+      if (I->isBundle())
+        continue;
 
-    switch (IsHazard(State, *I)) {
-    case HazardFound:
-      return true;
-    case HazardExpired:
-      return false;
-    default:
-      // Continue search
-      break;
-    }
+      auto Result = IsHazard(State, *I);
+      if (Result == HazardFound)
+        return true;
+      if (Result == HazardExpired) {
+        Expired = true;
+        break;
+      }
 
-    if (I->isInlineAsm() || I->isMetaInstruction())
-      continue;
+      if (I->isInlineAsm() || I->isMetaInstruction())
+        continue;
 
-    UpdateState(State, *I);
-  }
+      UpdateState(State, *I);
+    }
 
-  for (MachineBasicBlock *Pred : MBB->predecessors()) {
-    if (!Visited.insert(Pred).second)
-      continue;
+    if (!Expired) {
+      unsigned StateHash = State.getHashValue();
+      for (MachineBasicBlock *Pred : MBB->predecessors()) {
+        if (!Visited.insert(std::pair(Pred, StateHash)).second)
+          continue;
+        Worklist.emplace_back(Pred, State);
+      }
+    }
 
-    if (hasHazard(State, IsHazard, UpdateState, Pred, Pred->instr_rbegin(),
-                  Visited))
-      return true;
+    if (Worklist.empty())
+      break;
+
+    std::tie(MBB, State) = Worklist.pop_back_val();
+    I = MBB->instr_rbegin();
   }
 
   return false;
@@ -1624,6 +1637,10 @@ bool GCNHazardRecognizer::fixVALUPartialForwardingHazard(MachineInstr *MI) {
     SmallDenseMap<Register, int, 4> DefPos;
     int ExecPos = std::numeric_limits<int>::max();
     int VALUs = 0;
+
+    unsigned getHashValue() const {
+      return hash_combine(ExecPos, VALUs, hash_combine_range(DefPos));
+    }
   };
 
   StateType State;
@@ -1718,9 +1735,8 @@ bool GCNHazardRecognizer::fixVALUPartialForwardingHazard(MachineInstr *MI) {
       State.VALUs += 1;
   };
 
-  DenseSet<const MachineBasicBlock *> Visited;
   if (!hasHazard<StateType>(State, IsHazardFn, UpdateStateFn, MI->getParent(),
-                            std::next(MI->getReverseIterator()), Visited))
+                            std::next(MI->getReverseIterator())))
     return false;
 
   BuildMI(*MI->getParent(), MI, MI->getDebugLoc(),
@@ -1761,6 +1777,8 @@ bool GCNHazardRecognizer::fixVALUTransUseHazard(MachineInstr *MI) {
   struct StateType {
     int VALUs = 0;
     int TRANS = 0;
+
+    unsigned getHashValue() const { return hash_combine(VALUs, TRANS); }
   };
 
   StateType State;
@@ -1796,9 +1814,8 @@ bool GCNHazardRecognizer::fixVALUTransUseHazard(MachineInstr *MI) {
       State.TRANS += 1;
   };
 
-  DenseSet<const MachineBasicBlock *> Visited;
   if (!hasHazard<StateType>(State, IsHazardFn, UpdateStateFn, MI->getParent(),
-                            std::next(MI->getReverseIterator()), Visited))
+                            std::next(MI->getReverseIterator())))
     return false;
 
   // Hazard is observed - insert a wait on va_dst counter to ensure hazard is

if (!Expired) {
unsigned StateHash = State.getHashValue();
for (MachineBasicBlock *Pred : MBB->predecessors()) {
if (!Visited.insert(std::pair(Pred, StateHash)).second)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't safe in the presence of hash collisions.

As an alternative, could StateT have a "merge" or "min" function which takes two states and chooses the most constrained one, or the most contraining combination of them? Then if we reach the same predecessor twice by different paths we can merge the states and continue with the merged state.

Then perhaps the worklist could be a MapVector<const MachineBasicBlock *, StateT> with no need for a separate Visited set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this on the basis that is safer than current solution that never revisits a predecessor.

I had considered a solution similar to the one you propose.
But had performance concerns.

I'll work it through and test performance, but I do think we should consider a partial solution such as the one thing patch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants