Skip to content

[CIR] Upstream support for FlattenCFG switch and SwitchFlatOp #139154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

Andres-Salamanca
Copy link
Contributor

This PR adds support for the FlattenCFG transformation on switch statements. It also introduces the SwitchFlatOp, which is necessary for subsequent lowering to LLVM.

@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels May 8, 2025
@llvmbot
Copy link
Member

llvmbot commented May 8, 2025

@llvm/pr-subscribers-clangir

Author: None (Andres-Salamanca)

Changes

This PR adds support for the FlattenCFG transformation on switch statements. It also introduces the SwitchFlatOp, which is necessary for subsequent lowering to LLVM.


Patch is 28.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139154.diff

6 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+46)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+97)
  • (modified) clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp (+14-1)
  • (modified) clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp (+231-4)
  • (added) clang/test/CIR/IR/switch-flat.cir (+68)
  • (added) clang/test/CIR/Transforms/switch.cir (+278)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 7ffa10464dcd3..914af6d1dc6bd 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -971,6 +971,52 @@ def SwitchOp : CIR_Op<"switch",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
+                                          Terminator]> {
+
+  let description = [{
+    The `cir.switch.flat` operation is a region-less and simplified
+    version of the `cir.switch`.
+    It's representation is closer to LLVM IR dialect
+    than the C/C++ language feature.
+  }];
+
+  let arguments = (ins
+    CIR_IntType:$condition,
+    Variadic<AnyType>:$defaultOperands,
+    VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
+    ArrayAttr:$case_values,
+    DenseI32ArrayAttr:$case_operand_segments
+  );
+
+  let successors = (successor
+    AnySuccessor:$defaultDestination,
+    VariadicSuccessor<AnySuccessor>:$caseDestinations
+  );
+
+  let assemblyFormat = [{
+    $condition `:` type($condition) `,`
+    $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
+    custom<SwitchFlatOpCases>(ref(type($condition)), $case_values,
+                              $caseDestinations, $caseOperands,
+                              type($caseOperands))
+    attr-dict
+  }];
+
+  let builders = [
+    OpBuilder<(ins "mlir::Value":$condition,
+      "mlir::Block *":$defaultDestination,
+      "mlir::ValueRange":$defaultOperands,
+      CArg<"llvm::ArrayRef<llvm::APInt>", "{}">:$caseValues,
+      CArg<"mlir::BlockRange", "{}">:$caseDestinations,
+      CArg<"llvm::ArrayRef<mlir::ValueRange>", "{}">:$caseOperands)>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // BrOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b131edaf403ed..ca03013edb485 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -22,6 +22,7 @@
 #include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
 #include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
 #include "clang/CIR/MissingFeatures.h"
+#include <numeric>
 
 using namespace mlir;
 using namespace cir;
@@ -962,6 +963,102 @@ bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
   });
 }
 
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
+                              Value value, Block *defaultDestination,
+                              ValueRange defaultOperands,
+                              ArrayRef<APInt> caseValues,
+                              BlockRange caseDestinations,
+                              ArrayRef<ValueRange> caseOperands) {
+
+  std::vector<mlir::Attribute> caseValuesAttrs;
+  for (auto &val : caseValues) {
+    caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
+  }
+  mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
+
+  build(builder, result, value, defaultOperands, caseOperands, attrs,
+        defaultDestination, caseDestinations);
+}
+
+/// <cases> ::= `[` (case (`,` case )* )? `]`
+/// <case>  ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
+static ParseResult parseSwitchFlatOpCases(
+    OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
+    SmallVectorImpl<Block *> &caseDestinations,
+    SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
+        &caseOperands,
+    SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
+  if (failed(parser.parseLSquare()))
+    return failure();
+  if (succeeded(parser.parseOptionalRSquare()))
+    return success();
+  llvm::SmallVector<mlir::Attribute> values;
+
+  auto parseCase = [&]() {
+    int64_t value = 0;
+    if (failed(parser.parseInteger(value)))
+      return failure();
+
+    values.push_back(cir::IntAttr::get(flagType, value));
+
+    Block *destination;
+    llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
+    llvm::SmallVector<Type> operandTypes;
+    if (parser.parseColon() || parser.parseSuccessor(destination))
+      return failure();
+    if (!parser.parseOptionalLParen()) {
+      if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
+                                  /*allowResultNumber=*/false) ||
+          parser.parseColonTypeList(operandTypes) || parser.parseRParen())
+        return failure();
+    }
+    caseDestinations.push_back(destination);
+    caseOperands.emplace_back(operands);
+    caseOperandTypes.emplace_back(operandTypes);
+    return success();
+  };
+  if (failed(parser.parseCommaSeparatedList(parseCase)))
+    return failure();
+
+  caseValues = ArrayAttr::get(flagType.getContext(), values);
+
+  return parser.parseRSquare();
+}
+
+static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
+                                   Type flagType, mlir::ArrayAttr caseValues,
+                                   SuccessorRange caseDestinations,
+                                   OperandRangeRange caseOperands,
+                                   const TypeRangeRange &caseOperandTypes) {
+  p << '[';
+  p.printNewline();
+  if (!caseValues) {
+    p << ']';
+    return;
+  }
+
+  size_t index = 0;
+  llvm::interleave(
+      llvm::zip(caseValues, caseDestinations),
+      [&](auto i) {
+        p << "  ";
+        mlir::Attribute a = std::get<0>(i);
+        p << mlir::cast<cir::IntAttr>(a).getValue();
+        p << ": ";
+        p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
+      },
+      [&] {
+        p << ',';
+        p.printNewline();
+      });
+  p.printNewline();
+  p << ']';
+}
+
 //===----------------------------------------------------------------------===//
 // GlobalOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 3b4c7bc613133..edbb848322d41 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -84,6 +84,19 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
   }
 };
 
+struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
+  using OpRewritePattern<SwitchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SwitchOp op,
+                                PatternRewriter &rewriter) const final {
+    if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
+      return failure();
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // CIRCanonicalizePass
 //===----------------------------------------------------------------------===//
@@ -127,7 +140,7 @@ void CIRCanonicalizePass::runOnOperation() {
     assert(!cir::MissingFeatures::callOp());
     // CastOp and UnaryOp are here to perform a manual `fold` in
     // applyOpPatternsGreedily.
-    if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
+    if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 4a936d33b022a..70f383b556567 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
   }
 };
 
+class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
+public:
+  using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
+
+  inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
+                             cir::YieldOp yieldOp,
+                             mlir::Block *destination) const {
+    rewriter.setInsertionPoint(yieldOp);
+    rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
+                                           destination);
+  }
+
+  // Return the new defaultDestination block.
+  Block *condBrToRangeDestination(cir::SwitchOp op,
+                                  mlir::PatternRewriter &rewriter,
+                                  mlir::Block *rangeDestination,
+                                  mlir::Block *defaultDestination,
+                                  const APInt &lowerBound,
+                                  const APInt &upperBound) const {
+    assert(lowerBound.sle(upperBound) && "Invalid range");
+    mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
+    cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
+    cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
+
+    cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(
+        op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));
+
+    cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(
+        op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
+    cir::BinOp diffValue =
+        rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,
+                                    op.getCondition(), lowerBoundValue);
+
+    // Use unsigned comparison to check if the condition is in the range.
+    cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(
+        op.getLoc(), uIntType, CastKind::integral, diffValue);
+    cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(
+        op.getLoc(), uIntType, CastKind::integral, rangeLength);
+
+    cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(
+        op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,
+        uDiffValue, uRangeLength);
+    rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,
+                                   defaultDestination);
+    return resBlock;
+  }
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::SwitchOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    llvm::SmallVector<CaseOp> cases;
+    op.collectCases(cases);
+
+    // Empty switch statement: just erase it.
+    if (cases.empty()) {
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // Create exit block from the next node of cir.switch op.
+    mlir::Block *exitBlock = rewriter.splitBlock(
+        rewriter.getBlock(), op->getNextNode()->getIterator());
+
+    // We lower cir.switch op in the following process:
+    // 1. Inline the region from the switch op after switch op.
+    // 2. Traverse each cir.case op:
+    //    a. Record the entry block, block arguments and condition for every
+    //    case. b. Inline the case region after the case op.
+    // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
+    //    recorded block and conditions.
+
+    // inline everything from switch body between the switch op and the exit
+    // block.
+    {
+      cir::YieldOp switchYield = nullptr;
+      // Clear switch operation.
+      for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
+        if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
+          switchYield = yieldOp;
+
+      assert(!op.getBody().empty());
+      mlir::Block *originalBlock = op->getBlock();
+      mlir::Block *swopBlock =
+          rewriter.splitBlock(originalBlock, op->getIterator());
+      rewriter.inlineRegionBefore(op.getBody(), exitBlock);
+
+      if (switchYield)
+        rewriteYieldOp(rewriter, switchYield, exitBlock);
+
+      rewriter.setInsertionPointToEnd(originalBlock);
+      rewriter.create<cir::BrOp>(op.getLoc(), swopBlock);
+    }
+
+    // Allocate required data structures (disconsider default case in
+    // vectors).
+    llvm::SmallVector<mlir::APInt, 8> caseValues;
+    llvm::SmallVector<mlir::Block *, 8> caseDestinations;
+    llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
+
+    llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
+    llvm::SmallVector<mlir::Block *> rangeDestinations;
+    llvm::SmallVector<mlir::ValueRange> rangeOperands;
+
+    // Initialize default case as optional.
+    mlir::Block *defaultDestination = exitBlock;
+    mlir::ValueRange defaultOperands = exitBlock->getArguments();
+
+    // Digest the case statements values and bodies.
+    for (auto caseOp : cases) {
+      mlir::Region &region = caseOp.getCaseRegion();
+
+      // Found default case: save destination and operands.
+      switch (caseOp.getKind()) {
+      case cir::CaseOpKind::Default:
+        defaultDestination = &region.front();
+        defaultOperands = defaultDestination->getArguments();
+        break;
+      case cir::CaseOpKind::Range:
+        assert(caseOp.getValue().size() == 2 &&
+               "Case range should have 2 case value");
+        rangeValues.push_back(
+            {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
+             cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
+        rangeDestinations.push_back(&region.front());
+        rangeOperands.push_back(rangeDestinations.back()->getArguments());
+        break;
+      case cir::CaseOpKind::Anyof:
+      case cir::CaseOpKind::Equal:
+        // AnyOf cases kind can have multiple values, hence the loop below.
+        for (auto &value : caseOp.getValue()) {
+          caseValues.push_back(cast<cir::IntAttr>(value).getValue());
+          caseDestinations.push_back(&region.front());
+          caseOperands.push_back(caseDestinations.back()->getArguments());
+        }
+        break;
+      }
+
+      // Handle break statements.
+      walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
+          region, [&](mlir::Operation *op) {
+            if (!isa<cir::BreakOp>(op))
+              return mlir::WalkResult::advance();
+
+            lowerTerminator(op, exitBlock, rewriter);
+            return mlir::WalkResult::skip();
+          });
+
+      // Track fallthrough in cases.
+      for (auto &blk : region.getBlocks()) {
+        if (blk.getNumSuccessors())
+          continue;
+
+        if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
+          mlir::Operation *nextOp = caseOp->getNextNode();
+          assert(nextOp && "caseOp is not expected to be the last op");
+          mlir::Block *oldBlock = nextOp->getBlock();
+          mlir::Block *newBlock =
+              rewriter.splitBlock(oldBlock, nextOp->getIterator());
+          rewriter.setInsertionPointToEnd(oldBlock);
+          rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(),
+                                     newBlock);
+          rewriteYieldOp(rewriter, yieldOp, newBlock);
+        }
+      }
+
+      mlir::Block *oldBlock = caseOp->getBlock();
+      mlir::Block *newBlock =
+          rewriter.splitBlock(oldBlock, caseOp->getIterator());
+
+      mlir::Block &entryBlock = caseOp.getCaseRegion().front();
+      rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
+
+      // Create a branch to the entry of the inlined region.
+      rewriter.setInsertionPointToEnd(oldBlock);
+      rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock);
+    }
+
+    // Remove all cases since we've inlined the regions.
+    for (auto caseOp : cases) {
+      mlir::Block *caseBlock = caseOp->getBlock();
+      // Erase the block with no predecessors here to make the generated code
+      // simpler a little bit.
+      if (caseBlock->hasNoPredecessors())
+        rewriter.eraseBlock(caseBlock);
+      else
+        rewriter.eraseOp(caseOp);
+    }
+
+    for (size_t index = 0; index < rangeValues.size(); ++index) {
+      APInt lowerBound = rangeValues[index].first;
+      APInt upperBound = rangeValues[index].second;
+
+      // The case range is unreachable, skip it.
+      if (lowerBound.sgt(upperBound))
+        continue;
+
+      // If range is small, add multiple switch instruction cases.
+      // This magical number is from the original CGStmt code.
+      constexpr int kSmallRangeThreshold = 64;
+      if ((upperBound - lowerBound)
+              .ult(llvm::APInt(32, kSmallRangeThreshold))) {
+        for (APInt iValue = lowerBound; iValue.sle(upperBound);
+             (void)iValue++) {
+          caseValues.push_back(iValue);
+          caseOperands.push_back(rangeOperands[index]);
+          caseDestinations.push_back(rangeDestinations[index]);
+        }
+        continue;
+      }
+
+      defaultDestination =
+          condBrToRangeDestination(op, rewriter, rangeDestinations[index],
+                                   defaultDestination, lowerBound, upperBound);
+      defaultOperands = rangeOperands[index];
+    }
+
+    // Set switch op to branch to the newly created blocks.
+    rewriter.setInsertionPoint(op);
+    rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
+        op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
+        caseDestinations, caseOperands);
+
+    return mlir::success();
+  }
+};
+
 class CIRLoopOpInterfaceFlattening
     : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
 public:
@@ -306,9 +532,10 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
 };
 
 void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
-  patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
-               CIRScopeOpFlattening, CIRTernaryOpFlattening>(
-      patterns.getContext());
+  patterns
+      .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
+           CIRSwitchOpFlattening, CIRTernaryOpFlattening>(
+          patterns.getContext());
 }
 
 void CIRFlattenCFGPass::runOnOperation() {
@@ -321,7 +548,7 @@ void CIRFlattenCFGPass::runOnOperation() {
     assert(!cir::MissingFeatures::ifOp());
     assert(!cir::MissingFeatures::switchOp());
     assert(!cir::MissingFeatures::tryOp());
-    if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
+    if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/test/CIR/IR/switch-flat.cir b/clang/test/CIR/IR/switch-flat.cir
new file mode 100644
index 0000000000000..b072c224b4a2c
--- /dev/null
+++ b/clang/test/CIR/IR/switch-flat.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s | FileCheck %s
+!s32i = !cir.int<s, 32>
+
+cir.func @FlatSwitchWithoutDefault(%arg0: !s32i) {
+  cir.switch.flat %arg0 : !s32i, ^bb2 [
+    1: ^bb1
+  ]
+  ^bb1:
+    cir.br ^bb2
+  ^bb2:
+    cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT:  1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   cir.br ^bb2
+// CHECK-NEXT: ^bb2:
+//CHECK-NEXT:    cir.return
+
+cir.func @FlatSwitchWithDefault(%arg0: !s32i) {
+  cir.switch.flat %arg0 : !s32i, ^bb2 [
+    1: ^bb1
+  ]
+  ^bb1:
+    cir.br ^bb3
+  ^bb2:
+    cir.br ^bb3
+  ^bb3:
+    cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT:  1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT:   cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT:   cir.return
+
+cir.func @switchWithOperands(%arg0: !s32i, %arg1: !s32i, %arg2: !s32i) {
+  cir.switch.flat %arg0 : !s32i, ^bb3 [
+    0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+    1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+  ]
+^bb1:
+  cir.br ^bb3
+
+^bb2:
+  cir.br ^bb3
+
+^bb3:
+  cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb3 [
+// CHECK-NEXT:  0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+// CHECK-NEXT:  1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT:    cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT:    cir.return
diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir
new file mode 100644
index 0000000000000..a05cf37e39728
--- /dev/null
+++ b/clang/te...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 8, 2025

@llvm/pr-subscribers-clang

Author: None (Andres-Salamanca)

Changes

This PR adds support for the FlattenCFG transformation on switch statements. It also introduces the SwitchFlatOp, which is necessary for subsequent lowering to LLVM.


Patch is 28.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139154.diff

6 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+46)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+97)
  • (modified) clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp (+14-1)
  • (modified) clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp (+231-4)
  • (added) clang/test/CIR/IR/switch-flat.cir (+68)
  • (added) clang/test/CIR/Transforms/switch.cir (+278)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 7ffa10464dcd3..914af6d1dc6bd 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -971,6 +971,52 @@ def SwitchOp : CIR_Op<"switch",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
+                                          Terminator]> {
+
+  let description = [{
+    The `cir.switch.flat` operation is a region-less and simplified
+    version of the `cir.switch`.
+    It's representation is closer to LLVM IR dialect
+    than the C/C++ language feature.
+  }];
+
+  let arguments = (ins
+    CIR_IntType:$condition,
+    Variadic<AnyType>:$defaultOperands,
+    VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
+    ArrayAttr:$case_values,
+    DenseI32ArrayAttr:$case_operand_segments
+  );
+
+  let successors = (successor
+    AnySuccessor:$defaultDestination,
+    VariadicSuccessor<AnySuccessor>:$caseDestinations
+  );
+
+  let assemblyFormat = [{
+    $condition `:` type($condition) `,`
+    $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
+    custom<SwitchFlatOpCases>(ref(type($condition)), $case_values,
+                              $caseDestinations, $caseOperands,
+                              type($caseOperands))
+    attr-dict
+  }];
+
+  let builders = [
+    OpBuilder<(ins "mlir::Value":$condition,
+      "mlir::Block *":$defaultDestination,
+      "mlir::ValueRange":$defaultOperands,
+      CArg<"llvm::ArrayRef<llvm::APInt>", "{}">:$caseValues,
+      CArg<"mlir::BlockRange", "{}">:$caseDestinations,
+      CArg<"llvm::ArrayRef<mlir::ValueRange>", "{}">:$caseOperands)>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // BrOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b131edaf403ed..ca03013edb485 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -22,6 +22,7 @@
 #include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
 #include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
 #include "clang/CIR/MissingFeatures.h"
+#include <numeric>
 
 using namespace mlir;
 using namespace cir;
@@ -962,6 +963,102 @@ bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
   });
 }
 
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
+                              Value value, Block *defaultDestination,
+                              ValueRange defaultOperands,
+                              ArrayRef<APInt> caseValues,
+                              BlockRange caseDestinations,
+                              ArrayRef<ValueRange> caseOperands) {
+
+  std::vector<mlir::Attribute> caseValuesAttrs;
+  for (auto &val : caseValues) {
+    caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
+  }
+  mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
+
+  build(builder, result, value, defaultOperands, caseOperands, attrs,
+        defaultDestination, caseDestinations);
+}
+
+/// <cases> ::= `[` (case (`,` case )* )? `]`
+/// <case>  ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
+static ParseResult parseSwitchFlatOpCases(
+    OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
+    SmallVectorImpl<Block *> &caseDestinations,
+    SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
+        &caseOperands,
+    SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
+  if (failed(parser.parseLSquare()))
+    return failure();
+  if (succeeded(parser.parseOptionalRSquare()))
+    return success();
+  llvm::SmallVector<mlir::Attribute> values;
+
+  auto parseCase = [&]() {
+    int64_t value = 0;
+    if (failed(parser.parseInteger(value)))
+      return failure();
+
+    values.push_back(cir::IntAttr::get(flagType, value));
+
+    Block *destination;
+    llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
+    llvm::SmallVector<Type> operandTypes;
+    if (parser.parseColon() || parser.parseSuccessor(destination))
+      return failure();
+    if (!parser.parseOptionalLParen()) {
+      if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
+                                  /*allowResultNumber=*/false) ||
+          parser.parseColonTypeList(operandTypes) || parser.parseRParen())
+        return failure();
+    }
+    caseDestinations.push_back(destination);
+    caseOperands.emplace_back(operands);
+    caseOperandTypes.emplace_back(operandTypes);
+    return success();
+  };
+  if (failed(parser.parseCommaSeparatedList(parseCase)))
+    return failure();
+
+  caseValues = ArrayAttr::get(flagType.getContext(), values);
+
+  return parser.parseRSquare();
+}
+
+static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
+                                   Type flagType, mlir::ArrayAttr caseValues,
+                                   SuccessorRange caseDestinations,
+                                   OperandRangeRange caseOperands,
+                                   const TypeRangeRange &caseOperandTypes) {
+  p << '[';
+  p.printNewline();
+  if (!caseValues) {
+    p << ']';
+    return;
+  }
+
+  size_t index = 0;
+  llvm::interleave(
+      llvm::zip(caseValues, caseDestinations),
+      [&](auto i) {
+        p << "  ";
+        mlir::Attribute a = std::get<0>(i);
+        p << mlir::cast<cir::IntAttr>(a).getValue();
+        p << ": ";
+        p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
+      },
+      [&] {
+        p << ',';
+        p.printNewline();
+      });
+  p.printNewline();
+  p << ']';
+}
+
 //===----------------------------------------------------------------------===//
 // GlobalOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 3b4c7bc613133..edbb848322d41 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -84,6 +84,19 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
   }
 };
 
+struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
+  using OpRewritePattern<SwitchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SwitchOp op,
+                                PatternRewriter &rewriter) const final {
+    if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
+      return failure();
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // CIRCanonicalizePass
 //===----------------------------------------------------------------------===//
@@ -127,7 +140,7 @@ void CIRCanonicalizePass::runOnOperation() {
     assert(!cir::MissingFeatures::callOp());
     // CastOp and UnaryOp are here to perform a manual `fold` in
     // applyOpPatternsGreedily.
-    if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
+    if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 4a936d33b022a..70f383b556567 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
   }
 };
 
+class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
+public:
+  using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
+
+  inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
+                             cir::YieldOp yieldOp,
+                             mlir::Block *destination) const {
+    rewriter.setInsertionPoint(yieldOp);
+    rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
+                                           destination);
+  }
+
+  // Return the new defaultDestination block.
+  Block *condBrToRangeDestination(cir::SwitchOp op,
+                                  mlir::PatternRewriter &rewriter,
+                                  mlir::Block *rangeDestination,
+                                  mlir::Block *defaultDestination,
+                                  const APInt &lowerBound,
+                                  const APInt &upperBound) const {
+    assert(lowerBound.sle(upperBound) && "Invalid range");
+    mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
+    cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
+    cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
+
+    cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(
+        op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));
+
+    cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(
+        op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
+    cir::BinOp diffValue =
+        rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,
+                                    op.getCondition(), lowerBoundValue);
+
+    // Use unsigned comparison to check if the condition is in the range.
+    cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(
+        op.getLoc(), uIntType, CastKind::integral, diffValue);
+    cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(
+        op.getLoc(), uIntType, CastKind::integral, rangeLength);
+
+    cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(
+        op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,
+        uDiffValue, uRangeLength);
+    rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,
+                                   defaultDestination);
+    return resBlock;
+  }
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::SwitchOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    llvm::SmallVector<CaseOp> cases;
+    op.collectCases(cases);
+
+    // Empty switch statement: just erase it.
+    if (cases.empty()) {
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // Create exit block from the next node of cir.switch op.
+    mlir::Block *exitBlock = rewriter.splitBlock(
+        rewriter.getBlock(), op->getNextNode()->getIterator());
+
+    // We lower cir.switch op in the following process:
+    // 1. Inline the region from the switch op after switch op.
+    // 2. Traverse each cir.case op:
+    //    a. Record the entry block, block arguments and condition for every
+    //    case. b. Inline the case region after the case op.
+    // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
+    //    recorded block and conditions.
+
+    // inline everything from switch body between the switch op and the exit
+    // block.
+    {
+      cir::YieldOp switchYield = nullptr;
+      // Clear switch operation.
+      for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
+        if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
+          switchYield = yieldOp;
+
+      assert(!op.getBody().empty());
+      mlir::Block *originalBlock = op->getBlock();
+      mlir::Block *swopBlock =
+          rewriter.splitBlock(originalBlock, op->getIterator());
+      rewriter.inlineRegionBefore(op.getBody(), exitBlock);
+
+      if (switchYield)
+        rewriteYieldOp(rewriter, switchYield, exitBlock);
+
+      rewriter.setInsertionPointToEnd(originalBlock);
+      rewriter.create<cir::BrOp>(op.getLoc(), swopBlock);
+    }
+
+    // Allocate required data structures (disconsider default case in
+    // vectors).
+    llvm::SmallVector<mlir::APInt, 8> caseValues;
+    llvm::SmallVector<mlir::Block *, 8> caseDestinations;
+    llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
+
+    llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
+    llvm::SmallVector<mlir::Block *> rangeDestinations;
+    llvm::SmallVector<mlir::ValueRange> rangeOperands;
+
+    // Initialize default case as optional.
+    mlir::Block *defaultDestination = exitBlock;
+    mlir::ValueRange defaultOperands = exitBlock->getArguments();
+
+    // Digest the case statements values and bodies.
+    for (auto caseOp : cases) {
+      mlir::Region &region = caseOp.getCaseRegion();
+
+      // Found default case: save destination and operands.
+      switch (caseOp.getKind()) {
+      case cir::CaseOpKind::Default:
+        defaultDestination = &region.front();
+        defaultOperands = defaultDestination->getArguments();
+        break;
+      case cir::CaseOpKind::Range:
+        assert(caseOp.getValue().size() == 2 &&
+               "Case range should have 2 case value");
+        rangeValues.push_back(
+            {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
+             cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
+        rangeDestinations.push_back(&region.front());
+        rangeOperands.push_back(rangeDestinations.back()->getArguments());
+        break;
+      case cir::CaseOpKind::Anyof:
+      case cir::CaseOpKind::Equal:
+        // AnyOf cases kind can have multiple values, hence the loop below.
+        for (auto &value : caseOp.getValue()) {
+          caseValues.push_back(cast<cir::IntAttr>(value).getValue());
+          caseDestinations.push_back(&region.front());
+          caseOperands.push_back(caseDestinations.back()->getArguments());
+        }
+        break;
+      }
+
+      // Handle break statements.
+      walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
+          region, [&](mlir::Operation *op) {
+            if (!isa<cir::BreakOp>(op))
+              return mlir::WalkResult::advance();
+
+            lowerTerminator(op, exitBlock, rewriter);
+            return mlir::WalkResult::skip();
+          });
+
+      // Track fallthrough in cases.
+      for (auto &blk : region.getBlocks()) {
+        if (blk.getNumSuccessors())
+          continue;
+
+        if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
+          mlir::Operation *nextOp = caseOp->getNextNode();
+          assert(nextOp && "caseOp is not expected to be the last op");
+          mlir::Block *oldBlock = nextOp->getBlock();
+          mlir::Block *newBlock =
+              rewriter.splitBlock(oldBlock, nextOp->getIterator());
+          rewriter.setInsertionPointToEnd(oldBlock);
+          rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(),
+                                     newBlock);
+          rewriteYieldOp(rewriter, yieldOp, newBlock);
+        }
+      }
+
+      mlir::Block *oldBlock = caseOp->getBlock();
+      mlir::Block *newBlock =
+          rewriter.splitBlock(oldBlock, caseOp->getIterator());
+
+      mlir::Block &entryBlock = caseOp.getCaseRegion().front();
+      rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
+
+      // Create a branch to the entry of the inlined region.
+      rewriter.setInsertionPointToEnd(oldBlock);
+      rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock);
+    }
+
+    // Remove all cases since we've inlined the regions.
+    for (auto caseOp : cases) {
+      mlir::Block *caseBlock = caseOp->getBlock();
+      // Erase the block with no predecessors here to make the generated code
+      // simpler a little bit.
+      if (caseBlock->hasNoPredecessors())
+        rewriter.eraseBlock(caseBlock);
+      else
+        rewriter.eraseOp(caseOp);
+    }
+
+    for (size_t index = 0; index < rangeValues.size(); ++index) {
+      APInt lowerBound = rangeValues[index].first;
+      APInt upperBound = rangeValues[index].second;
+
+      // The case range is unreachable, skip it.
+      if (lowerBound.sgt(upperBound))
+        continue;
+
+      // If range is small, add multiple switch instruction cases.
+      // This magical number is from the original CGStmt code.
+      constexpr int kSmallRangeThreshold = 64;
+      if ((upperBound - lowerBound)
+              .ult(llvm::APInt(32, kSmallRangeThreshold))) {
+        for (APInt iValue = lowerBound; iValue.sle(upperBound);
+             (void)iValue++) {
+          caseValues.push_back(iValue);
+          caseOperands.push_back(rangeOperands[index]);
+          caseDestinations.push_back(rangeDestinations[index]);
+        }
+        continue;
+      }
+
+      defaultDestination =
+          condBrToRangeDestination(op, rewriter, rangeDestinations[index],
+                                   defaultDestination, lowerBound, upperBound);
+      defaultOperands = rangeOperands[index];
+    }
+
+    // Set switch op to branch to the newly created blocks.
+    rewriter.setInsertionPoint(op);
+    rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
+        op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
+        caseDestinations, caseOperands);
+
+    return mlir::success();
+  }
+};
+
 class CIRLoopOpInterfaceFlattening
     : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
 public:
@@ -306,9 +532,10 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
 };
 
 void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
-  patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
-               CIRScopeOpFlattening, CIRTernaryOpFlattening>(
-      patterns.getContext());
+  patterns
+      .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
+           CIRSwitchOpFlattening, CIRTernaryOpFlattening>(
+          patterns.getContext());
 }
 
 void CIRFlattenCFGPass::runOnOperation() {
@@ -321,7 +548,7 @@ void CIRFlattenCFGPass::runOnOperation() {
     assert(!cir::MissingFeatures::ifOp());
     assert(!cir::MissingFeatures::switchOp());
     assert(!cir::MissingFeatures::tryOp());
-    if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
+    if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/test/CIR/IR/switch-flat.cir b/clang/test/CIR/IR/switch-flat.cir
new file mode 100644
index 0000000000000..b072c224b4a2c
--- /dev/null
+++ b/clang/test/CIR/IR/switch-flat.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s | FileCheck %s
+!s32i = !cir.int<s, 32>
+
+cir.func @FlatSwitchWithoutDefault(%arg0: !s32i) {
+  cir.switch.flat %arg0 : !s32i, ^bb2 [
+    1: ^bb1
+  ]
+  ^bb1:
+    cir.br ^bb2
+  ^bb2:
+    cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT:  1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   cir.br ^bb2
+// CHECK-NEXT: ^bb2:
+//CHECK-NEXT:    cir.return
+
+cir.func @FlatSwitchWithDefault(%arg0: !s32i) {
+  cir.switch.flat %arg0 : !s32i, ^bb2 [
+    1: ^bb1
+  ]
+  ^bb1:
+    cir.br ^bb3
+  ^bb2:
+    cir.br ^bb3
+  ^bb3:
+    cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT:  1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT:   cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT:   cir.return
+
+cir.func @switchWithOperands(%arg0: !s32i, %arg1: !s32i, %arg2: !s32i) {
+  cir.switch.flat %arg0 : !s32i, ^bb3 [
+    0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+    1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+  ]
+^bb1:
+  cir.br ^bb3
+
+^bb2:
+  cir.br ^bb3
+
+^bb3:
+  cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb3 [
+// CHECK-NEXT:  0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+// CHECK-NEXT:  1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT:    cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT:    cir.return
diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir
new file mode 100644
index 0000000000000..a05cf37e39728
--- /dev/null
+++ b/clang/te...
[truncated]

@Andres-Salamanca
Copy link
Contributor Author

@andykaylor

@andykaylor andykaylor requested review from mmha and andykaylor May 8, 2025 21:59
let description = [{
The `cir.switch.flat` operation is a region-less and simplified
version of the `cir.switch`.
It's representation is closer to LLVM IR dialect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
It's representation is closer to LLVM IR dialect
Its representation is closer to LLVM IR dialect

Also, the formatting is odd here. Can you clean up the word wrap locations?

CIR_IntType:$condition,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
ArrayAttr:$case_values,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an inconsistency between case_values here and on line 1004, and caseValues on line 1014. I see that the LLVM dialect's LLVM_SwitchOp has the same inconsistency, but I don't see any other place that it's referenced directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to leave it as caseValues

ArrayRef<ValueRange> caseOperands) {

std::vector<mlir::Attribute> caseValuesAttrs;
for (auto &val : caseValues) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto &val : caseValues) {
for (APInt &val : caseValues) {

Also, braces aren't needed here.

{
cir::YieldOp switchYield = nullptr;
// Clear switch operation.
for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
for (mlir::Block &block : llvm::make_early_inc_range(op.getBody().getBlocks()))

mlir::ValueRange defaultOperands = exitBlock->getArguments();

// Digest the case statements values and bodies.
for (auto caseOp : cases) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto caseOp : cases) {
for (cir::CaseOp caseOp : cases) {

case cir::CaseOpKind::Anyof:
case cir::CaseOpKind::Equal:
// AnyOf cases kind can have multiple values, hence the loop below.
for (auto &value : caseOp.getValue()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use auto here.

}

// Remove all cases since we've inlined the regions.
for (auto caseOp : cases) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use auto here.

rewriter.eraseOp(caseOp);
}

for (size_t index = 0; index < rangeValues.size(); ++index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use llvm::zip to make this a range-for loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used llvm::enumerate instead because we need the index in the loop, which llvm::zip doesn't provide directly.

// If range is small, add multiple switch instruction cases.
// This magical number is from the original CGStmt code.
constexpr int kSmallRangeThreshold = 64;
if ((upperBound - lowerBound)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test case for a range that is within this threshold?

Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also missing a end-to-end test that proves we generate switchflatop if we are coming all the way from source code. There are examples in the incubator on how to check if a pass runs (you can use print-before/after) and the IR for it.

// CHECK-NEXT: ^bb1:
// CHECK-NEXT: cir.br ^bb2
// CHECK-NEXT: ^bb2:
//CHECK-NEXT: cir.return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Odd indentation compared with other check lines

});

// Track fallthrough in cases.
for (auto &blk : region.getBlocks()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be Block &blk

@Andres-Salamanca
Copy link
Contributor Author

This is also missing a end-to-end test that proves we generate switchflatop if we are coming all the way from source code. There are examples in the incubator on how to check if a pass runs (you can use print-before/after) and the IR for it.

In the incubator, the examples use %clang_cc1 -mmlir --mlir-print-ir-before=cir-lowering-prepare, but since that flag isn't available in our current setup, I followed this pipeline instead:

clang_cc1 (emit-cir) → cir-opt (cir-flatten-cfg with --mlir-print-ir-before/after)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants