Skip to content

[CIR] Upstream shift operators for VectorType #139465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1401,18 +1401,20 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
The `cir.shift` operation performs a bitwise shift, either to the left or to
the right, based on the first operand. The second operand specifies the
value to be shifted, and the third operand determines the number of
positions by which the shift is applied. Both the second and third operands
are required to be integers.
positions by which the shift is applied, They must be either all vector of
integer type, or all integer type. If they are vectors, each vector element of
the shift target is shifted by the corresponding shift amount in
the shift amount vector.

```mlir
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
%res = cir.shift(left, %lhs : !u64i, %amount : !s32i) -> !u64i
%new_vec = cir.shift(left, %lhs : !cir.vector<2 x !s32i>, %rhs :
!cir.vector<2 x !s32i>) -> !cir.vector<2 x !s32i>
```
}];

// TODO(cir): Support vectors. CIR_IntType -> CIR_AnyIntOrVecOfInt. Also
// update the description above.
let results = (outs CIR_IntType:$result);
let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
let results = (outs CIR_AnyIntOrVecOfInt:$result);
let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
UnitAttr:$isShiftleft);

let assemblyFormat = [{
Expand Down
19 changes: 19 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,23 @@ def CIR_PtrToVoidPtrType
"$_builder.getType<" # cppType # ">("
"cir::VoidType::get($_builder.getContext())))">;

//===----------------------------------------------------------------------===//
// Vector Type predicates
//===----------------------------------------------------------------------===//

// Vector of integral type
def IntegerVector : Type<
And<[
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
CPred<"::mlir::isa<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getElementType())">,
CPred<"::mlir::cast<::cir::IntType>("
"::mlir::cast<::cir::VectorType>($_self).getElementType())"
".isFundamental()">
]>, "!cir.vector of !cir.int"> {
}

// Any Integer or Vector of Integer Constraints
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;

#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
5 changes: 2 additions & 3 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1297,9 +1297,8 @@ OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
LogicalResult cir::ShiftOp::verify() {
mlir::Operation *op = getOperation();
mlir::Type resType = getResult().getType();
assert(!cir::MissingFeatures::vectorType());
bool isOp0Vec = false;
bool isOp1Vec = false;
const bool isOp0Vec = mlir::isa<cir::VectorType>(op->getOperand(0).getType());
const bool isOp1Vec = mlir::isa<cir::VectorType>(op->getOperand(1).getType());
if (isOp0Vec != isOp1Vec)
return emitOpError() << "input types cannot be one vector and one scalar";
if (isOp1Vec && op->getOperand(1).getType() != resType) {
Expand Down
51 changes: 26 additions & 25 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1372,41 +1372,42 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
cir::ShiftOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());
assert((op.getValue().getType() == op.getType()) &&
"inconsistent operands' types NYI");

// Operands could also be vector type
assert(!cir::MissingFeatures::vectorType());
mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
mlir::Value amt = adaptor.getAmount();
mlir::Value val = adaptor.getValue();

// TODO(cir): Assert for vector types
assert((cirValTy && cirAmtTy) &&
"shift input type must be integer or vector type, otherwise NYI");

assert((cirValTy == op.getType()) && "inconsistent operands' types NYI");

// Ensure shift amount is the same type as the value. Some undefined
// behavior might occur in the casts below as per [C99 6.5.7.3].
// Vector type shift amount needs no cast as type consistency is expected to
// be already be enforced at CIRGen.
if (cirAmtTy)
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
true, cirAmtTy.getWidth(), cirValTy.getWidth());
auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
bool isUnsigned;
if (cirAmtTy) {
auto cirValTy = mlir::cast<cir::IntType>(op.getValue().getType());
isUnsigned = cirValTy.isUnsigned();

// Ensure shift amount is the same type as the value. Some undefined
// behavior might occur in the casts below as per [C99 6.5.7.3].
// Vector type shift amount needs no cast as type consistency is expected to
// be already be enforced at CIRGen.
if (cirAmtTy)
amt = getLLVMIntCast(rewriter, amt, llvmTy, true, cirAmtTy.getWidth(),
cirValTy.getWidth());
} else {
auto cirValVTy = mlir::cast<cir::VectorType>(op.getValue().getType());
isUnsigned =
mlir::cast<cir::IntType>(cirValVTy.getElementType()).isUnsigned();
}

// Lower to the proper LLVM shift operation.
if (op.getIsShiftleft()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
} else {
assert(!cir::MissingFeatures::vectorType());
bool isUnsigned = !cirValTy.isSigned();
if (isUnsigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
return mlir::success();
}

if (isUnsigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
return mlir::success();
}

Expand Down
65 changes: 65 additions & 0 deletions clang/test/CIR/CodeGen/vector-ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,68 @@ void foo7() {
// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP2]], i32 %[[RES]], i32 2
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16

void foo9() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with a scalar shift value?

vi4 splat_shl = a << 3;

Also, a test with unsigned values and shift-right is needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of them still require upstreaming VecSplatOp, I will upstream it next in another PR and add those test cases there, if that's okey

vi4 a = {1, 2, 3, 4};
vi4 b = {5, 6, 7, 8};

vi4 shl = a << b;
vi4 shr = a >> b;
}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
65 changes: 65 additions & 0 deletions clang/test/CIR/CodeGen/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,68 @@ void foo7() {
// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP2]], i32 %[[RES]], i32 2
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16

void foo9() {
vi4 a = {1, 2, 3, 4};
vi4 b = {5, 6, 7, 8};

vi4 shl = a << b;
vi4 shr = a >> b;
}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16