-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir] Support rank-reduced extract_slice in ExtractSliceOfPadTensorSwapPattern #138921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Vivian Zhang (yzhang93) ChangesThis PR fixes Full diff: https://github.com/llvm/llvm-project/pull/138921.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 6700b4e0c2cb6..8718c57b9e86c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1017,9 +1017,22 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
sliceOp.getMixedSizes(), zeroSliceGuard);
if (failed(tilingResult))
return failure();
- // All shapes are static and the data source is actually used. Rewrite into
- // pad(extract_slice(x)).
- rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+
+ RankedTensorType sourceType = sliceOp.getSourceType();
+ RankedTensorType resultType = sliceOp.getResultType();
+
+ // If the extract_slice is not rank-reduced, all shapes are static and the
+ // data source is actually used. Rewrite into pad(extract_slice(x)).
+ if (sourceType.getRank() == resultType.getRank()) {
+ rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
+ return success();
+ }
+
+ // Handle rank-reduced slice by creating another extract_slice op.
+ Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
+
+ rewriter.replaceOp(sliceOp, rankReduced);
return success();
}
diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
index d43b9a7ac6c04..6a056bab98807 100644
--- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
+++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
@@ -129,6 +129,26 @@ func.func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
// -----
+// CHECK-LABEL: @static_rank_reduce
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x16x4xf32>, %[[PADVAL:.*]]: f32
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 14, 4] [1, 1, 1] : tensor<8x16x4xf32> to tensor<1x14x4xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[SLICE]] low[0, 2, 0] high[0, 0, 0] {
+// CHECK: } : tensor<1x14x4xf32> to tensor<1x16x4xf32>
+// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[PADDED]][0, 0, 0] [1, 16, 4] [1, 1, 1] : tensor<1x16x4xf32> to tensor<16x4xf32>
+// CHECK: return %[[RESULT]]
+func.func @static_rank_reduce(%arg0: tensor<8x16x4xf32>, %pad: f32)
+ -> tensor<16x4xf32> {
+ %0 = tensor.pad %arg0 low[0, 2, 0] high[0, 0, 0] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %pad : f32
+ } : tensor<8x16x4xf32> to tensor<8x18x4xf32>
+ %1 = tensor.extract_slice %0[0, 0, 0] [1, 16, 4] [1, 1, 1]
+ : tensor<8x18x4xf32> to tensor<16x4xf32>
+ return %1 : tensor<16x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @dynamic_high_pad
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>
// CHECK-NOT: tensor.pad
@@ -217,6 +237,27 @@ func.func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
return %1 : tensor<?x?xf32>
}
+// -----
+
+// CHECK-LABEL: @dynamic_rank_reduce
+// CHECK: %[[TEMP:.*]] = scf.if %{{.*}} -> (tensor<1x4xf32>) {
+// CHECK: tensor.generate
+// CHECK: } else {
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %{{.*}} : tensor<?x5xf32> to tensor<?x1xf32>
+// CHECK: tensor.pad %[[SLICE]] low[0, 0] high[%{{.*}}, 3] {
+// CHECK: } : tensor<?x1xf32> to tensor<1x4xf32>
+// CHECK: }
+// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[TEMP]]{{.*}} : tensor<1x4xf32> to tensor<4xf32>
+// CHECK: return %[[RESULT]]
+func.func @dynamic_rank_reduce(%arg0 : tensor<?x5xf32>, %s1: index, %pad : f32) -> tensor<4xf32> {
+ %0 = tensor.pad %arg0 low[0, 0] high[7, 8] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %pad : f32
+ } : tensor<?x5xf32> to tensor<?x13xf32>
+ %1 = tensor.extract_slice %0[2, 4] [1, 4] [1, 1] : tensor<?x13xf32> to tensor<4xf32>
+ return %1 : tensor<4xf32>
+}
+
// -----
// CHECK-LABEL: @nopaddim_with_dynamic_extract(
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4x5xf32>
|
This PR fixes
ExtractSliceOfPadTensorSwapPattern
to support rank-reducingtensor.extract_slice
ops, which were previously unhandled and could cause crashes. To support this, an additionaltensor.extract_slice
is inserted aftertensor.pad
to reduce the result rank.