-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[CIR][LLVMLowering] Upstream unary operators for VectorType #139444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds support for unary ops for VectorType Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/139444.diff 3 Files Affected:
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5986655ababe9..fb55b2984daea 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -54,10 +54,10 @@ namespace direct {
namespace {
/// If the given type is a vector type, return the vector's element type.
/// Otherwise return the given type unchanged.
-// TODO(cir): Return the vector element type once we have support for vectors
-// instead of the identity type.
mlir::Type elementTypeIfVector(mlir::Type type) {
- assert(!cir::MissingFeatures::vectorType());
+ if (const auto vecType = mlir::dyn_cast<cir::VectorType>(type)) {
+ return vecType.getElementType();
+ }
return type;
}
} // namespace
@@ -1043,12 +1043,11 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
mlir::ConversionPatternRewriter &rewriter) const {
assert(op.getType() == op.getInput().getType() &&
"Unary operation's operand type and result type are different");
- mlir::Type type = op.getType();
- mlir::Type elementType = type;
- bool isVector = false;
- assert(!cir::MissingFeatures::vectorType());
- mlir::Type llvmType = getTypeConverter()->convertType(type);
- mlir::Location loc = op.getLoc();
+ const mlir::Type type = op.getType();
+ const mlir::Type elementType = elementTypeIfVector(type);
+ const bool isVector = mlir::isa<cir::VectorType>(type);
+ const mlir::Type llvmType = getTypeConverter()->convertType(type);
+ const mlir::Location loc = op.getLoc();
// Integer unary operations: + - ~ ++ --
if (mlir::isa<cir::IntType>(elementType)) {
@@ -1076,20 +1075,41 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
rewriter.replaceOp(op, adaptor.getInput());
return mlir::success();
case cir::UnaryOpKind::Minus: {
- assert(!isVector &&
- "Add vector handling when vector types are supported");
- mlir::LLVM::ConstantOp zero = rewriter.create<mlir::LLVM::ConstantOp>(
- loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
+ mlir::Value zero;
+ if (isVector)
+ zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
+ else
+ zero = rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
op, llvmType, zero, adaptor.getInput(), maybeNSW);
return mlir::success();
}
case cir::UnaryOpKind::Not: {
// bit-wise compliment operator, implemented as an XOR with -1.
- assert(!isVector &&
- "Add vector handling when vector types are supported");
- mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
- loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
+ mlir::Value minusOne;
+ if (isVector) {
+ // Creating a vector object with all -1 values is easier said than
+ // done. It requires a series of insertelement ops.
+ const mlir::Type llvmElementType =
+ getTypeConverter()->convertType(elementType);
+ const mlir::Value minusOneInt = rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, llvmElementType, mlir::IntegerAttr::get(llvmElementType, -1));
+ minusOne = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmType);
+
+ const uint64_t numElements =
+ mlir::dyn_cast<cir::VectorType>(type).getSize();
+ for (uint64_t i = 0; i < numElements; ++i) {
+ const mlir::Value indexValue =
+ rewriter.create<mlir::LLVM::ConstantOp>(loc,
+ rewriter.getI64Type(), i);
+ minusOne = rewriter.create<mlir::LLVM::InsertElementOp>(
+ loc, minusOne, minusOneInt, indexValue);
+ }
+ } else {
+ minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
+ }
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(
op, llvmType, adaptor.getInput(), minusOne);
return mlir::success();
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 0756497bf6b96..005f629b88143 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -213,3 +213,59 @@ void foo4() {
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+void foo8() {
+ vi4 a = { 1, 2, 3, 4 };
+ vi4 plus_res = +a;
+ vi4 minus_res = -a;
+ vi4 not_res = ~a;
+}
+
+// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[PLUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["plus_res", init]
+// CIR: %[[MINUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["minus_res", init]
+// CIR: %[[NOT_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["not_res", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP1:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[PLUS:.*]] = cir.unary(plus, %[[TMP1]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+// CIR: cir.store %[[PLUS]], %[[PLUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP2:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[MINUS:.*]] = cir.unary(minus, %[[TMP2]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+// CIR: cir.store %[[MINUS]], %[[MINUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP3:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[NOT:.*]] = cir.unary(not, %[[TMP3]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+// CIR: cir.store %[[NOT]], %[[NOT_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+
+// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[PLUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[MINUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[NOT_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// LLVM: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
+// LLVM: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
+// LLVM: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
+// LLVM: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
+// LLVM: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
+
+// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[PLUS_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[MINUS_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[NOT_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// OGCG: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
+// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
+// OGCG: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
+// OGCG: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
+// OGCG: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 530018108c6d9..3f1cf830982e1 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -201,3 +201,59 @@ void foo4() {
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+void foo8() {
+ vi4 a = { 1, 2, 3, 4 };
+ vi4 plus_res = +a;
+ vi4 minus_res = -a;
+ vi4 not_res = ~a;
+}
+
+// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[PLUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["plus_res", init]
+// CIR: %[[MINUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["minus_res", init]
+// CIR: %[[NOT_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["not_res", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP1:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[PLUS:.*]] = cir.unary(plus, %[[TMP1]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+// CIR: cir.store %[[PLUS]], %[[PLUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP2:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[MINUS:.*]] = cir.unary(minus, %[[TMP2]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+// CIR: cir.store %[[MINUS]], %[[MINUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP3:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[NOT:.*]] = cir.unary(not, %[[TMP3]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+// CIR: cir.store %[[NOT]], %[[NOT_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+
+// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[PLUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[MINUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[NOT_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// LLVM: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
+// LLVM: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
+// LLVM: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
+// LLVM: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
+// LLVM: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
+
+// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[PLUS_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[MINUS_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[NOT_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// OGCG: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
+// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
+// OGCG: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
+// OGCG: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
+// OGCG: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after minor nit
mlir::Type elementTypeIfVector(mlir::Type type) { | ||
assert(!cir::MissingFeatures::vectorType()); | ||
if (const auto vecType = mlir::dyn_cast<cir::VectorType>(type)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably don't need the curly braces here!
9fd2f92
to
b462141
Compare
This change adds support for unary ops for VectorType
Issue #136487