-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[CIR] Upstream shift operators for VectorType #139465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds support for shift ops for VectorType Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/139465.diff 6 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 7aff5edb88167..b0e593b011109 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1401,18 +1401,19 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
The `cir.shift` operation performs a bitwise shift, either to the left or to
the right, based on the first operand. The second operand specifies the
value to be shifted, and the third operand determines the number of
- positions by which the shift is applied. Both the second and third operands
- are required to be integers.
+ positions by which the shift is applied, they must be either all vector of
+ integer type, or all integer type. If they are vectors, each vector element of
+ the shift target is shifted by the corresponding shift amount in
+ the shift amount vector.
```mlir
- %7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
+ %res = cir.shift(left, %lhs : !u64i, %amount : !s32i) -> !u64i
+ %new_vec = cir.shift(left, %lhs : !cir.vector<2 x !s32i>, %rhs : !cir.vector<2 x !s32i>) -> !cir.vector<2 x !s32i>
```
}];
- // TODO(cir): Support vectors. CIR_IntType -> CIR_AnyIntOrVecOfInt. Also
- // update the description above.
- let results = (outs CIR_IntType:$result);
- let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
+ let results = (outs CIR_AnyIntOrVecOfInt:$result);
+ let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
UnitAttr:$isShiftleft);
let assemblyFormat = [{
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index 00f67e2a03a25..902d6535ff717 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -174,4 +174,23 @@ def CIR_PtrToVoidPtrType
"$_builder.getType<" # cppType # ">("
"cir::VoidType::get($_builder.getContext())))">;
+//===----------------------------------------------------------------------===//
+// Vector Type predicates
+//===----------------------------------------------------------------------===//
+
+// Vector of integral type
+def IntegerVector : Type<
+ And<[
+ CPred<"::mlir::isa<::cir::VectorType>($_self)">,
+ CPred<"::mlir::isa<::cir::IntType>("
+ "::mlir::cast<::cir::VectorType>($_self).getElementType())">,
+ CPred<"::mlir::cast<::cir::IntType>("
+ "::mlir::cast<::cir::VectorType>($_self).getElementType())"
+ ".isFundamental()">
+ ]>, "!cir.vector of !cir.int"> {
+}
+
+// Any Integer or Vector of Integer Constraints
+def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;
+
#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b131edaf403ed..2f7e3496d55a5 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1297,9 +1297,8 @@ OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
LogicalResult cir::ShiftOp::verify() {
mlir::Operation *op = getOperation();
mlir::Type resType = getResult().getType();
- assert(!cir::MissingFeatures::vectorType());
- bool isOp0Vec = false;
- bool isOp1Vec = false;
+ const bool isOp0Vec = mlir::isa<cir::VectorType>(op->getOperand(0).getType());
+ const bool isOp1Vec = mlir::isa<cir::VectorType>(op->getOperand(1).getType());
if (isOp0Vec != isOp1Vec)
return emitOpError() << "input types cannot be one vector and one scalar";
if (isOp1Vec && op->getOperand(1).getType() != resType) {
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5986655ababe9..1951d2e3a3b79 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1376,16 +1376,17 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());
// Operands could also be vector type
- assert(!cir::MissingFeatures::vectorType());
+ auto cirAmtVTy = mlir::dyn_cast<cir::VectorType>(op.getAmount().getType());
+ auto cirValVTy = mlir::dyn_cast<cir::VectorType>(op.getValue().getType());
mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
mlir::Value amt = adaptor.getAmount();
mlir::Value val = adaptor.getValue();
- // TODO(cir): Assert for vector types
- assert((cirValTy && cirAmtTy) &&
+ assert(((cirValTy && cirAmtTy) || (cirAmtVTy && cirValVTy)) &&
"shift input type must be integer or vector type, otherwise NYI");
- assert((cirValTy == op.getType()) && "inconsistent operands' types NYI");
+ assert((cirValTy == op.getType() || cirValVTy == op.getType()) &&
+ "inconsistent operands' types NYI");
// Ensure shift amount is the same type as the value. Some undefined
// behavior might occur in the casts below as per [C99 6.5.7.3].
@@ -1399,8 +1400,10 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
if (op.getIsShiftleft()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
} else {
- assert(!cir::MissingFeatures::vectorType());
- bool isUnsigned = !cirValTy.isSigned();
+ const bool isUnsigned =
+ cirValTy
+ ? !cirValTy.isSigned()
+ : !mlir::cast<cir::IntType>(cirValVTy.getElementType()).isSigned();
if (isUnsigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
else
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 0756497bf6b96..4b3a271587f67 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -213,3 +213,68 @@ void foo4() {
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+void foo9() {
+ vi4 a = {1, 2, 3, 4};
+ vi4 b = {5, 6, 7, 8};
+
+ vi4 shl = a << b;
+ vi4 shr = a >> b;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
+// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
+// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
+// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 530018108c6d9..cf23a97ac61f2 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -201,3 +201,68 @@ void foo4() {
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+void foo9() {
+ vi4 a = {1, 2, 3, 4};
+ vi4 b = {5, 6, 7, 8};
+
+ vi4 shl = a << b;
+ vi4 shr = a >> b;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
+// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
+// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
+// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. I just have a few minor suggestions.
@@ -1401,18 +1401,19 @@ def ShiftOp : CIR_Op<"shift", [Pure]> { | |||
The `cir.shift` operation performs a bitwise shift, either to the left or to | |||
the right, based on the first operand. The second operand specifies the | |||
value to be shifted, and the third operand determines the number of | |||
positions by which the shift is applied. Both the second and third operands | |||
are required to be integers. | |||
positions by which the shift is applied, they must be either all vector of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
positions by which the shift is applied, they must be either all vector of | |
positions by which the shift is applied. They must be either all vector of |
|
||
```mlir | ||
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i | ||
%res = cir.shift(left, %lhs : !u64i, %amount : !s32i) -> !u64i | ||
%new_vec = cir.shift(left, %lhs : !cir.vector<2 x !s32i>, %rhs : !cir.vector<2 x !s32i>) -> !cir.vector<2 x !s32i> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you reformat this to fit in 80 columns?
const bool isUnsigned = | ||
cirValTy | ||
? !cirValTy.isSigned() | ||
: !mlir::cast<cir::IntType>(cirValVTy.getElementType()).isSigned(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense for the vector type to support isSigned()
? The instead of having separate cirValTy and cirValVTy (one of which would always be null), we could do something like this above:
if (!cirValTy) {
cirValTy = mlir::cast<cir::VectorType>(op.getValue().getType());
assert(mlir::isa<cir::VectorType>(op.getAmount().getType());
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmmm, not sure if we should support isSinged() for Vector because it can be a vector of bool, floats ..etc and also we will need to cast again to get width from Vec here
if (cirAmtTy)
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
true, cirAmtTy.getWidth(), cirValTy.getWidth());
I will try first to change the code and add bool isAmtVecTy and see if it will become better :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used your suggestion and assigned isUnsigned and getLLVMIntCast in the same branch
@@ -213,3 +213,68 @@ void foo4() { | |||
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4 | |||
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]] | |||
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4 | |||
|
|||
void foo9() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test with a scalar shift value?
vi4 splat_shl = a << 3;
Also, a test with unsigned values and shift-right is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both of them still require upstreaming VecSplatOp
, I will upstream it next in another PR and add those test cases there, if that's okey
7a480c7
to
af65eca
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing to add on top of Andy's review. LGTM
This change adds support for shift ops for VectorType
Issue #136487