Skip to content

[mlir] Support rank-reduced extract_slice in ExtractSliceOfPadTensorSwapPattern #138921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2025

Conversation

yzhang93
Copy link
Contributor

@yzhang93 yzhang93 commented May 7, 2025

This PR fixes ExtractSliceOfPadTensorSwapPattern to support rank-reducing tensor.extract_slice ops, which were previously unhandled and could cause crashes. To support this, an additional tensor.extract_slice is inserted after tensor.pad to reduce the result rank.

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Vivian Zhang (yzhang93)

Changes

This PR fixes ExtractSliceOfPadTensorSwapPattern to support rank-reducing tensor.extract_slice ops, which were previously unhandled and could cause crashes. To support this, an additional rank-reduced tensor.extract_slice is inserted after tensor.pad to reduce the result rank.


Full diff: https://github.com/llvm/llvm-project/pull/138921.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+16-3)
  • (modified) mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir (+41)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 6700b4e0c2cb6..8718c57b9e86c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1017,9 +1017,22 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
                                sliceOp.getMixedSizes(), zeroSliceGuard);
   if (failed(tilingResult))
     return failure();
-  // All shapes are static and the data source is actually used. Rewrite into
-  // pad(extract_slice(x)).
-  rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+
+  RankedTensorType sourceType = sliceOp.getSourceType();
+  RankedTensorType resultType = sliceOp.getResultType();
+
+  // If the extract_slice is not rank-reduced, all shapes are static and the
+  // data source is actually used. Rewrite into pad(extract_slice(x)).
+  if (sourceType.getRank() == resultType.getRank()) {
+    rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+    return success();
+  }
+
+  // Handle rank-reduced slice by creating another extract_slice op.
+  Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
+
+  rewriter.replaceOp(sliceOp, rankReduced);
   return success();
 }
 
diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
index d43b9a7ac6c04..6a056bab98807 100644
--- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
+++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
@@ -129,6 +129,26 @@ func.func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
 
 // -----
 
+// CHECK-LABEL: @static_rank_reduce
+//  CHECK-SAME:   %[[ARG0:.*]]: tensor<8x16x4xf32>, %[[PADVAL:.*]]: f32
+//       CHECK:   %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 14, 4] [1, 1, 1] : tensor<8x16x4xf32> to tensor<1x14x4xf32>
+//       CHECK:   %[[PADDED:.*]] = tensor.pad %[[SLICE]] low[0, 2, 0] high[0, 0, 0] {
+//       CHECK:   } : tensor<1x14x4xf32> to tensor<1x16x4xf32>
+//       CHECK:   %[[RESULT:.*]] = tensor.extract_slice %[[PADDED]][0, 0, 0] [1, 16, 4] [1, 1, 1] : tensor<1x16x4xf32> to tensor<16x4xf32>
+//       CHECK: return %[[RESULT]]
+func.func @static_rank_reduce(%arg0: tensor<8x16x4xf32>, %pad: f32)
+    -> tensor<16x4xf32> {
+  %0 = tensor.pad %arg0 low[0, 2, 0] high[0, 0, 0] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %pad : f32
+  } : tensor<8x16x4xf32> to tensor<8x18x4xf32>
+  %1 = tensor.extract_slice %0[0, 0, 0] [1, 16, 4] [1, 1, 1]
+      : tensor<8x18x4xf32> to tensor<16x4xf32>
+  return %1 : tensor<16x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @dynamic_high_pad
 //  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x5xf32>
 //   CHECK-NOT:   tensor.pad
@@ -217,6 +237,27 @@ func.func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: @dynamic_rank_reduce
+//       CHECK:   %[[TEMP:.*]] = scf.if %{{.*}} -> (tensor<1x4xf32>) {
+//       CHECK:     tensor.generate
+//       CHECK:   } else {
+//       CHECK:     %[[SLICE:.*]] = tensor.extract_slice %{{.*}} : tensor<?x5xf32> to tensor<?x1xf32>
+//       CHECK:     tensor.pad %[[SLICE]] low[0, 0] high[%{{.*}}, 3] {
+//       CHECK:     } : tensor<?x1xf32> to tensor<1x4xf32>
+//       CHECK:   }
+//       CHECK:   %[[RESULT:.*]] = tensor.extract_slice %[[TEMP]]{{.*}} : tensor<1x4xf32> to tensor<4xf32>
+//       CHECK:   return %[[RESULT]]
+func.func @dynamic_rank_reduce(%arg0 : tensor<?x5xf32>, %s1: index, %pad : f32) -> tensor<4xf32> {
+  %0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad : f32
+    } : tensor<?x5xf32> to tensor<?x13xf32>
+  %1 = tensor.extract_slice %0[2, 4] [1, 4] [1, 1] : tensor<?x13xf32> to tensor<4xf32>
+  return %1 : tensor<4xf32>
+}
+
 // -----
 // CHECK-LABEL: @nopaddim_with_dynamic_extract(
 //  CHECK-SAME:     %[[ARG0:.*]]: tensor<3x4x5xf32>

@yzhang93 yzhang93 merged commit 37fecfa into llvm:main May 8, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants