Skip to content

[CIR] Add support for indirect calls #139748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Lancern
Copy link
Member

@Lancern Lancern commented May 13, 2025

This PR adds support for indirect calls to the cir.call operation.

@Lancern Lancern requested a review from andykaylor May 13, 2025 15:03
@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels May 13, 2025
@llvmbot
Copy link
Member

llvmbot commented May 13, 2025

@llvm/pr-subscribers-clangir

@llvm/pr-subscribers-clang

Author: Sirui Mu (Lancern)

Changes

This PR adds support for indirect calls to the cir.call operation.


Patch is 21.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139748.diff

12 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h (+8)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+31-15)
  • (modified) clang/include/clang/CIR/MissingFeatures.h (-1)
  • (modified) clang/lib/CIR/CodeGen/CIRGenCall.cpp (+49-6)
  • (modified) clang/lib/CIR/CodeGen/CIRGenCall.h (+10-1)
  • (modified) clang/lib/CIR/CodeGen/CIRGenExpr.cpp (+22-2)
  • (modified) clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h (+13)
  • (modified) clang/lib/CIR/CodeGen/CIRGenTypes.h (+3)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+42-7)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+17-2)
  • (modified) clang/test/CIR/CodeGen/call.cpp (+14)
  • (modified) clang/test/CIR/IR/call.cir (+14)
diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index a63bf4f8858d0..b680e4162a5ce 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -225,6 +225,14 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
                         callee.getFunctionType().getReturnType(), operands);
   }
 
+  cir::CallOp createIndirectCallOp(mlir::Location loc,
+                                   mlir::Value indirectTarget,
+                                   cir::FuncType funcType,
+                                   mlir::ValueRange operands) {
+    return create<cir::CallOp>(loc, indirectTarget, funcType.getReturnType(),
+                               operands);
+  }
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 7aff5edb88167..f08818d0e82d1 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1796,13 +1796,8 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
                       DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
   let extraClassDeclaration = [{
     /// Get the argument operands to the called function.
-    mlir::OperandRange getArgOperands() {
-      return getArgs();
-    }
-
-    mlir::MutableOperandRange getArgOperandsMutable() {
-      return getArgsMutable();
-    }
+    mlir::OperandRange getArgOperands();
+    mlir::MutableOperandRange getArgOperandsMutable();
 
     /// Return the callee of this operation
     mlir::CallInterfaceCallable getCallableForCallee() {
@@ -1824,6 +1819,9 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
     ::mlir::Attribute removeArgAttrsAttr() { return {}; }
     ::mlir::Attribute removeResAttrsAttr() { return {}; }
 
+    bool isIndirect() { return !getCallee(); }
+    mlir::Value getIndirectCall();
+
     void setArg(unsigned index, mlir::Value value) {
       setOperand(index, value);
     }
@@ -1837,16 +1835,24 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
   // the upstreaming process moves on. The verifiers is also missing for now,
   // will add in the future.
 
-  dag commonArgs = (ins FlatSymbolRefAttr:$callee,
-                        Variadic<CIR_AnyType>:$args);
+  dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
+      Variadic<CIR_AnyType>:$args);
 }
 
 def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
   let summary = "call a function";
   let description = [{
-    The `cir.call` operation represents a direct call to a function that is
-    within the same symbol scope as the call. The callee is encoded as a symbol
-    reference attribute named `callee`.
+    The `cir.call` operation represents a function call. It could represent
+    either a direct call or an indirect call.
+
+    If the operation represents a direct call, the callee should be defined
+    within the same symbol scope as the call. The `callee` attribute contains a
+    symbo reference to the callee function. All operands of this operation are
+    arguments to the callee function.
+
+    If the operation represents an indirect call, the `callee` attribute is
+    empty. The first operand of this operation must be a pointer to the callee
+    function. All the rest operands are arguments to the callee function.
 
     Example:
 
@@ -1859,13 +1865,23 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
   let arguments = commonArgs;
 
   let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
-                                 "mlir::Type":$resType,
-                                 "mlir::ValueRange":$operands), [{
+                                "mlir::Type":$resType,
+                                "mlir::ValueRange":$operands),
+                            [{
       $_state.addOperands(operands);
       $_state.addAttribute("callee", callee);
       if (resType && !isa<VoidType>(resType))
         $_state.addTypes(resType);
-    }]>];
+    }]>,
+                  OpBuilder<(ins "mlir::Value":$callee, "mlir::Type":$resType,
+                                "mlir::ValueRange":$operands),
+                            [{
+      $_state.addOperands(callee);
+      $_state.addOperands(operands);
+      if (resType && !isa<VoidType>(resType))
+        $_state.addTypes(resType);
+    }]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index e148a0a636fa5..faf16bdb735bd 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -93,7 +93,6 @@ struct MissingFeatures {
   static bool opCallChainCall() { return false; }
   static bool opCallNoPrototypeFunc() { return false; }
   static bool opCallMustTail() { return false; }
-  static bool opCallIndirect() { return false; }
   static bool opCallVirtual() { return false; }
   static bool opCallInAlloca() { return false; }
   static bool opCallAttrs() { return false; }
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
index 70d45dc383fd1..3f0c54a939be4 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
@@ -39,6 +39,27 @@ CIRGenFunctionInfo::create(CanQualType resultType,
   return fi;
 }
 
+cir::FuncType CIRGenTypes::getFunctionType(const CIRGenFunctionInfo &info) {
+  [[maybe_unused]] bool inserted = functionsBeingProcessed.insert(&info).second;
+  assert(inserted && "Recursively being processed?");
+
+  mlir::Type resultType = convertType(info.getReturnType());
+  SmallVector<mlir::Type, 8> argTypes;
+  argTypes.reserve(info.getNumRequiredArgs());
+
+  // Add in all of the required arguments.
+  for (const CIRGenFunctionInfoArgInfo &argInfo : info.requiredArguments())
+    argTypes.push_back(convertType(argInfo.type));
+
+  [[maybe_unused]] bool erased = functionsBeingProcessed.erase(&info);
+  assert(erased && "Not in set?");
+
+  assert(!cir::MissingFeatures::opCallVariadic());
+  return cir::FuncType::get(argTypes,
+                            (resultType ? resultType : builder.getVoidTy()),
+                            /*isVarArg=*/false);
+}
+
 CIRGenCallee CIRGenCallee::prepareConcreteCallee(CIRGenFunction &cgf) const {
   assert(!cir::MissingFeatures::opCallVirtual());
   return *this;
@@ -75,6 +96,7 @@ CIRGenTypes::arrangeFreeFunctionCall(const CallArgList &args,
 
 static cir::CIRCallOpInterface
 emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
+               cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
                cir::FuncOp directFuncOp,
                const SmallVectorImpl<mlir::Value> &cirCallArgs) {
   CIRGenBuilderTy &builder = cgf.getBuilder();
@@ -83,7 +105,13 @@ emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
   assert(!cir::MissingFeatures::invokeOp());
 
   assert(builder.getInsertionBlock() && "expected valid basic block");
-  assert(!cir::MissingFeatures::opCallIndirect());
+
+  if (indirectFuncTy) {
+    // TODO(cir): Set calling convention for indirect calls.
+    assert(!cir::MissingFeatures::opCallCallConv());
+    return builder.createIndirectCallOp(callLoc, indirectFuncVal,
+                                        indirectFuncTy, cirCallArgs);
+  }
 
   return builder.createCallOp(callLoc, directFuncOp, cirCallArgs);
 }
@@ -95,6 +123,7 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
                                 cir::CIRCallOpInterface *callOp,
                                 mlir::Location loc) {
   QualType retTy = funcInfo.getReturnType();
+  cir::FuncType cirFuncTy = getTypes().getFunctionType(funcInfo);
 
   SmallVector<mlir::Value, 16> cirCallArgs(args.size());
 
@@ -145,12 +174,26 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
 
   assert(!cir::MissingFeatures::invokeOp());
 
-  auto directFuncOp = dyn_cast<cir::FuncOp>(calleePtr);
-  assert(!cir::MissingFeatures::opCallIndirect());
+  cir::FuncType indirectFuncTy;
+  mlir::Value indirectFuncVal;
+  cir::FuncOp directFuncOp;
+  if (auto fnOp = dyn_cast<cir::FuncOp>(calleePtr))
+    directFuncOp = fnOp;
+  else {
+    [[maybe_unused]] auto resultTypes = calleePtr->getResultTypes();
+    [[maybe_unused]] auto funcPtrTy =
+        mlir::dyn_cast<cir::PointerType>(resultTypes.front());
+    assert(funcPtrTy && mlir::isa<cir::FuncType>(funcPtrTy.getPointee()) &&
+           "expected pointer to function");
+
+    indirectFuncTy = cirFuncTy;
+    indirectFuncVal = calleePtr->getResult(0);
+  }
+
   assert(!cir::MissingFeatures::opCallAttrs());
 
-  cir::CIRCallOpInterface theCall =
-      emitCallLikeOp(*this, loc, directFuncOp, cirCallArgs);
+  cir::CIRCallOpInterface theCall = emitCallLikeOp(
+      *this, loc, indirectFuncTy, indirectFuncVal, directFuncOp, cirCallArgs);
 
   if (callOp)
     *callOp = theCall;
@@ -250,7 +293,7 @@ void CIRGenFunction::emitCallArgs(
 
   auto maybeEmitImplicitObjectSize = [&](size_t i, const Expr *arg,
                                          RValue emittedArg) {
-    if (callee.hasFunctionDecl() || i >= callee.getNumParams())
+    if (!callee.hasFunctionDecl() || i >= callee.getNumParams())
       return;
     auto *ps = callee.getParamDecl(i)->getAttr<PassObjectSizeAttr>();
     if (!ps)
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.h b/clang/lib/CIR/CodeGen/CIRGenCall.h
index 2ba1676eb6b97..e4fd9c1c506d8 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.h
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.h
@@ -25,11 +25,20 @@ class CIRGenFunction;
 
 /// Abstract information about a function or function prototype.
 class CIRGenCalleeInfo {
+  const clang::FunctionProtoType *calleeProtoTy;
   clang::GlobalDecl calleeDecl;
 
 public:
-  explicit CIRGenCalleeInfo() : calleeDecl() {}
+  explicit CIRGenCalleeInfo() : calleeProtoTy(nullptr), calleeDecl() {}
+  CIRGenCalleeInfo(const clang::FunctionProtoType *calleeProtoTy,
+                   clang::GlobalDecl calleeDecl)
+      : calleeProtoTy(calleeProtoTy), calleeDecl(calleeDecl) {}
   CIRGenCalleeInfo(clang::GlobalDecl calleeDecl) : calleeDecl(calleeDecl) {}
+
+  const clang::FunctionProtoType *getCalleeFunctionProtoType() const {
+    return calleeProtoTy;
+  }
+  clang::GlobalDecl getCalleeDecl() const { return calleeDecl; }
 };
 
 class CIRGenCallee {
diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index 711a65215b043..1fffec78a658f 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -915,8 +915,28 @@ CIRGenCallee CIRGenFunction::emitCallee(const clang::Expr *e) {
       return emitDirectCallee(cgm, funcDecl);
   }
 
-  cgm.errorNYI(e->getSourceRange(), "Unsupported callee kind");
-  return {};
+  assert(!cir::MissingFeatures::opCallPseudoDtor());
+
+  // Otherwise, we have an indirect reference.
+  mlir::Value calleePtr;
+  QualType functionType;
+  if (const auto *ptrType = e->getType()->getAs<clang::PointerType>()) {
+    calleePtr = emitScalarExpr(e);
+    functionType = ptrType->getPointeeType();
+  } else {
+    functionType = e->getType();
+    calleePtr = emitLValue(e).getPointer();
+  }
+  assert(functionType->isFunctionType());
+
+  GlobalDecl gd;
+  if (const auto *vd =
+          dyn_cast_or_null<VarDecl>(e->getReferencedDeclOfCallee()))
+    gd = GlobalDecl(vd);
+
+  CIRGenCalleeInfo calleeInfo(functionType->getAs<FunctionProtoType>(), gd);
+  CIRGenCallee callee(calleeInfo, calleePtr.getDefiningOp());
+  return callee;
 }
 
 RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e,
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h
index 645e6b23c4f76..a3e7025d4137c 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h
@@ -16,6 +16,7 @@
 #define LLVM_CLANG_CIR_CIRGENFUNCTIONINFO_H
 
 #include "clang/AST/CanonicalType.h"
+#include "clang/CIR/MissingFeatures.h"
 #include "llvm/ADT/FoldingSet.h"
 #include "llvm/Support/TrailingObjects.h"
 
@@ -67,6 +68,13 @@ class CIRGenFunctionInfo final
     return llvm::ArrayRef<ArgInfo>(arg_begin(), numArgs);
   }
 
+  llvm::MutableArrayRef<ArgInfo> requiredArguments() {
+    return llvm::MutableArrayRef<ArgInfo>(arg_begin(), getNumRequiredArgs());
+  }
+  llvm::ArrayRef<ArgInfo> requiredArguments() const {
+    return llvm::ArrayRef<ArgInfo>(arg_begin(), getNumRequiredArgs());
+  }
+
   const_arg_iterator arg_begin() const { return getArgsBuffer() + 1; }
   const_arg_iterator arg_end() const { return getArgsBuffer() + 1 + numArgs; }
   arg_iterator arg_begin() { return getArgsBuffer() + 1; }
@@ -75,6 +83,11 @@ class CIRGenFunctionInfo final
   unsigned arg_size() const { return numArgs; }
 
   CanQualType getReturnType() const { return getArgsBuffer()[0].type; }
+
+  unsigned getNumRequiredArgs() const {
+    assert(!cir::MissingFeatures::opCallVariadic());
+    return arg_size();
+  }
 };
 
 } // namespace clang::CIRGen
diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.h b/clang/lib/CIR/CodeGen/CIRGenTypes.h
index cf94375d17e12..4d392400e3924 100644
--- a/clang/lib/CIR/CodeGen/CIRGenTypes.h
+++ b/clang/lib/CIR/CodeGen/CIRGenTypes.h
@@ -117,6 +117,9 @@ class CIRGenTypes {
   // TODO: convert this comment to account for MLIR's equivalence
   mlir::Type convertTypeForMem(clang::QualType, bool forBitField = false);
 
+  /// Get the CIR function type for \arg Info.
+  cir::FuncType getFunctionType(const CIRGenFunctionInfo &info);
+
   /// Return whether a type can be zero-initialized (in the C++ sense) with an
   /// LLVM zeroinitializer.
   bool isZeroInitializable(clang::QualType ty);
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b131edaf403ed..06f6391ecd537 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -464,15 +464,35 @@ OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
 // CallOp
 //===----------------------------------------------------------------------===//
 
+mlir::OperandRange cir::CallOp::getArgOperands() {
+  if (isIndirect())
+    return getArgs().drop_front(1);
+  return getArgs();
+}
+
+mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
+  mlir::MutableOperandRange args = getArgsMutable();
+  if (isIndirect())
+    return args.slice(1, args.size() - 1);
+  return args;
+}
+
+mlir::Value cir::CallOp::getIndirectCall() {
+  assert(isIndirect());
+  return getOperand(0);
+}
+
 /// Return the operand at index 'i'.
 Value cir::CallOp::getArgOperand(unsigned i) {
-  assert(!cir::MissingFeatures::opCallIndirect());
+  if (isIndirect())
+    ++i;
   return getOperand(i);
 }
 
 /// Return the number of operands.
 unsigned cir::CallOp::getNumArgOperands() {
-  assert(!cir::MissingFeatures::opCallIndirect());
+  if (isIndirect())
+    return this->getOperation()->getNumOperands() - 1;
   return this->getOperation()->getNumOperands();
 }
 
@@ -483,9 +503,15 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
   mlir::FlatSymbolRefAttr calleeAttr;
   llvm::ArrayRef<mlir::Type> allResultTypes;
 
+  // If we cannot parse a string callee, it means this is an indirect call.
   if (!parser.parseOptionalAttribute(calleeAttr, "callee", result.attributes)
-           .has_value())
-    return mlir::failure();
+           .has_value()) {
+    OpAsmParser::UnresolvedOperand indirectVal;
+    // Do not resolve right now, since we need to figure out the type
+    if (parser.parseOperand(indirectVal).failed())
+      return failure();
+    ops.push_back(indirectVal);
+  }
 
   if (parser.parseLParen())
     return mlir::failure();
@@ -517,13 +543,21 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
 
 static void printCallCommon(mlir::Operation *op,
                             mlir::FlatSymbolRefAttr calleeSym,
+                            mlir::Value indirectCallee,
                             mlir::OpAsmPrinter &printer) {
   printer << ' ';
 
   auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
   auto ops = callLikeOp.getArgOperands();
 
-  printer.printAttributeWithoutType(calleeSym);
+  if (calleeSym) {
+    // Direct calls
+    printer.printAttributeWithoutType(calleeSym);
+  } else {
+    // Indirect calls
+    assert(indirectCallee);
+    printer << indirectCallee;
+  }
   printer << "(" << ops << ")";
 
   printer.printOptionalAttrDict(op->getAttrs(), {"callee"});
@@ -539,7 +573,8 @@ mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
 }
 
 void cir::CallOp::print(mlir::OpAsmPrinter &p) {
-  printCallCommon(*this, getCalleeAttr(), p);
+  mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
+  printCallCommon(*this, getCalleeAttr(), indirectCallee, p);
 }
 
 static LogicalResult
@@ -547,7 +582,7 @@ verifyCallCommInSymbolUses(mlir::Operation *op,
                            SymbolTableCollection &symbolTable) {
   auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee");
   if (!fnAttr)
-    return mlir::failure();
+    return mlir::success();
 
   auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
   if (!fn)
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5986655ababe9..dbc4904736a40 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -674,8 +674,15 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
     llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
         converter->convertType(fn.getFunctionType()));
   } else { // indirect call
-    assert(!cir::MissingFeatures::opCallIndirect());
-    return op->emitError("Indirect calls are NYI");
+    assert(!op->getOperands().empty() &&
+           "operands list must no be empty for the indirect call");
+    auto calleeTy = op->getOperands().front().getType();
+    auto calleePtrTy = cast<cir::PointerType>(calleeTy);
+    auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee());
+    calleeFuncTy.dump();
+    converter->convertType(calleeFuncTy).dump();
+    llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
+        converter->convertType(calleeFuncTy));
   }
 
   assert(!cir::MissingFeatures::opCallLandingPad());
@@ -1500,6 +1507,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
   converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
     return mlir::BFloat16Type::get(type.getContext());
   });
+  converter.addConversion([&](cir::FuncType type) -> mlir::Type {
+    auto result = converter.convertType(type.getReturnType());
+    llvm::SmallVector<mlir::Type> arguments;
+    if (converter.convertTypes(type.getInputs(), arguments).failed())
+      llvm_unreachable("Failed to convert function type parameters");
+    auto varArg = type.isVarArg();
+    return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg);
+  });
   converter.addConversion([&](cir::RecordType type) -> mlir::Type {
     // Convert struct members.
     llvm::SmallVector<mlir::Type> llvmMembers;
diff --git a/clang/test/CIR/CodeGen/call.cpp b/clang/test/CIR/CodeGen/call.cpp
index 3b1ab8b5fc498..8b8f1296b5108 100644
--- a/clang/test/CIR/CodeGen/call.cpp
+++ b/clang/test/CIR/CodeGen/call.cpp
@@ -42,3 +42,17 @@ int f6() {
 
 // LLVM-LABEL: define i32 @_Z2f6v() {
 // LLVM:         %{{.+}} = call i32 @_Z2f5iPib(i32 2, ptr %{{.+}}, i1 false)
+
+int f7(int (*ptr)(int, int)) {
+  return ptr(1, 2);
+}
+
+// CIR-LABEL: cir.func @_Z2f7PFiiiE
+// CIR:         %[[#ptr:]] = cir.load %{{.+}} : !cir.ptr<!cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>>, !cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>
+// CIR-NEXT:    %[[#a:]] = cir.const #cir.int<1> : !s32i
+// CIR-NEXT:    %[[#b:]] = cir.const #cir.int<2> : !s32i
+// CIR-NEXT:    %{{.+}} = cir.call %[[#ptr]](%[[#a]], %[[#b]]) : (!cir.ptr<!cir.func<(!s32i, !s32i) -> !s32i>>, !s32i, !s32i) -> !s32i
+
+// LLVM-LABEL: define i32 @_Z2f7PFiiiE
+// LLVM:         %[[#ptr:]] = load ptr, ptr %{{.+}}
+// LLVM...
[truncated]

Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly good, some nits!


If the operation represents a direct call, the callee should be defined
within the same symbol scope as the call. The `callee` attribute contains a
symbo reference to the callee function. All operands of this operation are
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

symbo -> symbol

auto result = converter.convertType(type.getReturnType());
llvm::SmallVector<mlir::Type> arguments;
if (converter.convertTypes(type.getInputs(), arguments).failed())
llvm_unreachable("Failed to convert function type parameters");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too bad we can't emit an error here. Perhaps instead of saying failed use something like "cir::FuncType conversion NYI"?

@@ -1500,6 +1507,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
return mlir::BFloat16Type::get(type.getContext());
});
converter.addConversion([&](cir::FuncType type) -> mlir::Type {
auto result = converter.convertType(type.getReturnType());
llvm::SmallVector<mlir::Type> arguments;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the type.getInputs() length to reserve the capacity for arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants