Skip to content

[Uniformity] Fixed control-div early stop #138806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

Conversation

jgu222
Copy link
Contributor

@jgu222 jgu222 commented May 7, 2025

Control-divergence finds joins by propagating labels from the divergent
control branch. The code that checks the early stop for propagation is
not correct when a sequence of blocks that can be merged into a single
block stay unmerged.

This change fixes this issue by handling this sequence of blocks as they
would be merged.

Note that the existing early stop checking is not optimal. Idealy,
the IPD should be the stop point, but this algorithm can go beyond IPD.

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-llvm-adt

Author: Junjie Gu (jgu222)

Changes

Control-divergence finds joins by propagating labels from the divergent control branch. The code checking if the early stop is reached is not correct.

This change fixes this issue by checking if a join is reached. The propagation is still the same in which propagation starts by adding successors of the divergent block into a set first and then propagate one from the set in the toplogical order.
If there is only one block in the set, then this block must be a join block, and thus propagation stops.


Full diff: https://github.com/llvm/llvm-project/pull/138806.diff

1 Files Affected:

  • (modified) llvm/include/llvm/ADT/GenericUniformityImpl.h (+13-32)
diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index d10355fff1bea..b373710340181 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -610,9 +610,6 @@ template <typename ContextT> class DivergencePropagator {
     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
                       << Context.print(&DivTermBlock) << "\n");
 
-    // Early stopping criterion
-    int FloorIdx = CyclePOT.size() - 1;
-    const BlockT *FloorLabel = nullptr;
     int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
 
     // Bootstrap with branch targets
@@ -626,15 +623,19 @@ template <typename ContextT> class DivergencePropagator {
         LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
                           << Context.print(SuccBlock) << "\n");
       }
-      auto SuccIdx = CyclePOT.getIndex(SuccBlock);
       visitEdge(*SuccBlock, *SuccBlock);
-      FloorIdx = std::min<int>(FloorIdx, SuccIdx);
     }
 
-    while (true) {
+
+    // Propagation shall stop at the IPD (immediate post-dominator)
+    // of DivTemBlock.
+    //
+    // If the number of blocks in FreshLabels is one, the block in FreshLabels
+    // must be a PD. As propagation follows RPO, the first PD reached should
+    // be IPD.
+    while (FreshLabels.count() > 1) {
       auto BlockIdx = FreshLabels.find_last();
-      if (BlockIdx == -1 || BlockIdx < FloorIdx)
-        break;
+      assert(BlockIdx >= 0);
 
       LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
 
@@ -651,9 +652,6 @@ template <typename ContextT> class DivergencePropagator {
       const auto *Label = BlockLabels[Block];
       assert(Label);
 
-      bool CausedJoin = false;
-      int LoweredFloorIdx = FloorIdx;
-
       // If the current block is the header of a reducible cycle that
       // contains the divergent branch, then the label should be
       // propagated to the cycle exits. Such a header is the "last
@@ -681,28 +679,11 @@ template <typename ContextT> class DivergencePropagator {
       if (const auto *BlockCycle = getReducibleParent(Block)) {
         SmallVector<BlockT *, 4> BlockCycleExits;
         BlockCycle->getExitBlocks(BlockCycleExits);
-        for (auto *BlockCycleExit : BlockCycleExits) {
-          CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
-          LoweredFloorIdx =
-              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
-        }
+        for (auto *BlockCycleExit : BlockCycleExits)
+          visitCycleExitEdge(*BlockCycleExit, *Label);
       } else {
-        for (const auto *SuccBlock : successors(Block)) {
-          CausedJoin |= visitEdge(*SuccBlock, *Label);
-          LoweredFloorIdx =
-              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
-        }
-      }
-
-      // Floor update
-      if (CausedJoin) {
-        // 1. Different labels pushed to successors
-        FloorIdx = LoweredFloorIdx;
-      } else if (FloorLabel != Label) {
-        // 2. No join caused BUT we pushed a label that is different than the
-        // last pushed label
-        FloorIdx = LoweredFloorIdx;
-        FloorLabel = Label;
+        for (const auto *SuccBlock : successors(Block))
+          visitEdge(*SuccBlock, *Label);
       }
     }
 

@jgu222 jgu222 marked this pull request as draft May 7, 2025 05:51
Copy link

github-actions bot commented May 7, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions h -- llvm/include/llvm/ADT/GenericUniformityImpl.h
View the diff from clang-format here.
diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index d45ad6098..8b9cb524e 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -604,9 +604,9 @@ public:
     return true;
   }
 
-  // getRealSucc() gets all successors of \p SuccBlock that can be merged with it
-  // and returns the last one of them (called it real succ of SuccBlock).
-  // For example,
+  // getRealSucc() gets all successors of \p SuccBlock that can be merged with
+  // it and returns the last one of them (called it real succ of SuccBlock). For
+  // example,
   //
   //    div-b0
   //    /    \
@@ -626,13 +626,13 @@ public:
   // This is necessary as the algorithm of propagating control-divergence
   // assumes that CFG has been optimized so that (b1,b3,b5) and (b2, b4)
   // are merged into a single block, respectively.
-  const BlockT* getRealSucc(const BlockT& SuccBlock, const BlockT& Label) {
+  const BlockT *getRealSucc(const BlockT &SuccBlock, const BlockT &Label) {
     const BlockT *LastBlock = &SuccBlock;
     if (pred_size(LastBlock) != 1)
       return LastBlock;
 
     while (succ_size(LastBlock) == 1) {
-      const BlockT* NextBlock = *succ_begin(LastBlock);
+      const BlockT *NextBlock = *succ_begin(LastBlock);
 
       if (pred_size(NextBlock) != 1)
         break;
@@ -649,7 +649,7 @@ public:
       BlockLabels[LastBlock] = &Label;
       LastBlock = NextBlock;
     }
- 
+
     return LastBlock;
   }
 
@@ -675,7 +675,7 @@ public:
         LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
                           << Context.print(SuccBlock) << "\n");
       }
-      const auto* SuccBB = getRealSucc(*SuccBlock, *SuccBlock);
+      const auto *SuccBB = getRealSucc(*SuccBlock, *SuccBlock);
       auto SuccIdx = CyclePOT.getIndex(SuccBB);
       visitEdge(*SuccBB, *SuccBlock);
       FloorIdx = std::min<int>(FloorIdx, SuccIdx);
@@ -738,7 +738,7 @@ public:
         }
       } else {
         for (const auto *SuccBlock : successors(Block)) {
-          const auto* SuccBB = getRealSucc(*SuccBlock, *Label);
+          const auto *SuccBB = getRealSucc(*SuccBlock, *Label);
           CausedJoin |= visitEdge(*SuccBB, *Label);
           LoweredFloorIdx =
               std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBB));

@ssahasra
Copy link
Collaborator

ssahasra commented May 7, 2025

Quoting the last comment from #137277 :

Yes, setting FloorIdx to the IPD of DivTermBlock seems working. I have a similar approach (based on this IPD concept). It does not use FloorIdx, rather checking FreshLabels.count() (replacing while (true) with while (FreshLabels.count() > 1)). [ draft: https://github.com/https://github.com//pull/138806 ]

But my local patch have two failures on lit tests, both are for irreducible tests. Looks to me that the irreducible cycles do not have head-rewired cycles done when doing propagation (not sure if they should be).

Correct. The analysis does not make any assumptions about how the header is connected in an irreducible cycle, and does not itself attempt any rewiring. The top of the file has a comment explaining the nature of the RPOT ModifiedPO numbering, where we "virtually" rewire the header, but not actually in the CFG. This is why we have to separately handle cycle exit blocks.

This is why using the RPOT ModifiedPO traversal and FreshLabels to find the IPD might be tricky. Maybe there is a way to do it, but I have not thought about it enough yet. If we can't work with RPOT ModifiedPO, then we will have to use fetch the postdom analysis, which is kinda expensive.

Control-divergence finds joins by propagating labels from the divergent
control branch. The code that checks the early stop for propagation is
not correct when a sequence of blocks that can be merged into a single
block stay unmerged.

This change fixes this issue by handling this sequence of blocks as they
would be merged.

Note that the existing early stop checking is not optimal. Idealy,
the IPD should be the stop point, but this algorithm can go beyond
IPD.
@jgu222 jgu222 marked this pull request as ready for review May 9, 2025 06:45
@jgu222 jgu222 closed this May 13, 2025
@jgu222
Copy link
Contributor Author

jgu222 commented May 13, 2025

replaced with #139667

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants