-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][ArmSME] Audit ArmSME load/store ops #139573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Audit ArmSME load/store ops #139573
Conversation
This patch updates the following arm_sme ops to require that input and output element types match: * `arm_sme.tile_load`, `arm_sme.tile_store`, `arm_sme.tile_load_slice`, `arm_sme.tile_store_slice`. In addition, it ensures that the base memref operand for `tile_load` and `tile_store` is always rank-2, aligning with the semantics of Arm SME tiles (always rank-2). This change is effectively a follow-up to llvm#135151: * "[mlir][vector] Tighten the semantics of vector.{load|store}" The patch also updates `createLoadStoreForOverTileSlices` in ArmSMEToSCF.cpp to fail when processing invalid tile stores like the following: ```mlir arm_sme.tile_store %arg0, %arg1[%c0] : memref<?x4xi8>, vector<[4]x[4]xi32> ``` This particular change fixes llvm#118769. As noted in the TODO, we should further extend op verification logic — I plan to address that in a follow-up patch.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis patch updates the following arm_sme ops to require that input and
In addition, it ensures that the base memref operand for
The patch also updates arm_sme.tile_store %arg0, %arg1[%c0] : memref<?x4xi8>, vector<[4]x[4]xi32> This particular change fixes #118769. As noted in the TODO, we should Full diff: https://github.com/llvm/llvm-project/pull/139573.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6fd992afbf043..2f083b55d4904 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -317,6 +317,7 @@ def CopyTileOp : ArmSME_Op<"copy_tile", [
def TileLoadOp : ArmSME_Op<"tile_load", [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
+ AllElementTypesMatch<["result", "base"]>,
OptionalTypesMatchWith<
"padding type matches element type of result",
"result", "padding",
@@ -369,7 +370,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
```
}];
let arguments = (ins
- Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+ Arg<MemRefRankOf<[AnyType], [2]>, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
@@ -407,6 +408,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
def TileStoreOp : ArmSME_Op<"tile_store", [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
+ AllElementTypesMatch<["valueToStore", "base"]>,
HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
]> {
let summary = "Tile store operation";
@@ -443,7 +445,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
```
}];
let arguments = (ins SMETile:$valueToStore,
- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+ Arg<MemRefRankOf<[AnyType], [2]>, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
@@ -473,6 +475,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
ArmSMETileOpInterface,
+ AllElementTypesMatch<["tile", "base"]>,
AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
]> {
let summary = "Tile slice load and update operation";
@@ -535,6 +538,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
ArmSMETileOpInterface,
+ AllElementTypesMatch<["tile", "base"]>,
TileSliceMaskConstraint<"tile", "mask">
]> {
let summary = "Tile slice store operation";
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 630414030d98b..458628c29c6ac 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -33,20 +33,15 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
Value tileSliceIndex,
Value tileSliceNumElts, Location loc,
PatternRewriter &rewriter) {
- assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+ assert(rank == 2 && "memref has unexpected rank!");
SmallVector<Value, 2> outIndices;
auto tileSliceOffset = tileSliceIndex;
- if (rank == 1)
- tileSliceOffset =
- rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
auto baseIndexPlusTileSliceOffset =
rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
outIndices.push_back(baseIndexPlusTileSliceOffset);
-
- if (rank == 2)
- outIndices.push_back(indices[1]);
+ outIndices.push_back(indices[1]);
return outIndices;
}
@@ -60,6 +55,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
makeLoopBody) {
PatternRewriter::InsertionGuard guard(rewriter);
+ // TODO: This case should be captured and rejected by a verifier.
+ if (memrefIndices.size() != 2)
+ return rewriter.notifyMatchFailure(loc, "invalid number of indices");
+
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
auto vscale =
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 700b2412ff7a7..8c5a098a0c785 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -50,7 +50,7 @@ func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
// -----
-func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
+func.func @arm_sme_insert_tile_slice_i8__bad_vector_length(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
%c0 = arith.constant 0 : index
// expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
%0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xi8> into vector<[16]x[16]xi8>
@@ -59,23 +59,40 @@ func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8
// -----
-func.func @arm_sme_insert_tile_slice_f32__bad_vector_type(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
+func.func @arm_sme_insert_tile_slice_f32__bad_vector_length(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
%c0 = arith.constant 0 : index
// expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
%0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xf32> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
+// -----
+
+func.func @arm_sme_insert_tile_slice__bad_element_type(%vector : vector<[4]xf64>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
+ %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[4]xf64> into vector<[4]x[4]xf32>
+ return %0 : vector<[4]x[4]xf32>
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.extract_tile_slice
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_extract_tile_slice__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
+func.func @arm_sme_extract_tile_slice__bad_result_length(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf32> {
+ // expected-error@+1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
+ %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf32> from vector<[4]x[4]xf32>
+ return %0 : vector<[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_extract_tile_slice__bad_result_element_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf64> {
// expected-error@+1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
- %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
- return %0 : vector<[2]xf64>
+ %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[4]xf64> from vector<[4]x[4]xf32>
+ return %0 : vector<[4]xf64>
}
//===----------------------------------------------------------------------===//
@@ -111,6 +128,24 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
return
}
+// -----
+
+func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+ %tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_element_type(%src : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{failed to verify that all of {result, base} have same element type}}
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.load_tile_slice
//===----------------------------------------------------------------------===//
@@ -124,6 +159,15 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
return
}
+// -----
+
+func.func @arm_sme_load_tile_slice__bad_element_type(%src : memref<?x?xi32>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{op failed to verify that all of {tile, base} have same element type}}
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.tile_store
//===----------------------------------------------------------------------===//
@@ -138,6 +182,24 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
return
}
+// -----
+
+func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+ arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store__bad_element_type(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{op failed to verify that all of {valueToStore, base} have same element type}}
+ arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.store_tile_slice
//===----------------------------------------------------------------------===//
@@ -152,6 +214,15 @@ func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>,
return
}
+// -----
+
+func.func @arm_sme_store_tile_slice__bad_element_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi32>) -> () {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{op failed to verify that all of {tile, base} have same element type}}
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.outerproduct
//===----------------------------------------------------------------------===//
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/13056 Here is the relevant piece of the build log for the reference
|
This patch updates the following ArmSME ops to require that input and
output element types match:
arm_sme.tile_load
,arm_sme.tile_store
,arm_sme.tile_load_slice
,arm_sme.tile_store_slice
.In addition, it ensures that the base memref operand for
tile_load
andtile_store
is always rank-2, aligning with the semantics of Arm SMEtiles (always rank-2). This change is effectively a follow-up to #135151:
The patch also updates
createLoadStoreForOverTileSlices
inArmSMEToSCF.cpp to fail when processing invalid tile stores like the
following:
This particular change fixes #118769. As noted in the TODO, we should
further extend op verification logic — I plan to address that in a
follow-up patch.