-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][MemRef] Use specialized index ops to fold expand/collapse_shape #138930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][MemRef] Use specialized index ops to fold expand/collapse_shape #138930
Conversation
PRs this depends on: |
@llvm/pr-subscribers-mlir-memref Author: Krzysztof Drewniak (krzysz00) ChangesThis PR updates the FoldMemRefAliasOps to use This also loosens some limitations of the pass:
Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
according to specified offsets, sizes, and strides.
```mlir
- %result1 = memref.reinterpret_cast %arg0 to
+ %result1 = memref.reinterpret_cast %arg0 to
offset: [9],
sizes: [4, 4],
strides: [16, 2]
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
memref<4x4xf32, strided<[16, 2], offset: 9>>
- %result2 = memref.reinterpret_cast %result1 to
+ %result2 = memref.reinterpret_cast %result1 to
offset: [0],
sizes: [2, 2],
strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
OpBuilder &b, Location loc, MemRefType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape);
+
+ // Return a vector with all the static and dynamic values in the output shape.
+ SmallVector<OpFoldResult> getMixedOutputShape() {
+ OpBuilder builder(getContext());
+ return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+ }
}];
let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
let summary = "store operation";
let description = [{
The `store` op stores an element into a memref at the specified indices.
-
+
The number of indices must match the rank of the memref. The indices must
be in-bounds: `0 <= idx < dim_size`
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
Unlike the `reinterpret_cast`, the values are relative to the strided
memref of the input (`%result1` in this case) and not its
underlying memory.
-
+
Example 2:
```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
///
/// %2 = load %0[6 * i1 + i2, %i3] :
/// memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
- memref::ExpandShapeOp expandShapeOp,
- ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
- // Record the rewriter context for constructing ops later.
- MLIRContext *ctx = rewriter.getContext();
-
- // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
- // This is done for the purpose of inferring the output shape via
- // `inferExpandOutputShape` which will in turn be used for suffix product
- // calculation later.
- SmallVector<OpFoldResult> srcShape;
- MemRefType srcType = expandShapeOp.getSrcType();
-
- for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
- if (srcType.isDynamicDim(i)) {
- srcShape.push_back(
- rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
- .getResult());
- } else {
- srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
- }
- }
-
- auto outputShape = inferExpandShapeOutputShape(
- rewriter, loc, expandShapeOp.getResultType(),
- expandShapeOp.getReassociationIndices(), srcShape);
- if (!outputShape.has_value())
- return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+ Location loc, PatternRewriter &rewriter,
+ memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+ SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
// Traverse all reassociation groups to determine the appropriate indices
// corresponding to each one of them post op folding.
- for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- // Flag to indicate the presence of dynamic dimensions in current
- // reassociation group.
- int64_t groupSize = groups.size();
-
- // Group output dimensions utilized in this reassociation group for suffix
- // product calculation.
- SmallVector<OpFoldResult> sizesVal(groupSize);
- for (int64_t i = 0; i < groupSize; ++i) {
- sizesVal[i] = (*outputShape)[groups[i]];
+ for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+ if (groupSize == 1) {
+ sourceIndices.push_back(indices[group[0]]);
+ continue;
}
-
- // Calculate suffix product of relevant output dimension sizes.
- SmallVector<OpFoldResult> suffixProduct =
- memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
- // Create affine expression variables for dimensions and symbols in the
- // newly constructed affine map.
- SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
- bindDimsList<AffineExpr>(ctx, dims);
- bindSymbolsList<AffineExpr>(ctx, symbols);
-
- // Linearize binded dimensions and symbols to construct the resultant
- // affine expression for this indice.
- AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
- // Record the load index corresponding to each dimension in the
- // reassociation group. These are later supplied as operands to the affine
- // map used for calulating relevant index post op folding.
- SmallVector<OpFoldResult> dynamicIndices(groupSize);
- for (int64_t i = 0; i < groupSize; i++)
- dynamicIndices[i] = indices[groups[i]];
-
- // Supply suffix product results followed by load op indices as operands
- // to the map.
- SmallVector<OpFoldResult> mapOperands;
- llvm::append_range(mapOperands, suffixProduct);
- llvm::append_range(mapOperands, dynamicIndices);
-
- // Creating maximally folded and composed affine.apply composes better
- // with other transformations without interleaving canonicalization
- // passes.
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/groupSize,
- /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
- mapOperands);
-
- // Push index value in the op post folding corresponding to this
- // reassociation group.
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ SmallVector<OpFoldResult> groupBasis =
+ llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+ SmallVector<Value> groupIndices =
+ llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+ Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+ loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+ sourceIndices.push_back(collapsedIndex);
}
return success();
}
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- int64_t cnt = 0;
- SmallVector<OpFoldResult> dynamicIndices;
- for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- dynamicIndices.push_back(indices[cnt++]);
- int64_t groupSize = groups.size();
-
- // Calculate suffix product for all collapse op source dimension sizes
- // except the most major one of each group.
- // We allow the most major source dimension to be dynamic but enforce all
- // others to be known statically.
- SmallVector<int64_t> sizes(groupSize, 1);
- for (int64_t i = 1; i < groupSize; ++i) {
- sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
- if (sizes[i] == ShapedType::kDynamic)
- return failure();
- }
- SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
- // Derive the index values along all dimensions of the source corresponding
- // to the index wrt to collapsed shape op output.
- auto d0 = rewriter.getAffineDimExpr(0);
- SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
- // Construct the AffineApplyOp for each delinearizingExpr.
- for (int64_t i = 0; i < groupSize; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
- delinearizingExprs[i]),
- dynamicIndices);
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ MemRefType sourceType = collapseShapeOp.getSrcType();
+ // Note: collapse_shape requires a strided memref, we can do this.
+ auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, collapseShapeOp.getSrc());
+ SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+ for (auto [index, group] :
+ llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+
+ if (groupSize == 1) {
+ sourceIndices.push_back(index);
+ continue;
}
- dynamicIndices.clear();
+
+ SmallVector<OpFoldResult> basis =
+ llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+ auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ loc, index, basis, /*hasOuterBound=*/true);
+ llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, zeroAffineMap, dynamicIndices);
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.load and affine.load guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.store and affine.store guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
%1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
%1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
%1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
%1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return %0 : f32
}
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return %[[VAL1]] : f32
// -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return
}
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
}
return
}
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThis PR updates the FoldMemRefAliasOps to use This also loosens some limitations of the pass:
Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
according to specified offsets, sizes, and strides.
```mlir
- %result1 = memref.reinterpret_cast %arg0 to
+ %result1 = memref.reinterpret_cast %arg0 to
offset: [9],
sizes: [4, 4],
strides: [16, 2]
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
memref<4x4xf32, strided<[16, 2], offset: 9>>
- %result2 = memref.reinterpret_cast %result1 to
+ %result2 = memref.reinterpret_cast %result1 to
offset: [0],
sizes: [2, 2],
strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
OpBuilder &b, Location loc, MemRefType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape);
+
+ // Return a vector with all the static and dynamic values in the output shape.
+ SmallVector<OpFoldResult> getMixedOutputShape() {
+ OpBuilder builder(getContext());
+ return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+ }
}];
let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
let summary = "store operation";
let description = [{
The `store` op stores an element into a memref at the specified indices.
-
+
The number of indices must match the rank of the memref. The indices must
be in-bounds: `0 <= idx < dim_size`
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
Unlike the `reinterpret_cast`, the values are relative to the strided
memref of the input (`%result1` in this case) and not its
underlying memory.
-
+
Example 2:
```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
///
/// %2 = load %0[6 * i1 + i2, %i3] :
/// memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
- memref::ExpandShapeOp expandShapeOp,
- ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
- // Record the rewriter context for constructing ops later.
- MLIRContext *ctx = rewriter.getContext();
-
- // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
- // This is done for the purpose of inferring the output shape via
- // `inferExpandOutputShape` which will in turn be used for suffix product
- // calculation later.
- SmallVector<OpFoldResult> srcShape;
- MemRefType srcType = expandShapeOp.getSrcType();
-
- for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
- if (srcType.isDynamicDim(i)) {
- srcShape.push_back(
- rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
- .getResult());
- } else {
- srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
- }
- }
-
- auto outputShape = inferExpandShapeOutputShape(
- rewriter, loc, expandShapeOp.getResultType(),
- expandShapeOp.getReassociationIndices(), srcShape);
- if (!outputShape.has_value())
- return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+ Location loc, PatternRewriter &rewriter,
+ memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+ SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
// Traverse all reassociation groups to determine the appropriate indices
// corresponding to each one of them post op folding.
- for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- // Flag to indicate the presence of dynamic dimensions in current
- // reassociation group.
- int64_t groupSize = groups.size();
-
- // Group output dimensions utilized in this reassociation group for suffix
- // product calculation.
- SmallVector<OpFoldResult> sizesVal(groupSize);
- for (int64_t i = 0; i < groupSize; ++i) {
- sizesVal[i] = (*outputShape)[groups[i]];
+ for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+ if (groupSize == 1) {
+ sourceIndices.push_back(indices[group[0]]);
+ continue;
}
-
- // Calculate suffix product of relevant output dimension sizes.
- SmallVector<OpFoldResult> suffixProduct =
- memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
- // Create affine expression variables for dimensions and symbols in the
- // newly constructed affine map.
- SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
- bindDimsList<AffineExpr>(ctx, dims);
- bindSymbolsList<AffineExpr>(ctx, symbols);
-
- // Linearize binded dimensions and symbols to construct the resultant
- // affine expression for this indice.
- AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
- // Record the load index corresponding to each dimension in the
- // reassociation group. These are later supplied as operands to the affine
- // map used for calulating relevant index post op folding.
- SmallVector<OpFoldResult> dynamicIndices(groupSize);
- for (int64_t i = 0; i < groupSize; i++)
- dynamicIndices[i] = indices[groups[i]];
-
- // Supply suffix product results followed by load op indices as operands
- // to the map.
- SmallVector<OpFoldResult> mapOperands;
- llvm::append_range(mapOperands, suffixProduct);
- llvm::append_range(mapOperands, dynamicIndices);
-
- // Creating maximally folded and composed affine.apply composes better
- // with other transformations without interleaving canonicalization
- // passes.
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/groupSize,
- /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
- mapOperands);
-
- // Push index value in the op post folding corresponding to this
- // reassociation group.
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ SmallVector<OpFoldResult> groupBasis =
+ llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+ SmallVector<Value> groupIndices =
+ llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+ Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+ loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+ sourceIndices.push_back(collapsedIndex);
}
return success();
}
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- int64_t cnt = 0;
- SmallVector<OpFoldResult> dynamicIndices;
- for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- dynamicIndices.push_back(indices[cnt++]);
- int64_t groupSize = groups.size();
-
- // Calculate suffix product for all collapse op source dimension sizes
- // except the most major one of each group.
- // We allow the most major source dimension to be dynamic but enforce all
- // others to be known statically.
- SmallVector<int64_t> sizes(groupSize, 1);
- for (int64_t i = 1; i < groupSize; ++i) {
- sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
- if (sizes[i] == ShapedType::kDynamic)
- return failure();
- }
- SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
- // Derive the index values along all dimensions of the source corresponding
- // to the index wrt to collapsed shape op output.
- auto d0 = rewriter.getAffineDimExpr(0);
- SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
- // Construct the AffineApplyOp for each delinearizingExpr.
- for (int64_t i = 0; i < groupSize; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
- delinearizingExprs[i]),
- dynamicIndices);
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ MemRefType sourceType = collapseShapeOp.getSrcType();
+ // Note: collapse_shape requires a strided memref, we can do this.
+ auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, collapseShapeOp.getSrc());
+ SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+ for (auto [index, group] :
+ llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+
+ if (groupSize == 1) {
+ sourceIndices.push_back(index);
+ continue;
}
- dynamicIndices.clear();
+
+ SmallVector<OpFoldResult> basis =
+ llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+ auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ loc, index, basis, /*hasOuterBound=*/true);
+ llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, zeroAffineMap, dynamicIndices);
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.load and affine.load guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.store and affine.store guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
%1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
%1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
%1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
%1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return %0 : f32
}
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return %[[VAL1]] : f32
// -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return
}
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
}
return
}
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0...
[truncated]
|
Ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good to me. It isn't strictly required, by given that book h of us work on the same downstream project, does this pass with the said downstream project. But this looks good to me
This PR updates the FoldMemRefAliasOps to use `affine.linearize_index` and `affine.delinearize_index` to perform the index computations needed to fold a `memref.expand_shape` or `memref.collapse_shape` into its consumers, respectively. This also loosens some limitations of the pass: 1. The existing `output_shape` argument to `memref.expand_shape` is now used, eliminating the need to re-infer this shape or call `memref.dim`. 2. Because we're using `affine.delinearize_index`, the restriction that each group in a `memref.collapse_shape` can only have one dynamic dimension is removed.
ae27c6f
to
249e426
Compare
Thanks! |
This PR updates the FoldMemRefAliasOps to use
affine.linearize_index
andaffine.delinearize_index
to perform the index computations needed to fold amemref.expand_shape
ormemref.collapse_shape
into its consumers, respectively.This also loosens some limitations of the pass:
output_shape
argument tomemref.expand_shape
is now used, eliminating the need to re-infer this shape or callmemref.dim
.affine.delinearize_index
, the restriction that each group in amemref.collapse_shape
can only have one dynamic dimension is removed.