diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f6c3c6a61afb6..79bf87ccd34af 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5573,13 +5573,11 @@ LogicalResult ShapeCastOp::verify() { return success(); } -namespace { - /// Return true if `transpose` does not permute a pair of non-unit dims. /// By `order preserving` we mean that the flattened versions of the input and /// output vectors are (numerically) identical. In other words `transpose` is /// effectively a shape cast. -bool isOrderPreserving(TransposeOp transpose) { +static bool isOrderPreserving(TransposeOp transpose) { ArrayRef permutation = transpose.getPermutation(); VectorType sourceType = transpose.getSourceVectorType(); ArrayRef inShape = sourceType.getShape(); @@ -5599,8 +5597,6 @@ bool isOrderPreserving(TransposeOp transpose) { return true; } -} // namespace - OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { VectorType resultType = getType(); @@ -5997,18 +5993,22 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { if (llvm::dyn_cast_if_present(adaptor.getVector())) return ub::PoisonAttr::get(getContext()); - // Eliminate identity transpose ops. This happens when the dimensions of the - // input vector remain in their original order after the transpose operation. - ArrayRef perm = getPermutation(); - - // Check if the permutation of the dimensions contains sequential values: - // {0, 1, 2, ...}. - for (int64_t i = 0, e = perm.size(); i < e; i++) { - if (perm[i] != i) - return {}; - } + // Eliminate identity transposes, and more generally any transposes that + // preserves the shape without permuting elements. + // + // Examples of what to fold: + // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8> + // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8> + // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8> + // + // Example of what NOT to fold: + // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8> + // + if (getSourceVectorType() == getResultVectorType() && + isOrderPreserving(*this)) + return getVector(); - return getVector(); + return {}; } LogicalResult vector::TransposeOp::verify() { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 99f0850000a16..974f4506a2ef0 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -450,28 +450,6 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>, // ----- -// CHECK-LABEL: transpose_1D_identity -// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>) -func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> { - // CHECK-NOT: transpose - %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32> - // CHECK-NEXT: return [[ARG]] - return %0 : vector<4xf32> -} - -// ----- - -// CHECK-LABEL: transpose_2D_identity -// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>) -func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> { - // CHECK-NOT: transpose - %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32> - // CHECK-NEXT: return [[ARG]] - return %0 : vector<4x3xf32> -} - -// ----- - // CHECK-LABEL: transpose_3D_identity // CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> { diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir index 7d8daec4dcba7..c84aea6609665 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir @@ -1,6 +1,10 @@ // RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s -// This file contains some canonicalizations tests involving vector.transpose. +// This file contains some tests of canonicalizations and foldings involving vector.transpose. + +// +--------------------------------------------------------------------------- +// Tests of FoldTransposeBroadcast: transpose(broadcast) -> broadcast +// +--------------------------------------------------------------------------- // CHECK-LABEL: func @transpose_scalar_broadcast1 // CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>) @@ -248,3 +252,47 @@ func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8> return %1 : vector<2x3xi8> } + +// ----- + +// +----------------------------------- +// Tests of TransposeOp::fold +// +----------------------------------- + +// CHECK-LABEL: transpose_1D_identity +// CHECK-SAME: [[ARG:%.*]]: vector<4xf32> +// CHECK-NEXT: return [[ARG]] +func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> { + %0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: transpose_2D_identity +// CHECK-SAME: [[ARG:%.*]]: vector<4x3xf32> +// CHECK-NEXT: return [[ARG]] +func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> { + %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32> + return %0 : vector<4x3xf32> +} + +// ----- + +// CHECK-LABEL: transpose_shape_and_order_preserving +// CHECK-SAME: [[ARG:%.*]]: vector<6x1x1x4xi8> +// CHECK-NEXT: return [[ARG]] +func.func @transpose_shape_and_order_preserving(%arg : vector<6x1x1x4xi8>) -> vector<6x1x1x4xi8> { + %0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8> + return %0 : vector<6x1x1x4xi8> +} + +// ----- + +// CHECK-LABEL: negative_transpose_fold +// CHECK: [[TRANSP:%.*]] = vector.transpose +// CHECK: return [[TRANSP]] +func.func @negative_transpose_fold(%arg : vector<2x2xi8>) -> vector<2x2xi8> { + %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8> + return %0 : vector<2x2xi8> +} diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index 83395504e8c74..a730f217f027d 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -65,13 +65,15 @@ func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// CHECK-LABEL: func @transpose1023_1x1x8x8xf32( -func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> { - // Note the single 2-D extract/insert pair since 2 and 3 are not transposed! - // CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32> - %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32> - return %0 : vector<1x1x8x8xf32> +// CHECK-LABEL: func @transpose1023_2x1x8x4xf32( +func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> { + // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed! + // CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32> + // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32> + %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32> + return %0 : vector<1x2x8x4xf32> } /// Scalable dim should not be unrolled.