Skip to content

[mlir][OpenMP] convert wsloop cancellation to LLVMIR #137194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 8, 2025

Conversation

tblah
Copy link
Contributor

@tblah tblah commented Apr 24, 2025

Taskloop support will follow in a later patch.

@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2025

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir

Author: Tom Eccles (tblah)

Changes

Taskloop support will follow in a later patch.


Full diff: https://github.com/llvm/llvm-project/pull/137194.diff

3 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+38-2)
  • (modified) mlir/test/Target/LLVMIR/openmp-cancel.mlir (+87)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-16)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d1885641f389d..7d8a7ccb6e4ac 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -161,8 +161,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
   auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
     omp::ClauseCancellationConstructType cancelledDirective =
         op.getCancelDirective();
-    if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel &&
-        cancelledDirective != omp::ClauseCancellationConstructType::Sections)
+    if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
       result = todo("cancel directive");
   };
   auto checkDepend = [&todo](auto op, LogicalResult &result) {
@@ -2360,6 +2359,30 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
           ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
           : llvm::omp::WorksharingLoopType::ForStaticLoop;
 
+  SmallVector<llvm::BranchInst *> cancelTerminators;
+  // This callback is invoked only if there is cancellation inside of the wsloop
+  // body.
+  auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
+    llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
+    llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
+
+    // ip is currently in the block branched to if cancellation occured.
+    // We need to create a branch to terminate that block.
+    llvmBuilder.restoreIP(ip);
+
+    // We must still clean up the wsloop after cancelling it, so we need to
+    // branch to the block that finalizes the wsloop.
+    // That block has not been created yet so use this block as a dummy for now
+    // and fix this after creating the wsloop.
+    cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
+    return llvm::Error::success();
+  };
+  // We have to add the cleanup to the OpenMPIRBuilder before the body gets
+  // created in case the body contains omp.cancel (which will then expect to be
+  // able to find this cleanup callback).
+  ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
+                                  constructIsCancellable(wsloopOp)});
+
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
@@ -2381,6 +2404,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  ompBuilder->popFinalizationCB();
+  if (!cancelTerminators.empty()) {
+    // If we cancelled the loop, we should branch to the finalization block of
+    // the wsloop (which is always immediately before the loop continuation
+    // block). Now the finalization has been created, we can fix the branch.
+    llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
+    for (llvm::BranchInst *cancelBranch : cancelTerminators) {
+      assert(cancelBranch->getNumSuccessors() == 1 &&
+             "cancel branch should have one target");
+      cancelBranch->setSuccessor(0, wsloopFini);
+    }
+  }
+
   // Process the reductions if required.
   if (failed(createReductionsAndCleanup(
           wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
index fca16b936fc85..3c195a98d1000 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
@@ -156,3 +156,90 @@ llvm.func @cancel_sections_if(%cond : i1) {
 // CHECK:         ret void
 // CHECK:       .cncl:                                            ; preds = %[[VAL_27]]
 // CHECK:         br label %[[VAL_19]]
+
+llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
+  omp.wsloop {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      omp.cancel cancellation_construct_type(loop) if(%cond)
+      omp.yield
+    }
+  }
+  llvm.return
+}
+// CHECK-LABEL: define void @cancel_wsloop_if
+// CHECK:         %[[VAL_0:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_1:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_2:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_3:.*]] = alloca i32, align 4
+// CHECK:         br label %[[VAL_4:.*]]
+// CHECK:       omp.region.after_alloca:                          ; preds = %[[VAL_5:.*]]
+// CHECK:         br label %[[VAL_6:.*]]
+// CHECK:       entry:                                            ; preds = %[[VAL_4]]
+// CHECK:         br label %[[VAL_7:.*]]
+// CHECK:       omp.wsloop.region:                                ; preds = %[[VAL_6]]
+// CHECK:         %[[VAL_8:.*]] = icmp slt i32 %[[VAL_9:.*]], 0
+// CHECK:         %[[VAL_10:.*]] = sub i32 0, %[[VAL_9]]
+// CHECK:         %[[VAL_11:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_10]], i32 %[[VAL_9]]
+// CHECK:         %[[VAL_12:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_13:.*]], i32 %[[VAL_14:.*]]
+// CHECK:         %[[VAL_15:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_14]], i32 %[[VAL_13]]
+// CHECK:         %[[VAL_16:.*]] = sub nsw i32 %[[VAL_15]], %[[VAL_12]]
+// CHECK:         %[[VAL_17:.*]] = icmp sle i32 %[[VAL_15]], %[[VAL_12]]
+// CHECK:         %[[VAL_18:.*]] = sub i32 %[[VAL_16]], 1
+// CHECK:         %[[VAL_19:.*]] = udiv i32 %[[VAL_18]], %[[VAL_11]]
+// CHECK:         %[[VAL_20:.*]] = add i32 %[[VAL_19]], 1
+// CHECK:         %[[VAL_21:.*]] = icmp ule i32 %[[VAL_16]], %[[VAL_11]]
+// CHECK:         %[[VAL_22:.*]] = select i1 %[[VAL_21]], i32 1, i32 %[[VAL_20]]
+// CHECK:         %[[VAL_23:.*]] = select i1 %[[VAL_17]], i32 0, i32 %[[VAL_22]]
+// CHECK:         br label %[[VAL_24:.*]]
+// CHECK:       omp_loop.preheader:                               ; preds = %[[VAL_7]]
+// CHECK:         store i32 0, ptr %[[VAL_1]], align 4
+// CHECK:         %[[VAL_25:.*]] = sub i32 %[[VAL_23]], 1
+// CHECK:         store i32 %[[VAL_25]], ptr %[[VAL_2]], align 4
+// CHECK:         store i32 1, ptr %[[VAL_3]], align 4
+// CHECK:         %[[VAL_26:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_26]], i32 34, ptr %[[VAL_0]], ptr %[[VAL_1]], ptr %[[VAL_2]], ptr %[[VAL_3]], i32 1, i32 0)
+// CHECK:         %[[VAL_27:.*]] = load i32, ptr %[[VAL_1]], align 4
+// CHECK:         %[[VAL_28:.*]] = load i32, ptr %[[VAL_2]], align 4
+// CHECK:         %[[VAL_29:.*]] = sub i32 %[[VAL_28]], %[[VAL_27]]
+// CHECK:         %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1
+// CHECK:         br label %[[VAL_31:.*]]
+// CHECK:       omp_loop.header:                                  ; preds = %[[VAL_32:.*]], %[[VAL_24]]
+// CHECK:         %[[VAL_33:.*]] = phi i32 [ 0, %[[VAL_24]] ], [ %[[VAL_34:.*]], %[[VAL_32]] ]
+// CHECK:         br label %[[VAL_35:.*]]
+// CHECK:       omp_loop.cond:                                    ; preds = %[[VAL_31]]
+// CHECK:         %[[VAL_36:.*]] = icmp ult i32 %[[VAL_33]], %[[VAL_30]]
+// CHECK:         br i1 %[[VAL_36]], label %[[VAL_37:.*]], label %[[VAL_38:.*]]
+// CHECK:       omp_loop.body:                                    ; preds = %[[VAL_35]]
+// CHECK:         %[[VAL_39:.*]] = add i32 %[[VAL_33]], %[[VAL_27]]
+// CHECK:         %[[VAL_40:.*]] = mul i32 %[[VAL_39]], %[[VAL_9]]
+// CHECK:         %[[VAL_41:.*]] = add i32 %[[VAL_40]], %[[VAL_14]]
+// CHECK:         br label %[[VAL_42:.*]]
+// CHECK:       omp.loop_nest.region:                             ; preds = %[[VAL_37]]
+// CHECK:         br i1 %[[VAL_43:.*]], label %[[VAL_44:.*]], label %[[VAL_45:.*]]
+// CHECK:       25:                                               ; preds = %[[VAL_42]]
+// CHECK:         %[[VAL_46:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2)
+// CHECK:         %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0
+// CHECK:         br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]]
+// CHECK:       .split:                                           ; preds = %[[VAL_44]]
+// CHECK:         br label %[[VAL_51:.*]]
+// CHECK:       28:                                               ; preds = %[[VAL_42]]
+// CHECK:         br label %[[VAL_51]]
+// CHECK:       29:                                               ; preds = %[[VAL_45]], %[[VAL_49]]
+// CHECK:         br label %[[VAL_52:.*]]
+// CHECK:       omp.region.cont1:                                 ; preds = %[[VAL_51]]
+// CHECK:         br label %[[VAL_32]]
+// CHECK:       omp_loop.inc:                                     ; preds = %[[VAL_52]]
+// CHECK:         %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1
+// CHECK:         br label %[[VAL_31]]
+// CHECK:       omp_loop.exit:                                    ; preds = %[[VAL_50]], %[[VAL_35]]
+// CHECK:         call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]])
+// CHECK:         %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]])
+// CHECK:         br label %[[VAL_54:.*]]
+// CHECK:       omp_loop.after:                                   ; preds = %[[VAL_38]]
+// CHECK:         br label %[[VAL_55:.*]]
+// CHECK:       omp.region.cont:                                  ; preds = %[[VAL_54]]
+// CHECK:         ret void
+// CHECK:       .cncl:                                            ; preds = %[[VAL_44]]
+// CHECK:         br label %[[VAL_38]]
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 6f904d0647285..f8d720dfe420c 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -26,22 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
 
 // -----
 
-llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) {
-  // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}
-  omp.wsloop {
-    // expected-error@below {{LLVM Translation failed for operation: omp.loop_nest}}
-    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      // expected-error@below {{not yet implemented: Unhandled clause cancel directive in omp.cancel operation}}
-      // expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
-      omp.cancel cancellation_construct_type(loop)
-      omp.yield
-    }
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @cancel_taskgroup() {
   // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
   omp.taskgroup {

@tblah tblah force-pushed the users/tblah/omp-cancel-codegen-1 branch from aa2445b to 4d28bd8 Compare April 25, 2025 14:17
@tblah tblah force-pushed the users/tblah/omp-cancel-codegen-2 branch from bb374c9 to bb1af01 Compare April 25, 2025 14:47
@tblah tblah force-pushed the users/tblah/omp-cancel-codegen-1 branch from 4d28bd8 to f3b8eeb Compare April 26, 2025 11:48
@tblah tblah force-pushed the users/tblah/omp-cancel-codegen-2 branch from bb1af01 to 3a198c8 Compare April 26, 2025 11:49
@tblah tblah force-pushed the users/tblah/omp-cancel-codegen-1 branch from 60592cf to 6c678b7 Compare April 28, 2025 15:09
@tblah tblah force-pushed the users/tblah/omp-cancel-codegen-2 branch from 3a198c8 to 9a8ed32 Compare April 28, 2025 15:12
Base automatically changed from users/tblah/omp-cancel-codegen-1 to main April 29, 2025 16:19
Taskloop support will follow in a later patch.
@tblah
Copy link
Contributor Author

tblah commented May 6, 2025

Ping for review

@skatrak
Copy link
Member

skatrak commented May 6, 2025

Does this also work with composite omp.distribute + omp.wsloop loops? The new block of code in convertOmpWsloop() is ran for do/for and also for distribute parallel do/for. Just wondering if it should work (so adding a unit test for it would be a good idea) or if we should be triggering a TODO in that case.

@tblah
Copy link
Contributor Author

tblah commented May 7, 2025

Does this also work with composite omp.distribute + omp.wsloop loops?

Thanks for taking a look. Yes it does work correctly. I have added a test. See the branches from the cancel block.

Copy link
Member

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

@tblah tblah merged commit a385c47 into main May 8, 2025
11 checks passed
@tblah tblah deleted the users/tblah/omp-cancel-codegen-2 branch May 8, 2025 10:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants