-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[CIR] Upstream support for FlattenCFG switch and SwitchFlatOp #139154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-clangir Author: None (Andres-Salamanca) ChangesThis PR adds support for the Patch is 28.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139154.diff 6 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 7ffa10464dcd3..914af6d1dc6bd 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -971,6 +971,52 @@ def SwitchOp : CIR_Op<"switch",
}];
}
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
+ Terminator]> {
+
+ let description = [{
+ The `cir.switch.flat` operation is a region-less and simplified
+ version of the `cir.switch`.
+ It's representation is closer to LLVM IR dialect
+ than the C/C++ language feature.
+ }];
+
+ let arguments = (ins
+ CIR_IntType:$condition,
+ Variadic<AnyType>:$defaultOperands,
+ VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
+ ArrayAttr:$case_values,
+ DenseI32ArrayAttr:$case_operand_segments
+ );
+
+ let successors = (successor
+ AnySuccessor:$defaultDestination,
+ VariadicSuccessor<AnySuccessor>:$caseDestinations
+ );
+
+ let assemblyFormat = [{
+ $condition `:` type($condition) `,`
+ $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
+ custom<SwitchFlatOpCases>(ref(type($condition)), $case_values,
+ $caseDestinations, $caseOperands,
+ type($caseOperands))
+ attr-dict
+ }];
+
+ let builders = [
+ OpBuilder<(ins "mlir::Value":$condition,
+ "mlir::Block *":$defaultDestination,
+ "mlir::ValueRange":$defaultOperands,
+ CArg<"llvm::ArrayRef<llvm::APInt>", "{}">:$caseValues,
+ CArg<"mlir::BlockRange", "{}">:$caseDestinations,
+ CArg<"llvm::ArrayRef<mlir::ValueRange>", "{}">:$caseOperands)>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// BrOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b131edaf403ed..ca03013edb485 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -22,6 +22,7 @@
#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
#include "clang/CIR/MissingFeatures.h"
+#include <numeric>
using namespace mlir;
using namespace cir;
@@ -962,6 +963,102 @@ bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
});
}
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
+ Value value, Block *defaultDestination,
+ ValueRange defaultOperands,
+ ArrayRef<APInt> caseValues,
+ BlockRange caseDestinations,
+ ArrayRef<ValueRange> caseOperands) {
+
+ std::vector<mlir::Attribute> caseValuesAttrs;
+ for (auto &val : caseValues) {
+ caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
+ }
+ mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
+
+ build(builder, result, value, defaultOperands, caseOperands, attrs,
+ defaultDestination, caseDestinations);
+}
+
+/// <cases> ::= `[` (case (`,` case )* )? `]`
+/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
+static ParseResult parseSwitchFlatOpCases(
+ OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
+ SmallVectorImpl<Block *> &caseDestinations,
+ SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
+ &caseOperands,
+ SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
+ if (failed(parser.parseLSquare()))
+ return failure();
+ if (succeeded(parser.parseOptionalRSquare()))
+ return success();
+ llvm::SmallVector<mlir::Attribute> values;
+
+ auto parseCase = [&]() {
+ int64_t value = 0;
+ if (failed(parser.parseInteger(value)))
+ return failure();
+
+ values.push_back(cir::IntAttr::get(flagType, value));
+
+ Block *destination;
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<Type> operandTypes;
+ if (parser.parseColon() || parser.parseSuccessor(destination))
+ return failure();
+ if (!parser.parseOptionalLParen()) {
+ if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
+ /*allowResultNumber=*/false) ||
+ parser.parseColonTypeList(operandTypes) || parser.parseRParen())
+ return failure();
+ }
+ caseDestinations.push_back(destination);
+ caseOperands.emplace_back(operands);
+ caseOperandTypes.emplace_back(operandTypes);
+ return success();
+ };
+ if (failed(parser.parseCommaSeparatedList(parseCase)))
+ return failure();
+
+ caseValues = ArrayAttr::get(flagType.getContext(), values);
+
+ return parser.parseRSquare();
+}
+
+static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
+ Type flagType, mlir::ArrayAttr caseValues,
+ SuccessorRange caseDestinations,
+ OperandRangeRange caseOperands,
+ const TypeRangeRange &caseOperandTypes) {
+ p << '[';
+ p.printNewline();
+ if (!caseValues) {
+ p << ']';
+ return;
+ }
+
+ size_t index = 0;
+ llvm::interleave(
+ llvm::zip(caseValues, caseDestinations),
+ [&](auto i) {
+ p << " ";
+ mlir::Attribute a = std::get<0>(i);
+ p << mlir::cast<cir::IntAttr>(a).getValue();
+ p << ": ";
+ p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
+ },
+ [&] {
+ p << ',';
+ p.printNewline();
+ });
+ p.printNewline();
+ p << ']';
+}
+
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 3b4c7bc613133..edbb848322d41 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -84,6 +84,19 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
}
};
+struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
+ using OpRewritePattern<SwitchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SwitchOp op,
+ PatternRewriter &rewriter) const final {
+ if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
+ return failure();
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// CIRCanonicalizePass
//===----------------------------------------------------------------------===//
@@ -127,7 +140,7 @@ void CIRCanonicalizePass::runOnOperation() {
assert(!cir::MissingFeatures::callOp());
// CastOp and UnaryOp are here to perform a manual `fold` in
// applyOpPatternsGreedily.
- if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
+ if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 4a936d33b022a..70f383b556567 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
}
};
+class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
+public:
+ using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
+
+ inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
+ cir::YieldOp yieldOp,
+ mlir::Block *destination) const {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
+ destination);
+ }
+
+ // Return the new defaultDestination block.
+ Block *condBrToRangeDestination(cir::SwitchOp op,
+ mlir::PatternRewriter &rewriter,
+ mlir::Block *rangeDestination,
+ mlir::Block *defaultDestination,
+ const APInt &lowerBound,
+ const APInt &upperBound) const {
+ assert(lowerBound.sle(upperBound) && "Invalid range");
+ mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
+ cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
+ cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
+
+ cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(
+ op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));
+
+ cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(
+ op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
+ cir::BinOp diffValue =
+ rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,
+ op.getCondition(), lowerBoundValue);
+
+ // Use unsigned comparison to check if the condition is in the range.
+ cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(
+ op.getLoc(), uIntType, CastKind::integral, diffValue);
+ cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(
+ op.getLoc(), uIntType, CastKind::integral, rangeLength);
+
+ cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(
+ op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,
+ uDiffValue, uRangeLength);
+ rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,
+ defaultDestination);
+ return resBlock;
+ }
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::SwitchOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ llvm::SmallVector<CaseOp> cases;
+ op.collectCases(cases);
+
+ // Empty switch statement: just erase it.
+ if (cases.empty()) {
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // Create exit block from the next node of cir.switch op.
+ mlir::Block *exitBlock = rewriter.splitBlock(
+ rewriter.getBlock(), op->getNextNode()->getIterator());
+
+ // We lower cir.switch op in the following process:
+ // 1. Inline the region from the switch op after switch op.
+ // 2. Traverse each cir.case op:
+ // a. Record the entry block, block arguments and condition for every
+ // case. b. Inline the case region after the case op.
+ // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
+ // recorded block and conditions.
+
+ // inline everything from switch body between the switch op and the exit
+ // block.
+ {
+ cir::YieldOp switchYield = nullptr;
+ // Clear switch operation.
+ for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
+ if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
+ switchYield = yieldOp;
+
+ assert(!op.getBody().empty());
+ mlir::Block *originalBlock = op->getBlock();
+ mlir::Block *swopBlock =
+ rewriter.splitBlock(originalBlock, op->getIterator());
+ rewriter.inlineRegionBefore(op.getBody(), exitBlock);
+
+ if (switchYield)
+ rewriteYieldOp(rewriter, switchYield, exitBlock);
+
+ rewriter.setInsertionPointToEnd(originalBlock);
+ rewriter.create<cir::BrOp>(op.getLoc(), swopBlock);
+ }
+
+ // Allocate required data structures (disconsider default case in
+ // vectors).
+ llvm::SmallVector<mlir::APInt, 8> caseValues;
+ llvm::SmallVector<mlir::Block *, 8> caseDestinations;
+ llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
+
+ llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
+ llvm::SmallVector<mlir::Block *> rangeDestinations;
+ llvm::SmallVector<mlir::ValueRange> rangeOperands;
+
+ // Initialize default case as optional.
+ mlir::Block *defaultDestination = exitBlock;
+ mlir::ValueRange defaultOperands = exitBlock->getArguments();
+
+ // Digest the case statements values and bodies.
+ for (auto caseOp : cases) {
+ mlir::Region ®ion = caseOp.getCaseRegion();
+
+ // Found default case: save destination and operands.
+ switch (caseOp.getKind()) {
+ case cir::CaseOpKind::Default:
+ defaultDestination = ®ion.front();
+ defaultOperands = defaultDestination->getArguments();
+ break;
+ case cir::CaseOpKind::Range:
+ assert(caseOp.getValue().size() == 2 &&
+ "Case range should have 2 case value");
+ rangeValues.push_back(
+ {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
+ cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
+ rangeDestinations.push_back(®ion.front());
+ rangeOperands.push_back(rangeDestinations.back()->getArguments());
+ break;
+ case cir::CaseOpKind::Anyof:
+ case cir::CaseOpKind::Equal:
+ // AnyOf cases kind can have multiple values, hence the loop below.
+ for (auto &value : caseOp.getValue()) {
+ caseValues.push_back(cast<cir::IntAttr>(value).getValue());
+ caseDestinations.push_back(®ion.front());
+ caseOperands.push_back(caseDestinations.back()->getArguments());
+ }
+ break;
+ }
+
+ // Handle break statements.
+ walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
+ region, [&](mlir::Operation *op) {
+ if (!isa<cir::BreakOp>(op))
+ return mlir::WalkResult::advance();
+
+ lowerTerminator(op, exitBlock, rewriter);
+ return mlir::WalkResult::skip();
+ });
+
+ // Track fallthrough in cases.
+ for (auto &blk : region.getBlocks()) {
+ if (blk.getNumSuccessors())
+ continue;
+
+ if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
+ mlir::Operation *nextOp = caseOp->getNextNode();
+ assert(nextOp && "caseOp is not expected to be the last op");
+ mlir::Block *oldBlock = nextOp->getBlock();
+ mlir::Block *newBlock =
+ rewriter.splitBlock(oldBlock, nextOp->getIterator());
+ rewriter.setInsertionPointToEnd(oldBlock);
+ rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(),
+ newBlock);
+ rewriteYieldOp(rewriter, yieldOp, newBlock);
+ }
+ }
+
+ mlir::Block *oldBlock = caseOp->getBlock();
+ mlir::Block *newBlock =
+ rewriter.splitBlock(oldBlock, caseOp->getIterator());
+
+ mlir::Block &entryBlock = caseOp.getCaseRegion().front();
+ rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
+
+ // Create a branch to the entry of the inlined region.
+ rewriter.setInsertionPointToEnd(oldBlock);
+ rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock);
+ }
+
+ // Remove all cases since we've inlined the regions.
+ for (auto caseOp : cases) {
+ mlir::Block *caseBlock = caseOp->getBlock();
+ // Erase the block with no predecessors here to make the generated code
+ // simpler a little bit.
+ if (caseBlock->hasNoPredecessors())
+ rewriter.eraseBlock(caseBlock);
+ else
+ rewriter.eraseOp(caseOp);
+ }
+
+ for (size_t index = 0; index < rangeValues.size(); ++index) {
+ APInt lowerBound = rangeValues[index].first;
+ APInt upperBound = rangeValues[index].second;
+
+ // The case range is unreachable, skip it.
+ if (lowerBound.sgt(upperBound))
+ continue;
+
+ // If range is small, add multiple switch instruction cases.
+ // This magical number is from the original CGStmt code.
+ constexpr int kSmallRangeThreshold = 64;
+ if ((upperBound - lowerBound)
+ .ult(llvm::APInt(32, kSmallRangeThreshold))) {
+ for (APInt iValue = lowerBound; iValue.sle(upperBound);
+ (void)iValue++) {
+ caseValues.push_back(iValue);
+ caseOperands.push_back(rangeOperands[index]);
+ caseDestinations.push_back(rangeDestinations[index]);
+ }
+ continue;
+ }
+
+ defaultDestination =
+ condBrToRangeDestination(op, rewriter, rangeDestinations[index],
+ defaultDestination, lowerBound, upperBound);
+ defaultOperands = rangeOperands[index];
+ }
+
+ // Set switch op to branch to the newly created blocks.
+ rewriter.setInsertionPoint(op);
+ rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
+ op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
+ caseDestinations, caseOperands);
+
+ return mlir::success();
+ }
+};
+
class CIRLoopOpInterfaceFlattening
: public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
public:
@@ -306,9 +532,10 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
};
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
- patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
- CIRScopeOpFlattening, CIRTernaryOpFlattening>(
- patterns.getContext());
+ patterns
+ .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
+ CIRSwitchOpFlattening, CIRTernaryOpFlattening>(
+ patterns.getContext());
}
void CIRFlattenCFGPass::runOnOperation() {
@@ -321,7 +548,7 @@ void CIRFlattenCFGPass::runOnOperation() {
assert(!cir::MissingFeatures::ifOp());
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
- if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
+ if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/IR/switch-flat.cir b/clang/test/CIR/IR/switch-flat.cir
new file mode 100644
index 0000000000000..b072c224b4a2c
--- /dev/null
+++ b/clang/test/CIR/IR/switch-flat.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s | FileCheck %s
+!s32i = !cir.int<s, 32>
+
+cir.func @FlatSwitchWithoutDefault(%arg0: !s32i) {
+ cir.switch.flat %arg0 : !s32i, ^bb2 [
+ 1: ^bb1
+ ]
+ ^bb1:
+ cir.br ^bb2
+ ^bb2:
+ cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT: 1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: cir.br ^bb2
+// CHECK-NEXT: ^bb2:
+//CHECK-NEXT: cir.return
+
+cir.func @FlatSwitchWithDefault(%arg0: !s32i) {
+ cir.switch.flat %arg0 : !s32i, ^bb2 [
+ 1: ^bb1
+ ]
+ ^bb1:
+ cir.br ^bb3
+ ^bb2:
+ cir.br ^bb3
+ ^bb3:
+ cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT: 1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT: cir.return
+
+cir.func @switchWithOperands(%arg0: !s32i, %arg1: !s32i, %arg2: !s32i) {
+ cir.switch.flat %arg0 : !s32i, ^bb3 [
+ 0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+ 1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+ ]
+^bb1:
+ cir.br ^bb3
+
+^bb2:
+ cir.br ^bb3
+
+^bb3:
+ cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb3 [
+// CHECK-NEXT: 0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+// CHECK-NEXT: 1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT: cir.return
diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir
new file mode 100644
index 0000000000000..a05cf37e39728
--- /dev/null
+++ b/clang/te...
[truncated]
|
@llvm/pr-subscribers-clang Author: None (Andres-Salamanca) ChangesThis PR adds support for the Patch is 28.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139154.diff 6 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 7ffa10464dcd3..914af6d1dc6bd 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -971,6 +971,52 @@ def SwitchOp : CIR_Op<"switch",
}];
}
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
+ Terminator]> {
+
+ let description = [{
+ The `cir.switch.flat` operation is a region-less and simplified
+ version of the `cir.switch`.
+ It's representation is closer to LLVM IR dialect
+ than the C/C++ language feature.
+ }];
+
+ let arguments = (ins
+ CIR_IntType:$condition,
+ Variadic<AnyType>:$defaultOperands,
+ VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
+ ArrayAttr:$case_values,
+ DenseI32ArrayAttr:$case_operand_segments
+ );
+
+ let successors = (successor
+ AnySuccessor:$defaultDestination,
+ VariadicSuccessor<AnySuccessor>:$caseDestinations
+ );
+
+ let assemblyFormat = [{
+ $condition `:` type($condition) `,`
+ $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
+ custom<SwitchFlatOpCases>(ref(type($condition)), $case_values,
+ $caseDestinations, $caseOperands,
+ type($caseOperands))
+ attr-dict
+ }];
+
+ let builders = [
+ OpBuilder<(ins "mlir::Value":$condition,
+ "mlir::Block *":$defaultDestination,
+ "mlir::ValueRange":$defaultOperands,
+ CArg<"llvm::ArrayRef<llvm::APInt>", "{}">:$caseValues,
+ CArg<"mlir::BlockRange", "{}">:$caseDestinations,
+ CArg<"llvm::ArrayRef<mlir::ValueRange>", "{}">:$caseOperands)>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// BrOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b131edaf403ed..ca03013edb485 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -22,6 +22,7 @@
#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
#include "clang/CIR/MissingFeatures.h"
+#include <numeric>
using namespace mlir;
using namespace cir;
@@ -962,6 +963,102 @@ bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
});
}
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
+ Value value, Block *defaultDestination,
+ ValueRange defaultOperands,
+ ArrayRef<APInt> caseValues,
+ BlockRange caseDestinations,
+ ArrayRef<ValueRange> caseOperands) {
+
+ std::vector<mlir::Attribute> caseValuesAttrs;
+ for (auto &val : caseValues) {
+ caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
+ }
+ mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
+
+ build(builder, result, value, defaultOperands, caseOperands, attrs,
+ defaultDestination, caseDestinations);
+}
+
+/// <cases> ::= `[` (case (`,` case )* )? `]`
+/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
+static ParseResult parseSwitchFlatOpCases(
+ OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
+ SmallVectorImpl<Block *> &caseDestinations,
+ SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
+ &caseOperands,
+ SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
+ if (failed(parser.parseLSquare()))
+ return failure();
+ if (succeeded(parser.parseOptionalRSquare()))
+ return success();
+ llvm::SmallVector<mlir::Attribute> values;
+
+ auto parseCase = [&]() {
+ int64_t value = 0;
+ if (failed(parser.parseInteger(value)))
+ return failure();
+
+ values.push_back(cir::IntAttr::get(flagType, value));
+
+ Block *destination;
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<Type> operandTypes;
+ if (parser.parseColon() || parser.parseSuccessor(destination))
+ return failure();
+ if (!parser.parseOptionalLParen()) {
+ if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
+ /*allowResultNumber=*/false) ||
+ parser.parseColonTypeList(operandTypes) || parser.parseRParen())
+ return failure();
+ }
+ caseDestinations.push_back(destination);
+ caseOperands.emplace_back(operands);
+ caseOperandTypes.emplace_back(operandTypes);
+ return success();
+ };
+ if (failed(parser.parseCommaSeparatedList(parseCase)))
+ return failure();
+
+ caseValues = ArrayAttr::get(flagType.getContext(), values);
+
+ return parser.parseRSquare();
+}
+
+static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
+ Type flagType, mlir::ArrayAttr caseValues,
+ SuccessorRange caseDestinations,
+ OperandRangeRange caseOperands,
+ const TypeRangeRange &caseOperandTypes) {
+ p << '[';
+ p.printNewline();
+ if (!caseValues) {
+ p << ']';
+ return;
+ }
+
+ size_t index = 0;
+ llvm::interleave(
+ llvm::zip(caseValues, caseDestinations),
+ [&](auto i) {
+ p << " ";
+ mlir::Attribute a = std::get<0>(i);
+ p << mlir::cast<cir::IntAttr>(a).getValue();
+ p << ": ";
+ p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
+ },
+ [&] {
+ p << ',';
+ p.printNewline();
+ });
+ p.printNewline();
+ p << ']';
+}
+
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 3b4c7bc613133..edbb848322d41 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -84,6 +84,19 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
}
};
+struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
+ using OpRewritePattern<SwitchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SwitchOp op,
+ PatternRewriter &rewriter) const final {
+ if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
+ return failure();
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// CIRCanonicalizePass
//===----------------------------------------------------------------------===//
@@ -127,7 +140,7 @@ void CIRCanonicalizePass::runOnOperation() {
assert(!cir::MissingFeatures::callOp());
// CastOp and UnaryOp are here to perform a manual `fold` in
// applyOpPatternsGreedily.
- if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
+ if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 4a936d33b022a..70f383b556567 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
}
};
+class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
+public:
+ using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
+
+ inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
+ cir::YieldOp yieldOp,
+ mlir::Block *destination) const {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
+ destination);
+ }
+
+ // Return the new defaultDestination block.
+ Block *condBrToRangeDestination(cir::SwitchOp op,
+ mlir::PatternRewriter &rewriter,
+ mlir::Block *rangeDestination,
+ mlir::Block *defaultDestination,
+ const APInt &lowerBound,
+ const APInt &upperBound) const {
+ assert(lowerBound.sle(upperBound) && "Invalid range");
+ mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
+ cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
+ cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
+
+ cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(
+ op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));
+
+ cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(
+ op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
+ cir::BinOp diffValue =
+ rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,
+ op.getCondition(), lowerBoundValue);
+
+ // Use unsigned comparison to check if the condition is in the range.
+ cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(
+ op.getLoc(), uIntType, CastKind::integral, diffValue);
+ cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(
+ op.getLoc(), uIntType, CastKind::integral, rangeLength);
+
+ cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(
+ op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,
+ uDiffValue, uRangeLength);
+ rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,
+ defaultDestination);
+ return resBlock;
+ }
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::SwitchOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ llvm::SmallVector<CaseOp> cases;
+ op.collectCases(cases);
+
+ // Empty switch statement: just erase it.
+ if (cases.empty()) {
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // Create exit block from the next node of cir.switch op.
+ mlir::Block *exitBlock = rewriter.splitBlock(
+ rewriter.getBlock(), op->getNextNode()->getIterator());
+
+ // We lower cir.switch op in the following process:
+ // 1. Inline the region from the switch op after switch op.
+ // 2. Traverse each cir.case op:
+ // a. Record the entry block, block arguments and condition for every
+ // case. b. Inline the case region after the case op.
+ // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
+ // recorded block and conditions.
+
+ // inline everything from switch body between the switch op and the exit
+ // block.
+ {
+ cir::YieldOp switchYield = nullptr;
+ // Clear switch operation.
+ for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
+ if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
+ switchYield = yieldOp;
+
+ assert(!op.getBody().empty());
+ mlir::Block *originalBlock = op->getBlock();
+ mlir::Block *swopBlock =
+ rewriter.splitBlock(originalBlock, op->getIterator());
+ rewriter.inlineRegionBefore(op.getBody(), exitBlock);
+
+ if (switchYield)
+ rewriteYieldOp(rewriter, switchYield, exitBlock);
+
+ rewriter.setInsertionPointToEnd(originalBlock);
+ rewriter.create<cir::BrOp>(op.getLoc(), swopBlock);
+ }
+
+ // Allocate required data structures (disconsider default case in
+ // vectors).
+ llvm::SmallVector<mlir::APInt, 8> caseValues;
+ llvm::SmallVector<mlir::Block *, 8> caseDestinations;
+ llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
+
+ llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
+ llvm::SmallVector<mlir::Block *> rangeDestinations;
+ llvm::SmallVector<mlir::ValueRange> rangeOperands;
+
+ // Initialize default case as optional.
+ mlir::Block *defaultDestination = exitBlock;
+ mlir::ValueRange defaultOperands = exitBlock->getArguments();
+
+ // Digest the case statements values and bodies.
+ for (auto caseOp : cases) {
+ mlir::Region ®ion = caseOp.getCaseRegion();
+
+ // Found default case: save destination and operands.
+ switch (caseOp.getKind()) {
+ case cir::CaseOpKind::Default:
+ defaultDestination = ®ion.front();
+ defaultOperands = defaultDestination->getArguments();
+ break;
+ case cir::CaseOpKind::Range:
+ assert(caseOp.getValue().size() == 2 &&
+ "Case range should have 2 case value");
+ rangeValues.push_back(
+ {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
+ cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
+ rangeDestinations.push_back(®ion.front());
+ rangeOperands.push_back(rangeDestinations.back()->getArguments());
+ break;
+ case cir::CaseOpKind::Anyof:
+ case cir::CaseOpKind::Equal:
+ // AnyOf cases kind can have multiple values, hence the loop below.
+ for (auto &value : caseOp.getValue()) {
+ caseValues.push_back(cast<cir::IntAttr>(value).getValue());
+ caseDestinations.push_back(®ion.front());
+ caseOperands.push_back(caseDestinations.back()->getArguments());
+ }
+ break;
+ }
+
+ // Handle break statements.
+ walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
+ region, [&](mlir::Operation *op) {
+ if (!isa<cir::BreakOp>(op))
+ return mlir::WalkResult::advance();
+
+ lowerTerminator(op, exitBlock, rewriter);
+ return mlir::WalkResult::skip();
+ });
+
+ // Track fallthrough in cases.
+ for (auto &blk : region.getBlocks()) {
+ if (blk.getNumSuccessors())
+ continue;
+
+ if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
+ mlir::Operation *nextOp = caseOp->getNextNode();
+ assert(nextOp && "caseOp is not expected to be the last op");
+ mlir::Block *oldBlock = nextOp->getBlock();
+ mlir::Block *newBlock =
+ rewriter.splitBlock(oldBlock, nextOp->getIterator());
+ rewriter.setInsertionPointToEnd(oldBlock);
+ rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(),
+ newBlock);
+ rewriteYieldOp(rewriter, yieldOp, newBlock);
+ }
+ }
+
+ mlir::Block *oldBlock = caseOp->getBlock();
+ mlir::Block *newBlock =
+ rewriter.splitBlock(oldBlock, caseOp->getIterator());
+
+ mlir::Block &entryBlock = caseOp.getCaseRegion().front();
+ rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
+
+ // Create a branch to the entry of the inlined region.
+ rewriter.setInsertionPointToEnd(oldBlock);
+ rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock);
+ }
+
+ // Remove all cases since we've inlined the regions.
+ for (auto caseOp : cases) {
+ mlir::Block *caseBlock = caseOp->getBlock();
+ // Erase the block with no predecessors here to make the generated code
+ // simpler a little bit.
+ if (caseBlock->hasNoPredecessors())
+ rewriter.eraseBlock(caseBlock);
+ else
+ rewriter.eraseOp(caseOp);
+ }
+
+ for (size_t index = 0; index < rangeValues.size(); ++index) {
+ APInt lowerBound = rangeValues[index].first;
+ APInt upperBound = rangeValues[index].second;
+
+ // The case range is unreachable, skip it.
+ if (lowerBound.sgt(upperBound))
+ continue;
+
+ // If range is small, add multiple switch instruction cases.
+ // This magical number is from the original CGStmt code.
+ constexpr int kSmallRangeThreshold = 64;
+ if ((upperBound - lowerBound)
+ .ult(llvm::APInt(32, kSmallRangeThreshold))) {
+ for (APInt iValue = lowerBound; iValue.sle(upperBound);
+ (void)iValue++) {
+ caseValues.push_back(iValue);
+ caseOperands.push_back(rangeOperands[index]);
+ caseDestinations.push_back(rangeDestinations[index]);
+ }
+ continue;
+ }
+
+ defaultDestination =
+ condBrToRangeDestination(op, rewriter, rangeDestinations[index],
+ defaultDestination, lowerBound, upperBound);
+ defaultOperands = rangeOperands[index];
+ }
+
+ // Set switch op to branch to the newly created blocks.
+ rewriter.setInsertionPoint(op);
+ rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
+ op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
+ caseDestinations, caseOperands);
+
+ return mlir::success();
+ }
+};
+
class CIRLoopOpInterfaceFlattening
: public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
public:
@@ -306,9 +532,10 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
};
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
- patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
- CIRScopeOpFlattening, CIRTernaryOpFlattening>(
- patterns.getContext());
+ patterns
+ .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
+ CIRSwitchOpFlattening, CIRTernaryOpFlattening>(
+ patterns.getContext());
}
void CIRFlattenCFGPass::runOnOperation() {
@@ -321,7 +548,7 @@ void CIRFlattenCFGPass::runOnOperation() {
assert(!cir::MissingFeatures::ifOp());
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
- if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
+ if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/IR/switch-flat.cir b/clang/test/CIR/IR/switch-flat.cir
new file mode 100644
index 0000000000000..b072c224b4a2c
--- /dev/null
+++ b/clang/test/CIR/IR/switch-flat.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s | FileCheck %s
+!s32i = !cir.int<s, 32>
+
+cir.func @FlatSwitchWithoutDefault(%arg0: !s32i) {
+ cir.switch.flat %arg0 : !s32i, ^bb2 [
+ 1: ^bb1
+ ]
+ ^bb1:
+ cir.br ^bb2
+ ^bb2:
+ cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT: 1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: cir.br ^bb2
+// CHECK-NEXT: ^bb2:
+//CHECK-NEXT: cir.return
+
+cir.func @FlatSwitchWithDefault(%arg0: !s32i) {
+ cir.switch.flat %arg0 : !s32i, ^bb2 [
+ 1: ^bb1
+ ]
+ ^bb1:
+ cir.br ^bb3
+ ^bb2:
+ cir.br ^bb3
+ ^bb3:
+ cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT: 1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT: cir.return
+
+cir.func @switchWithOperands(%arg0: !s32i, %arg1: !s32i, %arg2: !s32i) {
+ cir.switch.flat %arg0 : !s32i, ^bb3 [
+ 0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+ 1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+ ]
+^bb1:
+ cir.br ^bb3
+
+^bb2:
+ cir.br ^bb3
+
+^bb3:
+ cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb3 [
+// CHECK-NEXT: 0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+// CHECK-NEXT: 1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT: cir.return
diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir
new file mode 100644
index 0000000000000..a05cf37e39728
--- /dev/null
+++ b/clang/te...
[truncated]
|
let description = [{ | ||
The `cir.switch.flat` operation is a region-less and simplified | ||
version of the `cir.switch`. | ||
It's representation is closer to LLVM IR dialect |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's representation is closer to LLVM IR dialect | |
Its representation is closer to LLVM IR dialect |
Also, the formatting is odd here. Can you clean up the word wrap locations?
CIR_IntType:$condition, | ||
Variadic<AnyType>:$defaultOperands, | ||
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands, | ||
ArrayAttr:$case_values, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an inconsistency between case_values
here and on line 1004, and caseValues
on line 1014. I see that the LLVM dialect's LLVM_SwitchOp
has the same inconsistency, but I don't see any other place that it's referenced directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to leave it as caseValues
ArrayRef<ValueRange> caseOperands) { | ||
|
||
std::vector<mlir::Attribute> caseValuesAttrs; | ||
for (auto &val : caseValues) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (auto &val : caseValues) { | |
for (APInt &val : caseValues) { |
Also, braces aren't needed here.
{ | ||
cir::YieldOp switchYield = nullptr; | ||
// Clear switch operation. | ||
for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks())) | |
for (mlir::Block &block : llvm::make_early_inc_range(op.getBody().getBlocks())) |
mlir::ValueRange defaultOperands = exitBlock->getArguments(); | ||
|
||
// Digest the case statements values and bodies. | ||
for (auto caseOp : cases) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (auto caseOp : cases) { | |
for (cir::CaseOp caseOp : cases) { |
case cir::CaseOpKind::Anyof: | ||
case cir::CaseOpKind::Equal: | ||
// AnyOf cases kind can have multiple values, hence the loop below. | ||
for (auto &value : caseOp.getValue()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use auto
here.
} | ||
|
||
// Remove all cases since we've inlined the regions. | ||
for (auto caseOp : cases) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use auto
here.
rewriter.eraseOp(caseOp); | ||
} | ||
|
||
for (size_t index = 0; index < rangeValues.size(); ++index) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use llvm::zip to make this a range-for loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used llvm::enumerate instead because we need the index in the loop, which llvm::zip doesn't provide directly.
// If range is small, add multiple switch instruction cases. | ||
// This magical number is from the original CGStmt code. | ||
constexpr int kSmallRangeThreshold = 64; | ||
if ((upperBound - lowerBound) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test case for a range that is within this threshold?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also missing a end-to-end test that proves we generate switchflatop if we are coming all the way from source code. There are examples in the incubator on how to check if a pass runs (you can use print-before/after) and the IR for it.
clang/test/CIR/IR/switch-flat.cir
Outdated
// CHECK-NEXT: ^bb1: | ||
// CHECK-NEXT: cir.br ^bb2 | ||
// CHECK-NEXT: ^bb2: | ||
//CHECK-NEXT: cir.return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Odd indentation compared with other check lines
}); | ||
|
||
// Track fallthrough in cases. | ||
for (auto &blk : region.getBlocks()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be Block &blk
In the incubator, the examples use
|
This PR adds support for the
FlattenCFG
transformation onswitch
statements. It also introduces theSwitchFlatOp
, which is necessary for subsequent lowering to LLVM.