Skip to content

[mlir][vector] Additional transpose folding #138347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 14, 2025

Conversation

newling
Copy link
Contributor

@newling newling commented May 2, 2025

Fold transpose with unit-dimensions. Seen in the wild:

 %0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>

This transpose can be folded because (1) it preserves the shape and (2) the shuffled dims are unit extent.

Also addresses comment about static vs anonymous namespace: #135841 (comment)

@newling newling force-pushed the further_vector_transpose_folding branch from 76c06a7 to f8b561d Compare May 5, 2025 23:42
@newling newling marked this pull request as ready for review May 6, 2025 06:42
@llvmbot
Copy link
Member

llvmbot commented May 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Fold transpose with unit-dimensions. Seen in the wild:

 %0 = vector.transpose %arg, [0, 2, 1, 3] : vector&lt;6x1x1x4xi8&gt; to vector&lt;6x1x1x4xi8&gt;

This transpose can be folded because (1) it preserves the shape and (2) the shuffled dims are unit extent.

Also addresses comment about static vs anonymous namespace: #135841 (comment)


Full diff: https://github.com/llvm/llvm-project/pull/138347.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+16-16)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (-22)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir (+33)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (+9-7)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0f96442bc3756..dc30741daccd4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5575,13 +5575,11 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
-namespace {
-
 /// Return true if `transpose` does not permute a pair of non-unit dims.
 /// By `order preserving` we mean that the flattened versions of the input and
 /// output vectors are (numerically) identical. In other words `transpose` is
 /// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+static bool isOrderPreserving(TransposeOp transpose) {
   ArrayRef<int64_t> permutation = transpose.getPermutation();
   VectorType sourceType = transpose.getSourceVectorType();
   ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5601,8 +5599,6 @@ bool isOrderPreserving(TransposeOp transpose) {
   return true;
 }
 
-} // namespace
-
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
@@ -5999,18 +5995,22 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
     return ub::PoisonAttr::get(getContext());
 
-  // Eliminate identity transpose ops. This happens when the dimensions of the
-  // input vector remain in their original order after the transpose operation.
-  ArrayRef<int64_t> perm = getPermutation();
-
-  // Check if the permutation of the dimensions contains sequential values:
-  // {0, 1, 2, ...}.
-  for (int64_t i = 0, e = perm.size(); i < e; i++) {
-    if (perm[i] != i)
-      return {};
-  }
+  // Eliminate identity transposes, and more generally any transposes that
+  // preserves the shape without permuting elements.
+  //
+  // Examples of what to fold:
+  // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
+  // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
+  // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
+  //
+  // Example of what NOT to fold:
+  // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
+  //
+  if (getSourceVectorType() == getResultVectorType() &&
+      isOrderPreserving(*this))
+    return getVector();
 
-  return getVector();
+  return {};
 }
 
 LogicalResult vector::TransposeOp::verify() {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 99f0850000a16..974f4506a2ef0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -450,28 +450,6 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
 
 // -----
 
-// CHECK-LABEL: transpose_1D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
-func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
-  // CHECK-NOT: transpose
-  %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
-  // CHECK-NEXT: return [[ARG]]
-  return %0 : vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: transpose_2D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
-func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
-  // CHECK-NOT: transpose
-  %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
-  // CHECK-NEXT: return [[ARG]]
-  return %0 : vector<4x3xf32>
-}
-
-// -----
-
 // CHECK-LABEL: transpose_3D_identity
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
 func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index 91ee0d335ecca..aa227c4628d42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -247,3 +247,36 @@ func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8>
   %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
   return %1 : vector<2x3xi8>
 }
+
+// -----
+
+// Test of transpose folding
+//  CHECK-LABEL: transpose_1D_identity
+//   CHECK-SAME:   [[ARG:%.*]]: vector<4xf32>
+//   CHECK-NEXT:   return [[ARG]]
+func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// Test of transpose folding
+// CHECK-LABEL: transpose_2D_identity
+//  CHECK-SAME:    [[ARG:%.*]]: vector<4x3xf32>
+//  CHECK-NEXT:    return [[ARG]]
+func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
+  %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
+  return %0 : vector<4x3xf32>
+}
+
+// -----
+
+// Test of transpose folding
+// CHECK-LABEL: transpose_shape_and_order_preserving
+//  CHECK-SAME:    [[ARG:%.*]]: vector<6x1x1x4xi8>
+//  CHECK-NEXT:    return [[ARG]]
+func.func @transpose_shape_and_order_preserving(%arg : vector<6x1x1x4xi8>) -> vector<6x1x1x4xi8> {
+  %0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>
+  return %0 : vector<6x1x1x4xi8>
+}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 83395504e8c74..a730f217f027d 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -65,13 +65,15 @@ func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32>
   return %0 : vector<1x8x8xf32>
 }
 
-// CHECK-LABEL:   func @transpose1023_1x1x8x8xf32(
-func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
-  // Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
-  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32>
-  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
-  %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
-  return %0 : vector<1x1x8x8xf32>
+// CHECK-LABEL:   func @transpose1023_2x1x8x4xf32(
+func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> {
+  // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed!
+  //      CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32>
+  // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32>
+  %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32>
+  return %0 : vector<1x2x8x4xf32>
 }
 
 /// Scalable dim should not be unrolled.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left a couple of minor comments, but nothing major and approving as is - this is clearly an improvement. LGTM!

Btw, this is a very nice clean-up!

Comment on lines +73 to +74
// CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
// CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, these were missing before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed the type of the (result) vector from <1x1x..> to <1x2x..>

Why? Because now the <1x1x...> case gets folded, so we end up with return %arg0. Based on the original comment in the test

"Note the single 2-D extract/insert pair since 2 and 3 are not transposed!"

I assume changing the sizes of dims 0/1 retains the original goal of the test.


// -----

// Test of transpose folding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just noticed that you add this note to every test - thanks! Nit suggestion - why not add block comments, e.g.:

//-----------------------------------------------------------------------------
// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
//-----------------------------------------------------------------------------

The actual format is secondary - my main suggestion would be to avoid duplicating the same comment and to create "sections" to group similar tests together. Especially given that this file keeps growing 😅

Copy link
Contributor

@banach-space banach-space May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While working on #139706, I've noticed that these comments were incorrect/flipped (please double check). I prepared #139699 to address that.

I am happy to re-factor the format (yours is more descriptive), or wait for you to update things instead. Please let me know if you have a preference :)

EDIT: Feel free to land this and I will update my PR using your format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've landed #139699, so you will have to rebase. Sorry about the noise :(

Btw, before merging my PR I updated the block comments to match the format proposed here. I hope that that's OK.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally fine, I actually expected you to land it with your new format.. I should have communicated better yesterday.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a negative test?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@newling newling force-pushed the further_vector_transpose_folding branch from 98f1d06 to d92bf2c Compare May 14, 2025 19:25
@newling newling merged commit 21f1a61 into llvm:main May 14, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants