Skip to content

[CIR] Upstream shift operators for VectorType #139465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 13, 2025

Conversation

AmrDeveloper
Copy link
Member

This change adds support for shift ops for VectorType

Issue #136487

@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels May 11, 2025
@llvmbot
Copy link
Member

llvmbot commented May 11, 2025

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-clangir

Author: Amr Hesham (AmrDeveloper)

Changes

This change adds support for shift ops for VectorType

Issue #136487


Full diff: https://github.com/llvm/llvm-project/pull/139465.diff

6 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+8-7)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td (+19)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+2-3)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+9-6)
  • (modified) clang/test/CIR/CodeGen/vector-ext.cpp (+65)
  • (modified) clang/test/CIR/CodeGen/vector.cpp (+65)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 7aff5edb88167..b0e593b011109 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1401,18 +1401,19 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
     The `cir.shift` operation performs a bitwise shift, either to the left or to
     the right, based on the first operand. The second operand specifies the
     value to be shifted, and the third operand determines the number of
-    positions by which the shift is applied. Both the second and third operands
-    are required to be integers.
+    positions by which the shift is applied, they must be either all vector of
+    integer type, or all integer type. If they are vectors, each vector element of
+    the shift target is shifted by the corresponding shift amount in
+    the shift amount vector.
 
     ```mlir
-    %7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
+    %res = cir.shift(left, %lhs : !u64i, %amount : !s32i) -> !u64i
+    %new_vec = cir.shift(left, %lhs : !cir.vector<2 x !s32i>, %rhs : !cir.vector<2 x !s32i>) -> !cir.vector<2 x !s32i>
     ```
   }];
 
-  // TODO(cir): Support vectors. CIR_IntType -> CIR_AnyIntOrVecOfInt. Also
-  // update the description above.
-  let results = (outs CIR_IntType:$result);
-  let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
+  let results = (outs CIR_AnyIntOrVecOfInt:$result);
+  let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
                        UnitAttr:$isShiftleft);
 
   let assemblyFormat = [{
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index 00f67e2a03a25..902d6535ff717 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -174,4 +174,23 @@ def CIR_PtrToVoidPtrType
         "$_builder.getType<" # cppType # ">("
         "cir::VoidType::get($_builder.getContext())))">;
 
+//===----------------------------------------------------------------------===//
+// Vector Type predicates
+//===----------------------------------------------------------------------===//
+
+// Vector of integral type
+def IntegerVector : Type<
+    And<[
+      CPred<"::mlir::isa<::cir::VectorType>($_self)">,
+      CPred<"::mlir::isa<::cir::IntType>("
+            "::mlir::cast<::cir::VectorType>($_self).getElementType())">,
+      CPred<"::mlir::cast<::cir::IntType>("
+            "::mlir::cast<::cir::VectorType>($_self).getElementType())"
+            ".isFundamental()">
+    ]>, "!cir.vector of !cir.int"> {
+}
+
+// Any Integer or Vector of Integer Constraints
+def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;
+
 #endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b131edaf403ed..2f7e3496d55a5 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1297,9 +1297,8 @@ OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
 LogicalResult cir::ShiftOp::verify() {
   mlir::Operation *op = getOperation();
   mlir::Type resType = getResult().getType();
-  assert(!cir::MissingFeatures::vectorType());
-  bool isOp0Vec = false;
-  bool isOp1Vec = false;
+  const bool isOp0Vec = mlir::isa<cir::VectorType>(op->getOperand(0).getType());
+  const bool isOp1Vec = mlir::isa<cir::VectorType>(op->getOperand(1).getType());
   if (isOp0Vec != isOp1Vec)
     return emitOpError() << "input types cannot be one vector and one scalar";
   if (isOp1Vec && op->getOperand(1).getType() != resType) {
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5986655ababe9..1951d2e3a3b79 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1376,16 +1376,17 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
   auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());
 
   // Operands could also be vector type
-  assert(!cir::MissingFeatures::vectorType());
+  auto cirAmtVTy = mlir::dyn_cast<cir::VectorType>(op.getAmount().getType());
+  auto cirValVTy = mlir::dyn_cast<cir::VectorType>(op.getValue().getType());
   mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
   mlir::Value amt = adaptor.getAmount();
   mlir::Value val = adaptor.getValue();
 
-  // TODO(cir): Assert for vector types
-  assert((cirValTy && cirAmtTy) &&
+  assert(((cirValTy && cirAmtTy) || (cirAmtVTy && cirValVTy)) &&
          "shift input type must be integer or vector type, otherwise NYI");
 
-  assert((cirValTy == op.getType()) && "inconsistent operands' types NYI");
+  assert((cirValTy == op.getType() || cirValVTy == op.getType()) &&
+         "inconsistent operands' types NYI");
 
   // Ensure shift amount is the same type as the value. Some undefined
   // behavior might occur in the casts below as per [C99 6.5.7.3].
@@ -1399,8 +1400,10 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
   if (op.getIsShiftleft()) {
     rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
   } else {
-    assert(!cir::MissingFeatures::vectorType());
-    bool isUnsigned = !cirValTy.isSigned();
+    const bool isUnsigned =
+        cirValTy
+            ? !cirValTy.isSigned()
+            : !mlir::cast<cir::IntType>(cirValVTy.getElementType()).isSigned();
     if (isUnsigned)
       rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
     else
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 0756497bf6b96..4b3a271587f67 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -213,3 +213,68 @@ void foo4() {
 // OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
 // OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
 // OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+void foo9() {
+  vi4 a = {1, 2, 3, 4};
+  vi4 b = {5, 6, 7, 8};
+
+  vi4 shl = a << b;
+  vi4 shr = a >> b;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
+// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
+// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
+// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 530018108c6d9..cf23a97ac61f2 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -201,3 +201,68 @@ void foo4() {
 // OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
 // OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
 // OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+void foo9() {
+  vi4 a = {1, 2, 3, 4};
+  vi4 b = {5, 6, 7, 8};
+
+  vi4 shl = a << b;
+  vi4 shr = a >> b;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
+// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
+// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
+// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16

Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. I just have a few minor suggestions.

@@ -1401,18 +1401,19 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
The `cir.shift` operation performs a bitwise shift, either to the left or to
the right, based on the first operand. The second operand specifies the
value to be shifted, and the third operand determines the number of
positions by which the shift is applied. Both the second and third operands
are required to be integers.
positions by which the shift is applied, they must be either all vector of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
positions by which the shift is applied, they must be either all vector of
positions by which the shift is applied. They must be either all vector of


```mlir
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
%res = cir.shift(left, %lhs : !u64i, %amount : !s32i) -> !u64i
%new_vec = cir.shift(left, %lhs : !cir.vector<2 x !s32i>, %rhs : !cir.vector<2 x !s32i>) -> !cir.vector<2 x !s32i>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you reformat this to fit in 80 columns?

const bool isUnsigned =
cirValTy
? !cirValTy.isSigned()
: !mlir::cast<cir::IntType>(cirValVTy.getElementType()).isSigned();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense for the vector type to support isSigned()? The instead of having separate cirValTy and cirValVTy (one of which would always be null), we could do something like this above:

if (!cirValTy) {
  cirValTy = mlir::cast<cir::VectorType>(op.getValue().getType());
  assert(mlir::isa<cir::VectorType>(op.getAmount().getType());
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmmm, not sure if we should support isSinged() for Vector because it can be a vector of bool, floats ..etc and also we will need to cast again to get width from Vec here

if (cirAmtTy)
    amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
                         true, cirAmtTy.getWidth(), cirValTy.getWidth());

I will try first to change the code and add bool isAmtVecTy and see if it will become better :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used your suggestion and assigned isUnsigned and getLLVMIntCast in the same branch

@@ -213,3 +213,68 @@ void foo4() {
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4

void foo9() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with a scalar shift value?

vi4 splat_shl = a << 3;

Also, a test with unsigned values and shift-right is needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of them still require upstreaming VecSplatOp, I will upstream it next in another PR and add those test cases there, if that's okey

@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_vec_sh_op branch from 7a480c7 to af65eca Compare May 12, 2025 20:11
Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to add on top of Andy's review. LGTM

@AmrDeveloper AmrDeveloper merged commit 377cb7f into llvm:main May 13, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants