Skip to content

[mlir][LLVM] handle ArrayAttr for constant array of structs #139724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jeanPerier
Copy link
Contributor

While LLVM IR dialect has a way to represent arbitrary LLVM constant array of structs via an insert chain, it is in practice very expensive for the compilation time as soon as the array is bigger than a couple hundred elements. This is because generating and later folding such insert chain is really not cheap.

This is an issue for flang because it is very easy to generate rather big array of struct constant initializer in Fortran, and unlike C++, dynamic initialization of globals is not a feature of the language. Initializers must be static.

For instance, here are the compile time I measuring for the following program changing N size:

! test.F90
module m
  type t
    integer :: i = 42
    real :: x = 1.0
  end type
  type(t) :: some_global(N)
end module
/usr/bin/time flang -c -DN=1000 test.F90 
0.08user 0.07system 0:00.11elapsed 140%CPU (0avgtext+0avgdata 83968maxresident)k
8inputs+40outputs (13major+4840minor)pagefaults 0swaps

/usr/bin/time  flang -c -DN=10000 test.F90 
1.40user 0.08system 0:01.44elapsed 102%CPU (0avgtext+0avgdata 89088maxresident)k
8inputs+184outputs (13major+8764minor)pagefaults 0swaps

/usr/bin/time  flang -c -DN=100000 test.F90 
137.79user 0.22system 2:18.00elapsed 100%CPU (0avgtext+0avgdata 145540maxresident)k
8inputs+1584outputs (10major+82461minor)pagefaults 0swap

In the last case, more than 99.99% of the time is spend folding the insert chain in ModuleTranslation.cpp
With this patch (and updating flang to generate an ArrayAttr instead of an insert chain), the last case with 100000 elements takes 0.15s to compile (~1000x compilation speed up :)).

This is not a silver bullet because there are cases where an insert chain will still currently be needed, like when the initial values contain symbol reference, but this is not very common for my use case.

@llvmbot
Copy link
Member

llvmbot commented May 13, 2025

@llvm/pr-subscribers-mlir

Author: None (jeanPerier)

Changes

While LLVM IR dialect has a way to represent arbitrary LLVM constant array of structs via an insert chain, it is in practice very expensive for the compilation time as soon as the array is bigger than a couple hundred elements. This is because generating and later folding such insert chain is really not cheap.

This is an issue for flang because it is very easy to generate rather big array of struct constant initializer in Fortran, and unlike C++, dynamic initialization of globals is not a feature of the language. Initializers must be static.

For instance, here are the compile time I measuring for the following program changing N size:

! test.F90
module m
  type t
    integer :: i = 42
    real :: x = 1.0
  end type
  type(t) :: some_global(N)
end module
/usr/bin/time flang -c -DN=1000 test.F90 
0.08user 0.07system 0:00.11elapsed 140%CPU (0avgtext+0avgdata 83968maxresident)k
8inputs+40outputs (13major+4840minor)pagefaults 0swaps

/usr/bin/time  flang -c -DN=10000 test.F90 
1.40user 0.08system 0:01.44elapsed 102%CPU (0avgtext+0avgdata 89088maxresident)k
8inputs+184outputs (13major+8764minor)pagefaults 0swaps

/usr/bin/time  flang -c -DN=100000 test.F90 
137.79user 0.22system 2:18.00elapsed 100%CPU (0avgtext+0avgdata 145540maxresident)k
8inputs+1584outputs (10major+82461minor)pagefaults 0swap

In the last case, more than 99.99% of the time is spend folding the insert chain in ModuleTranslation.cpp
With this patch (and updating flang to generate an ArrayAttr instead of an insert chain), the last case with 100000 elements takes 0.15s to compile (~1000x compilation speed up :)).

This is not a silver bullet because there are cases where an insert chain will still currently be needed, like when the initial values contain symbol reference, but this is not very common for my use case.


Full diff: https://github.com/llvm/llvm-project/pull/139724.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+44-9)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+27)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+22)
  • (modified) mlir/test/Target/LLVMIR/llvmir-invalid.mlir (-5)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+17)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c757f3ceb90e3..1868995e3f5ed 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3221,15 +3221,50 @@ LogicalResult LLVM::ConstantOp::verify() {
     if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
       return emitOpError() << "expected vector or array type";
     // The number of elements of the attribute and the type must match.
-    int64_t attrNumElements;
-    if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
-      attrNumElements = elementsAttr.getNumElements();
-    else
-      attrNumElements = cast<ArrayAttr>(getValue()).size();
-    if (getNumElements(getType()) != attrNumElements)
-      return emitOpError()
-             << "type and attribute have a different number of elements: "
-             << getNumElements(getType()) << " vs. " << attrNumElements;
+    if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
+      int64_t attrNumElements = elementsAttr.getNumElements();
+      if (getNumElements(getType()) != attrNumElements)
+        return emitOpError()
+               << "type and attribute have a different number of elements: "
+               << getNumElements(getType()) << " vs. " << attrNumElements;
+    } else {
+      // When the attribute is an ArrayAttr, check that its nesting matches the
+      // corresponding ArrayType or VectorType nesting.
+      Type dimType = getType();
+      Attribute dimVal = getValue();
+      int dim = 0;
+      while (true) {
+        int64_t dimSize =
+            llvm::TypeSwitch<Type, int64_t>(dimType)
+                .Case<VectorType, LLVMArrayType>([&dimType](auto t) -> int64_t {
+                  dimType = t.getElementType();
+                  return t.getNumElements();
+                })
+                .Default([](auto) -> int64_t { return -1; });
+        if (dimSize < 0)
+          break;
+        auto arrayAttr = dyn_cast<ArrayAttr>(dimVal);
+        if (!arrayAttr)
+          return emitOpError()
+                 << "array attribute nesting must match array type nesting";
+        if (dimSize != static_cast<int64_t>(arrayAttr.size()))
+          return emitOpError()
+                 << "array attribute size does not match array type size in "
+                    "dimension "
+                 << dim << ": " << arrayAttr.size() << " vs. " << dimSize;
+        if (arrayAttr.size() == 0)
+          break;
+        dimVal = arrayAttr.getValue()[0];
+        ++dim;
+      }
+      if (auto structType = dyn_cast<LLVMStructType>(dimType)) {
+        auto arrayAttr = dyn_cast<ArrayAttr>(dimVal);
+        if (!arrayAttr || arrayAttr.size() != structType.getBody().size())
+          return emitOpError()
+                 << "nested attribute must be an array attribute with the same "
+                    "number of elements as the struct type";
+      }
+    }
   } else {
     return emitOpError()
            << "only supports integer, float, string or elements attributes";
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 1168b9f339904..1d4509ccb044e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -713,6 +713,33 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
         ArrayRef<char>{stringAttr.getValue().data(),
                        stringAttr.getValue().size()});
   }
+
+  // Handle arrays of structs that cannot be represented as DenseElementsAttr
+  // in MLIR.
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
+      llvm::Type *elementType = arrayTy->getElementType();
+      Attribute previousElementAttr;
+      llvm::Constant *elementCst = nullptr;
+      SmallVector<llvm::Constant *> constants;
+      constants.reserve(arrayTy->getNumElements());
+      for (auto elementAttr : arrayAttr) {
+        // Arrays with a single value or with repeating values are quite common.
+        // short-circuit the translation when the element value is the same as
+        // the previous one.
+        if (!previousElementAttr || previousElementAttr != elementAttr) {
+          previousElementAttr = elementAttr;
+          elementCst =
+              getLLVMConstant(elementType, elementAttr, loc, moduleTranslation);
+          if (!elementCst)
+            return nullptr;
+        }
+        constants.push_back(elementCst);
+      }
+      return llvm::ConstantArray::get(arrayTy, constants);
+    }
+  }
+
   emitError(loc, "unsupported constant value");
   return nullptr;
 }
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f9ea066a63624..4c82e586b8a3c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1850,3 +1850,25 @@ llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) {
   llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
   llvm.return
 }
+
+// -----
+
+llvm.mlir.global @x1() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
+  // expected-error@+1{{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}}
+  %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>>
+}
+
+// -----
+llvm.mlir.global @x2() : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> {
+  // expected-error@+1{{'llvm.mlir.constant' op array attribute nesting must match array type nesting}}
+  %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
+  llvm.return %0 : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
+}
+
+// -----
+llvm.mlir.global @x3() : !llvm.array<1x!llvm.struct<(i32, f32)>> {
+  // expected-error@+1{{'llvm.mlir.constant' op nested attribute must be an array attribute with the same number of elements as the struct type}}
+  %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>>
+}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 90c0f5ac55cb1..24a7b42557278 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -79,11 +79,6 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
 
 // -----
 
-// expected-error @below{{unsupported constant value}}
-llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64>
-
-// -----
-
 // expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
 llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
 
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 4ef68fa83a70d..242a151116fb3 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -3000,3 +3000,20 @@ llvm.func internal @i(%arg0: i32) attributes {dso_local} {
   llvm.call @testfn3(%arg0) : (i32 {llvm.alignstack = 8 : i64}) -> ()
   llvm.return
 }
+
+// -----
+
+// CHECK: @test_array_attr_1 = internal constant [2 x double] [double 2.500000e+00, double 7.400000e+00]
+llvm.mlir.global internal constant @test_array_attr_1([2.5, 7.4]) : !llvm.array<2 x f64>
+
+// CHECK: @test_array_attr_2 = global [2 x { i32, float }] [{ i32, float } { i32 42, float 1.000000e+00 }, { i32, float } { i32 42, float 1.000000e+00 }]
+llvm.mlir.global @test_array_attr_2() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
+  %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32],[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>>
+}
+
+// CHECK: @test_array_attr_3 = global [2 x [3 x { i32, float }]{{.*}}[3 x { i32, float }] [{ i32, float } { i32 1, float 1.000000e+00 }, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } { i32 3, float 1.000000e+00 }], [3 x { i32, float }] [{ i32, float } { i32 4, float 1.000000e+00 }, { i32, float } { i32 5, float 1.000000e+00 }, { i32, float } { i32 6, float 1.000000e+00 }
+llvm.mlir.global @test_array_attr_3() : !llvm.array<2x!llvm.array<3x!llvm.struct<(i32, f32)>>> {
+  %0 = llvm.mlir.constant([[[1 : i32, 1.000000e+00 : f32],[2 : i32, 1.000000e+00 : f32],[3 : i32, 1.000000e+00 : f32]],[[4 : i32, 1.000000e+00 : f32],[5 : i32, 1.000000e+00 : f32],[6 : i32, 1.000000e+00 : f32]]]) : !llvm.array<2x!llvm.array<3x!llvm.struct<(i32, f32)>>>
+  llvm.return %0 : !llvm.array<2x!llvm.array<3x!llvm.struct<(i32, f32)>>>
+}

@llvmbot
Copy link
Member

llvmbot commented May 13, 2025

@llvm/pr-subscribers-mlir-llvm

Author: None (jeanPerier)

Changes

While LLVM IR dialect has a way to represent arbitrary LLVM constant array of structs via an insert chain, it is in practice very expensive for the compilation time as soon as the array is bigger than a couple hundred elements. This is because generating and later folding such insert chain is really not cheap.

This is an issue for flang because it is very easy to generate rather big array of struct constant initializer in Fortran, and unlike C++, dynamic initialization of globals is not a feature of the language. Initializers must be static.

For instance, here are the compile time I measuring for the following program changing N size:

! test.F90
module m
  type t
    integer :: i = 42
    real :: x = 1.0
  end type
  type(t) :: some_global(N)
end module
/usr/bin/time flang -c -DN=1000 test.F90 
0.08user 0.07system 0:00.11elapsed 140%CPU (0avgtext+0avgdata 83968maxresident)k
8inputs+40outputs (13major+4840minor)pagefaults 0swaps

/usr/bin/time  flang -c -DN=10000 test.F90 
1.40user 0.08system 0:01.44elapsed 102%CPU (0avgtext+0avgdata 89088maxresident)k
8inputs+184outputs (13major+8764minor)pagefaults 0swaps

/usr/bin/time  flang -c -DN=100000 test.F90 
137.79user 0.22system 2:18.00elapsed 100%CPU (0avgtext+0avgdata 145540maxresident)k
8inputs+1584outputs (10major+82461minor)pagefaults 0swap

In the last case, more than 99.99% of the time is spend folding the insert chain in ModuleTranslation.cpp
With this patch (and updating flang to generate an ArrayAttr instead of an insert chain), the last case with 100000 elements takes 0.15s to compile (~1000x compilation speed up :)).

This is not a silver bullet because there are cases where an insert chain will still currently be needed, like when the initial values contain symbol reference, but this is not very common for my use case.


Full diff: https://github.com/llvm/llvm-project/pull/139724.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+44-9)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+27)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+22)
  • (modified) mlir/test/Target/LLVMIR/llvmir-invalid.mlir (-5)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+17)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c757f3ceb90e3..1868995e3f5ed 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3221,15 +3221,50 @@ LogicalResult LLVM::ConstantOp::verify() {
     if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
       return emitOpError() << "expected vector or array type";
     // The number of elements of the attribute and the type must match.
-    int64_t attrNumElements;
-    if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
-      attrNumElements = elementsAttr.getNumElements();
-    else
-      attrNumElements = cast<ArrayAttr>(getValue()).size();
-    if (getNumElements(getType()) != attrNumElements)
-      return emitOpError()
-             << "type and attribute have a different number of elements: "
-             << getNumElements(getType()) << " vs. " << attrNumElements;
+    if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
+      int64_t attrNumElements = elementsAttr.getNumElements();
+      if (getNumElements(getType()) != attrNumElements)
+        return emitOpError()
+               << "type and attribute have a different number of elements: "
+               << getNumElements(getType()) << " vs. " << attrNumElements;
+    } else {
+      // When the attribute is an ArrayAttr, check that its nesting matches the
+      // corresponding ArrayType or VectorType nesting.
+      Type dimType = getType();
+      Attribute dimVal = getValue();
+      int dim = 0;
+      while (true) {
+        int64_t dimSize =
+            llvm::TypeSwitch<Type, int64_t>(dimType)
+                .Case<VectorType, LLVMArrayType>([&dimType](auto t) -> int64_t {
+                  dimType = t.getElementType();
+                  return t.getNumElements();
+                })
+                .Default([](auto) -> int64_t { return -1; });
+        if (dimSize < 0)
+          break;
+        auto arrayAttr = dyn_cast<ArrayAttr>(dimVal);
+        if (!arrayAttr)
+          return emitOpError()
+                 << "array attribute nesting must match array type nesting";
+        if (dimSize != static_cast<int64_t>(arrayAttr.size()))
+          return emitOpError()
+                 << "array attribute size does not match array type size in "
+                    "dimension "
+                 << dim << ": " << arrayAttr.size() << " vs. " << dimSize;
+        if (arrayAttr.size() == 0)
+          break;
+        dimVal = arrayAttr.getValue()[0];
+        ++dim;
+      }
+      if (auto structType = dyn_cast<LLVMStructType>(dimType)) {
+        auto arrayAttr = dyn_cast<ArrayAttr>(dimVal);
+        if (!arrayAttr || arrayAttr.size() != structType.getBody().size())
+          return emitOpError()
+                 << "nested attribute must be an array attribute with the same "
+                    "number of elements as the struct type";
+      }
+    }
   } else {
     return emitOpError()
            << "only supports integer, float, string or elements attributes";
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 1168b9f339904..1d4509ccb044e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -713,6 +713,33 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
         ArrayRef<char>{stringAttr.getValue().data(),
                        stringAttr.getValue().size()});
   }
+
+  // Handle arrays of structs that cannot be represented as DenseElementsAttr
+  // in MLIR.
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
+      llvm::Type *elementType = arrayTy->getElementType();
+      Attribute previousElementAttr;
+      llvm::Constant *elementCst = nullptr;
+      SmallVector<llvm::Constant *> constants;
+      constants.reserve(arrayTy->getNumElements());
+      for (auto elementAttr : arrayAttr) {
+        // Arrays with a single value or with repeating values are quite common.
+        // short-circuit the translation when the element value is the same as
+        // the previous one.
+        if (!previousElementAttr || previousElementAttr != elementAttr) {
+          previousElementAttr = elementAttr;
+          elementCst =
+              getLLVMConstant(elementType, elementAttr, loc, moduleTranslation);
+          if (!elementCst)
+            return nullptr;
+        }
+        constants.push_back(elementCst);
+      }
+      return llvm::ConstantArray::get(arrayTy, constants);
+    }
+  }
+
   emitError(loc, "unsupported constant value");
   return nullptr;
 }
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f9ea066a63624..4c82e586b8a3c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1850,3 +1850,25 @@ llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) {
   llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
   llvm.return
 }
+
+// -----
+
+llvm.mlir.global @x1() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
+  // expected-error@+1{{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}}
+  %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>>
+}
+
+// -----
+llvm.mlir.global @x2() : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> {
+  // expected-error@+1{{'llvm.mlir.constant' op array attribute nesting must match array type nesting}}
+  %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
+  llvm.return %0 : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
+}
+
+// -----
+llvm.mlir.global @x3() : !llvm.array<1x!llvm.struct<(i32, f32)>> {
+  // expected-error@+1{{'llvm.mlir.constant' op nested attribute must be an array attribute with the same number of elements as the struct type}}
+  %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>>
+}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 90c0f5ac55cb1..24a7b42557278 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -79,11 +79,6 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
 
 // -----
 
-// expected-error @below{{unsupported constant value}}
-llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64>
-
-// -----
-
 // expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
 llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
 
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 4ef68fa83a70d..242a151116fb3 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -3000,3 +3000,20 @@ llvm.func internal @i(%arg0: i32) attributes {dso_local} {
   llvm.call @testfn3(%arg0) : (i32 {llvm.alignstack = 8 : i64}) -> ()
   llvm.return
 }
+
+// -----
+
+// CHECK: @test_array_attr_1 = internal constant [2 x double] [double 2.500000e+00, double 7.400000e+00]
+llvm.mlir.global internal constant @test_array_attr_1([2.5, 7.4]) : !llvm.array<2 x f64>
+
+// CHECK: @test_array_attr_2 = global [2 x { i32, float }] [{ i32, float } { i32 42, float 1.000000e+00 }, { i32, float } { i32 42, float 1.000000e+00 }]
+llvm.mlir.global @test_array_attr_2() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
+  %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32],[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>>
+}
+
+// CHECK: @test_array_attr_3 = global [2 x [3 x { i32, float }]{{.*}}[3 x { i32, float }] [{ i32, float } { i32 1, float 1.000000e+00 }, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } { i32 3, float 1.000000e+00 }], [3 x { i32, float }] [{ i32, float } { i32 4, float 1.000000e+00 }, { i32, float } { i32 5, float 1.000000e+00 }, { i32, float } { i32 6, float 1.000000e+00 }
+llvm.mlir.global @test_array_attr_3() : !llvm.array<2x!llvm.array<3x!llvm.struct<(i32, f32)>>> {
+  %0 = llvm.mlir.constant([[[1 : i32, 1.000000e+00 : f32],[2 : i32, 1.000000e+00 : f32],[3 : i32, 1.000000e+00 : f32]],[[4 : i32, 1.000000e+00 : f32],[5 : i32, 1.000000e+00 : f32],[6 : i32, 1.000000e+00 : f32]]]) : !llvm.array<2x!llvm.array<3x!llvm.struct<(i32, f32)>>>
+  llvm.return %0 : !llvm.array<2x!llvm.array<3x!llvm.struct<(i32, f32)>>>
+}

// -----

llvm.mlir.global @x1() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
// expected-error@+1{{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// expected-error@+1{{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}}
// expected-error@below{{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}}

Nit: Same for the other test cases.


// -----

llvm.mlir.global @x1() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe use a slightly more descriptive function name?


// CHECK: @test_array_attr_2 = global [2 x { i32, float }] [{ i32, float } { i32 42, float 1.000000e+00 }, { i32, float } { i32 42, float 1.000000e+00 }]
llvm.mlir.global @test_array_attr_2() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
%0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32],[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
%0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32],[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
%0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32], [42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2 x !llvm.struct<(i32, f32)>>

Ultra nit: Using spaces like this should make the tests slightly more readable. Feel free to ignore, though.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There have been multiple gradual additions to constant operation in the past and I wonder if it would be worth to try to come up with some alternative modeling for such large constants.

Do you need the full flexibility of this PR or are you mainly interested in a solution of arrays of structs? If it is the later, I would favor making the PR more restrictive to only support what we need.

Also note that the doc string of the constant operation well need to be update if we extend the supported element types.

dimType = t.getElementType();
return t.getNumElements();
})
.Default([](auto) -> int64_t { return -1; });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets use std::optional to signal an unexpected type is hit?

<< dim << ": " << arrayAttr.size() << " vs. " << dimSize;
if (arrayAttr.size() == 0)
break;
dimVal = arrayAttr.getValue()[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is we want to support, potentially large, arrays with different element types? In the fortran examples all types are the same though?

In any case just testing the first element of the array seems a bit arbitrary? Assuming the array is large we may not want to verify at all? But maybe some solution with not checking previously verified attributes may be fine?

<< "type and attribute have a different number of elements: "
<< getNumElements(getType()) << " vs. " << attrNumElements;
} else {
// When the attribute is an ArrayAttr, check that its nesting matches the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to factor out some helper function(s) to reduce the complexity of the verify function (I know that it was already complex before...). Maybe the check for arrays/vectors could be part of a separate function? Additionally, the test, if a scalar value is supported, could possibly be factored out as well? It seems that the struct test below is similar to the struct test at the beginning of the verify function?

llvm::Constant *elementCst = nullptr;
SmallVector<llvm::Constant *> constants;
constants.reserve(arrayTy->getNumElements());
for (auto elementAttr : arrayAttr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto elementAttr : arrayAttr) {
for (Attribute elementAttr : arrayAttr) {

I think the type here is just attribute?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants