-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][vector] Additional transpose folding #138347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
76c06a7
to
f8b561d
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesFold transpose with unit-dimensions. Seen in the wild:
This transpose can be folded because (1) it preserves the shape and (2) the shuffled dims are unit extent. Also addresses comment about static vs anonymous namespace: #135841 (comment) Full diff: https://github.com/llvm/llvm-project/pull/138347.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0f96442bc3756..dc30741daccd4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5575,13 +5575,11 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
-namespace {
-
/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+static bool isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
VectorType sourceType = transpose.getSourceVectorType();
ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5601,8 +5599,6 @@ bool isOrderPreserving(TransposeOp transpose) {
return true;
}
-} // namespace
-
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
@@ -5999,18 +5995,22 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
return ub::PoisonAttr::get(getContext());
- // Eliminate identity transpose ops. This happens when the dimensions of the
- // input vector remain in their original order after the transpose operation.
- ArrayRef<int64_t> perm = getPermutation();
-
- // Check if the permutation of the dimensions contains sequential values:
- // {0, 1, 2, ...}.
- for (int64_t i = 0, e = perm.size(); i < e; i++) {
- if (perm[i] != i)
- return {};
- }
+ // Eliminate identity transposes, and more generally any transposes that
+ // preserves the shape without permuting elements.
+ //
+ // Examples of what to fold:
+ // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
+ // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
+ // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
+ //
+ // Example of what NOT to fold:
+ // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
+ //
+ if (getSourceVectorType() == getResultVectorType() &&
+ isOrderPreserving(*this))
+ return getVector();
- return getVector();
+ return {};
}
LogicalResult vector::TransposeOp::verify() {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 99f0850000a16..974f4506a2ef0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -450,28 +450,6 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
// -----
-// CHECK-LABEL: transpose_1D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
-func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
- // CHECK-NOT: transpose
- %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
- // CHECK-NEXT: return [[ARG]]
- return %0 : vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: transpose_2D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
-func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
- // CHECK-NOT: transpose
- %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
- // CHECK-NEXT: return [[ARG]]
- return %0 : vector<4x3xf32>
-}
-
-// -----
-
// CHECK-LABEL: transpose_3D_identity
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 91ee0d335ecca..aa227c4628d42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -247,3 +247,36 @@ func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8>
%1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
return %1 : vector<2x3xi8>
}
+
+// -----
+
+// Test of transpose folding
+// CHECK-LABEL: transpose_1D_identity
+// CHECK-SAME: [[ARG:%.*]]: vector<4xf32>
+// CHECK-NEXT: return [[ARG]]
+func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// Test of transpose folding
+// CHECK-LABEL: transpose_2D_identity
+// CHECK-SAME: [[ARG:%.*]]: vector<4x3xf32>
+// CHECK-NEXT: return [[ARG]]
+func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
+ %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
+ return %0 : vector<4x3xf32>
+}
+
+// -----
+
+// Test of transpose folding
+// CHECK-LABEL: transpose_shape_and_order_preserving
+// CHECK-SAME: [[ARG:%.*]]: vector<6x1x1x4xi8>
+// CHECK-NEXT: return [[ARG]]
+func.func @transpose_shape_and_order_preserving(%arg : vector<6x1x1x4xi8>) -> vector<6x1x1x4xi8> {
+ %0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>
+ return %0 : vector<6x1x1x4xi8>
+}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 83395504e8c74..a730f217f027d 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -65,13 +65,15 @@ func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32>
return %0 : vector<1x8x8xf32>
}
-// CHECK-LABEL: func @transpose1023_1x1x8x8xf32(
-func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
- // Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
- // CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32>
- // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
- %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
- return %0 : vector<1x1x8x8xf32>
+// CHECK-LABEL: func @transpose1023_2x1x8x4xf32(
+func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> {
+ // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed!
+ // CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32>
+ // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+ // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32>
+ %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32>
+ return %0 : vector<1x2x8x4xf32>
}
/// Scalable dim should not be unrolled.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left a couple of minor comments, but nothing major and approving as is - this is clearly an improvement. LGTM!
Btw, this is a very nice clean-up!
// CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> | ||
// CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, these were missing before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed the type of the (result) vector from <1x1x..> to <1x2x..>
Why? Because now the <1x1x...> case gets folded, so we end up with return %arg0
. Based on the original comment in the test
"Note the single 2-D extract/insert pair since 2 and 3 are not transposed!"
I assume changing the sizes of dims 0/1 retains the original goal of the test.
|
||
// ----- | ||
|
||
// Test of transpose folding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just noticed that you add this note to every test - thanks! Nit suggestion - why not add block comments, e.g.:
llvm-project/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
Lines 3 to 5 in 0009a17
//----------------------------------------------------------------------------- | |
// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern] | |
//----------------------------------------------------------------------------- |
The actual format is secondary - my main suggestion would be to avoid duplicating the same comment and to create "sections" to group similar tests together. Especially given that this file keeps growing 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While working on #139706, I've noticed that these comments were incorrect/flipped (please double check). I prepared #139699 to address that.
I am happy to re-factor the format (yours is more descriptive), or wait for you to update things instead. Please let me know if you have a preference :)
EDIT: Feel free to land this and I will update my PR using your format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've landed #139699, so you will have to rebase. Sorry about the noise :(
Btw, before merging my PR I updated the block comments to match the format proposed here. I hope that that's OK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally fine, I actually expected you to land it with your new format.. I should have communicated better yesterday.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding a negative test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Signed-off-by: James Newling <[email protected]>
98f1d06
to
d92bf2c
Compare
Fold transpose with unit-dimensions. Seen in the wild:
This transpose can be folded because (1) it preserves the shape and (2) the shuffled dims are unit extent.
Also addresses comment about static vs anonymous namespace: #135841 (comment)