Skip to content

[mlir][MemRef] Use specialized index ops to fold expand/collapse_shape #138930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented May 7, 2025

This PR updates the FoldMemRefAliasOps to use affine.linearize_index and affine.delinearize_index to perform the index computations needed to fold a memref.expand_shape or memref.collapse_shape into its consumers, respectively.

This also loosens some limitations of the pass:

  1. The existing output_shape argument to memref.expand_shape is now used, eliminating the need to re-infer this shape or call memref.dim.
  2. Because we're using affine.delinearize_index, the restriction that each group in a memref.collapse_shape can only have one dynamic dimension is removed.

@krzysz00
Copy link
Contributor Author

krzysz00 commented May 7, 2025

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir-memref

Author: Krzysztof Drewniak (krzysz00)

Changes

This PR updates the FoldMemRefAliasOps to use affine.linearize_index and affine.delinearize_index to perform the index computations needed to fold a memref.expand_shape or memref.collapse_shape into its consumers, respectively.

This also loosens some limitations of the pass:

  1. The existing output_shape argument to memref.expand_shape is now used, eliminating the need to re-infer this shape or call memref.dim.
  2. Because we're using affine.delinearize_index, the restriction that each group in a memref.collapse_shape can only have one dynamic dimension is removed.

Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+10-4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+49-120)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+64-65)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
     according to specified offsets, sizes, and strides.
 
     ```mlir
-    %result1 = memref.reinterpret_cast %arg0 to 
+    %result1 = memref.reinterpret_cast %arg0 to
       offset: [9],
       sizes: [4, 4],
       strides: [16, 2]
     : memref<8x8xf32, strided<[8, 1], offset: 0>> to
       memref<4x4xf32, strided<[16, 2], offset: 9>>
 
-    %result2 = memref.reinterpret_cast %result1 to 
+    %result2 = memref.reinterpret_cast %result1 to
       offset: [0],
       sizes: [2, 2],
       strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
         OpBuilder &b, Location loc, MemRefType expandedType,
         ArrayRef<ReassociationIndices> reassociation,
         ArrayRef<OpFoldResult> inputShape);
+
+    // Return a vector with all the static and dynamic values in the output shape.
+    SmallVector<OpFoldResult> getMixedOutputShape() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+    }
   }];
 
   let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
   let summary = "store operation";
   let description = [{
     The `store` op stores an element into a memref at the specified indices.
-    
+
     The number of indices must match the rank of the memref. The indices must
     be in-bounds: `0 <= idx < dim_size`
 
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     Unlike the `reinterpret_cast`, the values are relative to the strided
     memref of the input (`%result1` in this case) and not its
     underlying memory.
-    
+
     Example 2:
 
     ```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
 ///
 /// %2 = load %0[6 * i1 + i2, %i3] :
 ///          memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
-                                memref::ExpandShapeOp expandShapeOp,
-                                ValueRange indices,
-                                SmallVectorImpl<Value> &sourceIndices) {
-  // Record the rewriter context for constructing ops later.
-  MLIRContext *ctx = rewriter.getContext();
-
-  // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
-  // This is done for the purpose of inferring the output shape via
-  // `inferExpandOutputShape` which will in turn be used for suffix product
-  // calculation later.
-  SmallVector<OpFoldResult> srcShape;
-  MemRefType srcType = expandShapeOp.getSrcType();
-
-  for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
-    if (srcType.isDynamicDim(i)) {
-      srcShape.push_back(
-          rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
-              .getResult());
-    } else {
-      srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
-    }
-  }
-
-  auto outputShape = inferExpandShapeOutputShape(
-      rewriter, loc, expandShapeOp.getResultType(),
-      expandShapeOp.getReassociationIndices(), srcShape);
-  if (!outputShape.has_value())
-    return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
 
   // Traverse all reassociation groups to determine the appropriate indices
   // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    // Flag to indicate the presence of dynamic dimensions in current
-    // reassociation group.
-    int64_t groupSize = groups.size();
-
-    // Group output dimensions utilized in this reassociation group for suffix
-    // product calculation.
-    SmallVector<OpFoldResult> sizesVal(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i) {
-      sizesVal[i] = (*outputShape)[groups[i]];
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
     }
-
-    // Calculate suffix product of relevant output dimension sizes.
-    SmallVector<OpFoldResult> suffixProduct =
-        memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
-    // Create affine expression variables for dimensions and symbols in the
-    // newly constructed affine map.
-    SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
-    bindDimsList<AffineExpr>(ctx, dims);
-    bindSymbolsList<AffineExpr>(ctx, symbols);
-
-    // Linearize binded dimensions and symbols to construct the resultant
-    // affine expression for this indice.
-    AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
-    // Record the load index corresponding to each dimension in the
-    // reassociation group. These are later supplied as operands to the affine
-    // map used for calulating relevant index post op folding.
-    SmallVector<OpFoldResult> dynamicIndices(groupSize);
-    for (int64_t i = 0; i < groupSize; i++)
-      dynamicIndices[i] = indices[groups[i]];
-
-    // Supply suffix product results followed by load op indices as operands
-    // to the map.
-    SmallVector<OpFoldResult> mapOperands;
-    llvm::append_range(mapOperands, suffixProduct);
-    llvm::append_range(mapOperands, dynamicIndices);
-
-    // Creating maximally folded and composed affine.apply composes better
-    // with other transformations without interleaving canonicalization
-    // passes.
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc,
-        AffineMap::get(/*numDims=*/groupSize,
-                       /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
-        mapOperands);
-
-    // Push index value in the op post folding corresponding to this
-    // reassociation group.
-    sourceIndices.push_back(
-        getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
   }
   return success();
 }
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
                                   memref::CollapseShapeOp collapseShapeOp,
                                   ValueRange indices,
                                   SmallVectorImpl<Value> &sourceIndices) {
-  int64_t cnt = 0;
-  SmallVector<OpFoldResult> dynamicIndices;
-  for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    dynamicIndices.push_back(indices[cnt++]);
-    int64_t groupSize = groups.size();
-
-    // Calculate suffix product for all collapse op source dimension sizes
-    // except the most major one of each group.
-    // We allow the most major source dimension to be dynamic but enforce all
-    // others to be known statically.
-    SmallVector<int64_t> sizes(groupSize, 1);
-    for (int64_t i = 1; i < groupSize; ++i) {
-      sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
-      if (sizes[i] == ShapedType::kDynamic)
-        return failure();
-    }
-    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
-    // Derive the index values along all dimensions of the source corresponding
-    // to the index wrt to collapsed shape op output.
-    auto d0 = rewriter.getAffineDimExpr(0);
-    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
-    // Construct the AffineApplyOp for each delinearizingExpr.
-    for (int64_t i = 0; i < groupSize; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc,
-          AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
-                         delinearizingExprs[i]),
-          dynamicIndices);
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  MemRefType sourceType = collapseShapeOp.getSrcType();
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
     }
-    dynamicIndices.clear();
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
   }
   if (collapseShapeOp.getReassociationIndices().empty()) {
     auto zeroAffineMap = rewriter.getConstantAffineMap(0);
     int64_t srcRank =
         cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
     for (int64_t i = 0; i < srcRank; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, zeroAffineMap, dynamicIndices);
       sourceIndices.push_back(
           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
     }
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.load and affine.load guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
       .Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.store and affine.store guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
       .Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
   %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
   %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
 // CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
   %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
   %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
   %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return %0 : f32
 }
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return %[[VAL1]] : f32
 
 // -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
   memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return
 }
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return
 
 // -----
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
 // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
 func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
   }
   return
 }
+// CHECK-NEXT:   %[[C0:.*]] = arith.constant 0...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

This PR updates the FoldMemRefAliasOps to use affine.linearize_index and affine.delinearize_index to perform the index computations needed to fold a memref.expand_shape or memref.collapse_shape into its consumers, respectively.

This also loosens some limitations of the pass:

  1. The existing output_shape argument to memref.expand_shape is now used, eliminating the need to re-infer this shape or call memref.dim.
  2. Because we're using affine.delinearize_index, the restriction that each group in a memref.collapse_shape can only have one dynamic dimension is removed.

Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+10-4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+49-120)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+64-65)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
     according to specified offsets, sizes, and strides.
 
     ```mlir
-    %result1 = memref.reinterpret_cast %arg0 to 
+    %result1 = memref.reinterpret_cast %arg0 to
       offset: [9],
       sizes: [4, 4],
       strides: [16, 2]
     : memref<8x8xf32, strided<[8, 1], offset: 0>> to
       memref<4x4xf32, strided<[16, 2], offset: 9>>
 
-    %result2 = memref.reinterpret_cast %result1 to 
+    %result2 = memref.reinterpret_cast %result1 to
       offset: [0],
       sizes: [2, 2],
       strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
         OpBuilder &b, Location loc, MemRefType expandedType,
         ArrayRef<ReassociationIndices> reassociation,
         ArrayRef<OpFoldResult> inputShape);
+
+    // Return a vector with all the static and dynamic values in the output shape.
+    SmallVector<OpFoldResult> getMixedOutputShape() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+    }
   }];
 
   let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
   let summary = "store operation";
   let description = [{
     The `store` op stores an element into a memref at the specified indices.
-    
+
     The number of indices must match the rank of the memref. The indices must
     be in-bounds: `0 <= idx < dim_size`
 
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     Unlike the `reinterpret_cast`, the values are relative to the strided
     memref of the input (`%result1` in this case) and not its
     underlying memory.
-    
+
     Example 2:
 
     ```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
 ///
 /// %2 = load %0[6 * i1 + i2, %i3] :
 ///          memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
-                                memref::ExpandShapeOp expandShapeOp,
-                                ValueRange indices,
-                                SmallVectorImpl<Value> &sourceIndices) {
-  // Record the rewriter context for constructing ops later.
-  MLIRContext *ctx = rewriter.getContext();
-
-  // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
-  // This is done for the purpose of inferring the output shape via
-  // `inferExpandOutputShape` which will in turn be used for suffix product
-  // calculation later.
-  SmallVector<OpFoldResult> srcShape;
-  MemRefType srcType = expandShapeOp.getSrcType();
-
-  for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
-    if (srcType.isDynamicDim(i)) {
-      srcShape.push_back(
-          rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
-              .getResult());
-    } else {
-      srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
-    }
-  }
-
-  auto outputShape = inferExpandShapeOutputShape(
-      rewriter, loc, expandShapeOp.getResultType(),
-      expandShapeOp.getReassociationIndices(), srcShape);
-  if (!outputShape.has_value())
-    return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
 
   // Traverse all reassociation groups to determine the appropriate indices
   // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    // Flag to indicate the presence of dynamic dimensions in current
-    // reassociation group.
-    int64_t groupSize = groups.size();
-
-    // Group output dimensions utilized in this reassociation group for suffix
-    // product calculation.
-    SmallVector<OpFoldResult> sizesVal(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i) {
-      sizesVal[i] = (*outputShape)[groups[i]];
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
     }
-
-    // Calculate suffix product of relevant output dimension sizes.
-    SmallVector<OpFoldResult> suffixProduct =
-        memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
-    // Create affine expression variables for dimensions and symbols in the
-    // newly constructed affine map.
-    SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
-    bindDimsList<AffineExpr>(ctx, dims);
-    bindSymbolsList<AffineExpr>(ctx, symbols);
-
-    // Linearize binded dimensions and symbols to construct the resultant
-    // affine expression for this indice.
-    AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
-    // Record the load index corresponding to each dimension in the
-    // reassociation group. These are later supplied as operands to the affine
-    // map used for calulating relevant index post op folding.
-    SmallVector<OpFoldResult> dynamicIndices(groupSize);
-    for (int64_t i = 0; i < groupSize; i++)
-      dynamicIndices[i] = indices[groups[i]];
-
-    // Supply suffix product results followed by load op indices as operands
-    // to the map.
-    SmallVector<OpFoldResult> mapOperands;
-    llvm::append_range(mapOperands, suffixProduct);
-    llvm::append_range(mapOperands, dynamicIndices);
-
-    // Creating maximally folded and composed affine.apply composes better
-    // with other transformations without interleaving canonicalization
-    // passes.
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc,
-        AffineMap::get(/*numDims=*/groupSize,
-                       /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
-        mapOperands);
-
-    // Push index value in the op post folding corresponding to this
-    // reassociation group.
-    sourceIndices.push_back(
-        getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
   }
   return success();
 }
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
                                   memref::CollapseShapeOp collapseShapeOp,
                                   ValueRange indices,
                                   SmallVectorImpl<Value> &sourceIndices) {
-  int64_t cnt = 0;
-  SmallVector<OpFoldResult> dynamicIndices;
-  for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    dynamicIndices.push_back(indices[cnt++]);
-    int64_t groupSize = groups.size();
-
-    // Calculate suffix product for all collapse op source dimension sizes
-    // except the most major one of each group.
-    // We allow the most major source dimension to be dynamic but enforce all
-    // others to be known statically.
-    SmallVector<int64_t> sizes(groupSize, 1);
-    for (int64_t i = 1; i < groupSize; ++i) {
-      sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
-      if (sizes[i] == ShapedType::kDynamic)
-        return failure();
-    }
-    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
-    // Derive the index values along all dimensions of the source corresponding
-    // to the index wrt to collapsed shape op output.
-    auto d0 = rewriter.getAffineDimExpr(0);
-    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
-    // Construct the AffineApplyOp for each delinearizingExpr.
-    for (int64_t i = 0; i < groupSize; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc,
-          AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
-                         delinearizingExprs[i]),
-          dynamicIndices);
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  MemRefType sourceType = collapseShapeOp.getSrcType();
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
     }
-    dynamicIndices.clear();
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
   }
   if (collapseShapeOp.getReassociationIndices().empty()) {
     auto zeroAffineMap = rewriter.getConstantAffineMap(0);
     int64_t srcRank =
         cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
     for (int64_t i = 0; i < srcRank; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, zeroAffineMap, dynamicIndices);
       sourceIndices.push_back(
           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
     }
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.load and affine.load guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
       .Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.store and affine.store guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
       .Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
   %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
   %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
 // CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
   %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
   %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
   %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return %0 : f32
 }
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return %[[VAL1]] : f32
 
 // -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
   memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return
 }
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return
 
 // -----
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
 // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
 func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
   }
   return
 }
+// CHECK-NEXT:   %[[C0:.*]] = arith.constant 0...
[truncated]

@krzysz00
Copy link
Contributor Author

krzysz00 commented May 9, 2025

Ping

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good to me. It isn't strictly required, by given that book h of us work on the same downstream project, does this pass with the said downstream project. But this looks good to me

Base automatically changed from users/krzysz00/linearize-delinearize-dims to main May 13, 2025 16:13
This PR updates the FoldMemRefAliasOps to use `affine.linearize_index`
and `affine.delinearize_index` to perform the index computations
needed to fold a `memref.expand_shape` or `memref.collapse_shape` into
its consumers, respectively.

This also loosens some limitations of the pass:
1. The existing `output_shape` argument to `memref.expand_shape` is
now used, eliminating the need to re-infer this shape or call
`memref.dim`.
2. Because we're using `affine.delinearize_index`, the restriction
that each group in a `memref.collapse_shape` can only have one dynamic
dimension is removed.
@krzysz00 krzysz00 force-pushed the users/krzysz00/linearize-delinearize-alias-ops-folder branch from ae27c6f to 249e426 Compare May 13, 2025 16:16
@krzysz00 krzysz00 merged commit a891163 into main May 13, 2025
9 of 10 checks passed
@krzysz00 krzysz00 deleted the users/krzysz00/linearize-delinearize-alias-ops-folder branch May 13, 2025 18:28
@kazutakahirata
Copy link
Contributor

@krzysz00 I've landed 2534839 to fix a warning from this PR. Thanks!

@krzysz00
Copy link
Contributor Author

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants