Skip to content

WIP on some CIR → MLIR experiment #1334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f3c6945
[CIR][Lowering][NFC] Expose cir::prepareTypeConverter()
keryell Oct 9, 2024
1013314
[CIR] Add inline interface to CIR dialect
keryell Nov 20, 2024
e775742
[CIR] Add runAtStartOfConvertCIRToMLIRPass() to ConvertCIRToMLIRPass
keryell Dec 4, 2024
8bd9a8d
[CIR] Add runAtStartOfConvertCIRToLLVMPass() to ConvertCIRToLLVMPass
keryell Dec 18, 2024
e705fad
[CIR] Lower struct/union/class to tuple
keryell Jan 14, 2025
3793fc3
[CIR][MLIR] Allow memref of tuple
keryell Jan 22, 2025
ce2133d
[CIR][MLIR] Add minimal NamedTuple to core MLIR
keryell Jan 29, 2025
762695a
[CIR][MLIR] named_tuple.cast operation
keryell Feb 10, 2025
a09a2e8
[CIR][MLIR] Add some named_tuple type introspection functions
keryell Feb 11, 2025
12626d1
[CIR] Lower cir.get_member to named_tuple + memref casts
keryell Feb 11, 2025
64a0abd
[CIR] Lower to MLIR struct with array member
keryell Feb 14, 2025
a967bc7
[CIR][Lowering] Lower arrays in class/struct/union as tensor
keryell Feb 14, 2025
9ce66f2
[CIR][Lowering] Handle pointer of pointer of struct or array
keryell Feb 15, 2025
c1ddaff
[CIR][Lowering][MLIR] Export cir::lowerArrayType()
keryell Feb 18, 2025
1097f87
[CIR][Lowering][MLIR] Lower class/struct/union to memref by default
keryell Feb 19, 2025
6c35a8f
[CIR][Lowering][MLIR] Lower cir.cast(bitcast) between !cir.ptr
keryell Feb 21, 2025
a207932
[CIR][Lowering][MLIR] Rework the !cir.array lowering
keryell Feb 22, 2025
90b70f4
[CIR][Doc] Add some top-level documentation on CIR→MLIR WIP
keryell Feb 26, 2025
15abba4
[CIR][Lower][MLIR] Handle pointer decay of higher dimensions arrays
keryell Feb 28, 2025
429b8ae
[CIR][Lowering][MLIR] Generalize the lowering of cir.ptr_stride
keryell Feb 28, 2025
9338d3f
[CIR][Lowering][MLIR] Remove a layout API
keryell Mar 21, 2025
39b162f
[CIR][Lowering][NFC] Update code to upstream cir::RecordType
keryell Apr 17, 2025
7275515
[CIR][ThroughMLIR][NFC] Fix test syntax
keryell Apr 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[CIR][Lowering] Lower arrays in class/struct/union as tensor
Arrays in C/C++ have usually a reference semantics and can be lowered to memref.
But when inside a class/struct/union, arrays hav a value semantics and can be
lowered as tensor.
  • Loading branch information
keryell committed Apr 17, 2025
commit a967bc7bb3755b77aca8d4cd914e6e40b9c77b13
85 changes: 51 additions & 34 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
}
};

// Lower cir.get_member
// Lower cir.get_member by aliasing the result memref to the member inside the
// flattened structure as a byte array. For example
//
// clang-format off
//
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
//
// to something like
// is lowered to something like
//
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
// %c8 = arith.constant 8 : index
Expand Down Expand Up @@ -325,37 +325,30 @@ class CIRGetMemberOpLowering
// concrete datalayout, both datalayouts are the same.
auto *structLayout = dataLayout.getStructLayout(structType);

// Get the lowered type: memref<!named_tuple.named_tuple<>>
auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr().getType());
// Alias the memref of struct to a memref of an i8 array of the same size.
const std::array linearizedSize{
static_cast<std::int64_t>(dataLayout.getTypeStoreSize(structType))};
auto flattenMemRef = mlir::MemRefType::get(
linearizedSize, mlir::IntegerType::get(memref.getContext(), 8));
auto flattenedMemRef = mlir::MemRefType::get(
linearizedSize, mlir::IntegerType::get(getContext(), 8));
// Use a special cast because normal memref cast cannot do such an extreme
// cast.
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
op.getLoc(), mlir::TypeRange{flattenMemRef},
op.getLoc(), mlir::TypeRange{flattenedMemRef},
mlir::ValueRange{adaptor.getAddr()});

auto pointerToMemberTypeToLower = op.getResultTy();
// The lowered type of the cir.ptr to the cir.struct member.
auto memrefToLoweredMemberType =
typeConverter->convertType(pointerToMemberTypeToLower);
// Synthesize the byte access to right lowered type.
auto memberIndex = op.getIndex();
auto namedTupleType =
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
// The lowered type of the element to access in the named_tuple.
auto loweredMemberType = namedTupleType.getType(memberIndex);
// memref.view can only cast to another memref. Wrap the target type if it
// is not already a memref (like with a struct with an array member)
mlir::MemRefType elementMemRefTy;
if (mlir::isa<mlir::MemRefType>(loweredMemberType))
elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
else
elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
auto offset = structLayout->getElementOffset(memberIndex);
// Synthesize the byte access to right lowered type.
auto byteShift =
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), offset);
// Create the memref pointing to the flattened member location.
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
op, memrefToLoweredMemberType, bytesMemRef, byteShift,
mlir::ValueRange{});
return mlir::LogicalResult::success();
}
};
Expand Down Expand Up @@ -1382,6 +1375,29 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
cirDataLayout);
}

namespace {
// Lower a cir.array either as a memref when it has a reference semantics or as
// a tensor when it has a value semantics (like inside a struct or union)
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
mlir::TypeConverter &converter) {
SmallVector<int64_t> shape;
mlir::Type curType = type;
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
shape.push_back(arrayType.getSize());
curType = arrayType.getEltType();
}
auto elementType = converter.convertType(curType);
// FIXME: The element type might not be converted
if (!elementType)
return nullptr;
// Arrays in C/C++ have a reference semantics when not in a struct, so use
// a memref
if (hasValueSemantics)
return mlir::RankedTensorType::get(shape, elementType);
return mlir::MemRefType::get(shape, elementType);
}
} // namespace

mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
mlir::TypeConverter converter;
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
Expand All @@ -1390,6 +1406,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
if (!ty)
return nullptr;
if (isa<cir::ArrayType>(type.getPointee()))
// An array is already lowered as a memref with reference semantics
return ty;
return mlir::MemRefType::get({}, ty);
});
Expand Down Expand Up @@ -1429,23 +1446,23 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
return mlir::BFloat16Type::get(type.getContext());
});
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
SmallVector<int64_t> shape;
mlir::Type curType = type;
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
shape.push_back(arrayType.getSize());
curType = arrayType.getEltType();
}
auto elementType = converter.convertType(curType);
// FIXME: The element type might not be converted
if (!elementType)
return nullptr;
return mlir::MemRefType::get(shape, elementType);
// Arrays in C/C++ have a reference semantics when not in a
// class/struct/union, so use a memref.
return lowerArrayType(type, /* hasValueSemantics */ false, converter);
});
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
auto ty = converter.convertType(type.getEltType());
return mlir::VectorType::get(type.getSize(), ty);
});
converter.addConversion([&](cir::StructType type) -> mlir::Type {
auto convertWithValueSemanticsArray = [&](mlir::Type t) {
if (mlir::isa<cir::ArrayType>(t))
// Inside a class/struct/union, an array has value semantics and is
// lowered as a tensor.
return lowerArrayType(mlir::cast<cir::ArrayType>(t),
/* hasValueSemantics */ true, converter);
return converter.convertType(t);
};
// FIXME(cir): create separate unions, struct, and classes types.
// Convert struct members.
llvm::SmallVector<mlir::Type> mlirMembers;
Expand All @@ -1454,13 +1471,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
// TODO(cir): This should be properly validated.
case cir::StructType::Struct:
for (auto ty : type.getMembers())
mlirMembers.push_back(converter.convertType(ty));
mlirMembers.push_back(convertWithValueSemanticsArray(ty));
break;
// Unions are lowered as only the largest member.
case cir::StructType::Union: {
auto largestMember = type.getLargestMember(dataLayout);
if (largestMember)
mlirMembers.push_back(converter.convertType(largestMember));
mlirMembers.push_back(convertWithValueSemanticsArray(largestMember));
break;
}
}
Expand Down
16 changes: 8 additions & 8 deletions clang/test/CIR/Lowering/ThroughMLIR/struct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,40 @@ struct s {

int main() {
s v;
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>>
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>>
v.a = 7;
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<40xi8> to memref<i32>
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>

v.b = 3.;
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<40xi8> to memref<f64>
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>

v.c = 'z';
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<40xi8> to memref<i8>
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>

auto& a = v.d;
v.d[4] = 6.f;
// CHECK: %[[C_6:.+]] = arith.constant 6.000000e+00 : f32
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
// CHECK: %[[OFFSET_D:.+]] = arith.constant 20 : index
// Do not lower to a memref of memref
// CHECK: %[[VIEW_D:.+]] = memref.view %3[%c20][] : memref<40xi8> to memref<5xf32>
// CHECK: %[[VIEW_D:.+]] = memref.view %[[I8_EQUIV_D]][%[[OFFSET_D]]][] : memref<40xi8> to memref<5xf32>
// CHECK: %[[C_4:.+]] = arith.constant 4 : i32
// CHECK: %[[I_D:.+]] = arith.index_cast %[[C_4]] : i32 to index
// CHECK: memref.store %[[C_6]], %[[VIEW_D]][%[[I_D]]] : memref<5xf32>

return v.c;
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<40xi8> to memref<i8>
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>
Expand Down