diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index cf9bb3a000050..156c679c5039e 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, if (inVecType.getShape().empty()) { Value zerodSplat = - rewriter.createOrFold(loc, outType, zero); + rewriter.createOrFold(loc, outType, zero); Value scalarIn = rewriter.create(loc, in, ArrayRef{}); Value scalarExt = @@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, VectorType flatTy = VectorType::get(SmallVector{numElements}, outType.getElementType()); - Value result = rewriter.createOrFold(loc, flatTy, zero); + Value result = rewriter.createOrFold(loc, flatTy, zero); if (inVecType.getRank() > 1) { inVecType = VectorType::get(SmallVector{numElements}, @@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, VectorType flatTy = VectorType::get(SmallVector{numElements}, outVecType.getElementType()); - Value result = rewriter.createOrFold(loc, flatTy, zero); + Value result = rewriter.createOrFold(loc, flatTy, zero); if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector{numElements}, @@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( int64_t numElements = outVecType.getNumElements(); Value zero = rewriter.createOrFold( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); - Value result = rewriter.createOrFold(loc, outVecType, zero); + Value result = + rewriter.createOrFold(loc, outVecType, zero); if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector{numElements}, @@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, VectorType extScaleResultType = VectorType::get(opWidth, outType); if (!outVecType) { - Value inCast = - rewriter.create(loc, VectorType::get(1, inType), in); + Value inCast = rewriter.create( + loc, VectorType::get(1, inType), in); // TODO: replace this with non-packed ScaledExtOp Value scaleExt = rewriter.create( loc, extScaleResultType, inCast, scale, 0); @@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Value zero = rewriter.create( loc, outType, rewriter.getFloatAttr(outType, 0.0)); - Value result = rewriter.createOrFold(loc, outVecType, zero); + Value result = + rewriter.createOrFold(loc, outVecType, zero); for (SmallVector offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector strides(offsets.size(), 1); @@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = - rewriter.createOrFold(loc, blockResultType, zero); + rewriter.createOrFold(loc, blockResultType, zero); for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); i < blockSize; @@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, if (!outVecType) { Type inVecType = VectorType::get(1, inType); - Value inCast = rewriter.create(loc, inVecType, in); + Value inCast = rewriter.create(loc, inVecType, in); // TODO: replace this with non-packed ScaledTruncOp Value scaleTrunc = rewriter.create( loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr); @@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, int64_t blockSize = computeProduct(ratio); - Value result = rewriter.createOrFold(loc, outVecType, zero); + Value result = + rewriter.createOrFold(loc, outVecType, zero); for (SmallVector offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector strides(offsets.size(), 1); @@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = - rewriter.createOrFold(loc, blockResultType, zero); + rewriter.createOrFold(loc, blockResultType, zero); for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); i < blockSize; diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 18adaa793787c..9a8eb72d72925 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, op.getLoc(), vectorType.getElementType(), rewriter.getZeroAttr(vectorType.getElementType())); Value result = - rewriter.create(op.getLoc(), fill, vectorType); + rewriter.create(op.getLoc(), vectorType, fill); bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 501d98862672d..345da05b99e54 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering /// ``` /// is rewritten into: /// ``` -/// %r = splat %f0: vector<2x4xf32> +/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32> /// %va = vector.extractvalue %a[0] : vector<2x4xf32> /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> @@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern { auto elemType = vType.getElementType(); Value zero = rewriter.create( loc, elemType, rewriter.getZeroAttr(elemType)); - Value desc = rewriter.create(loc, vType, zero); + Value desc = rewriter.create(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { Value extrLHS = rewriter.create(loc, op.getLhs(), i); Value extrRHS = rewriter.create(loc, op.getRhs(), i); @@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion /*isScalable=*/true)); auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, adaptor.getOperands()[0]); - Value bounds = rewriter.create(loc, indices.getType(), bound); + Value bounds = rewriter.create(loc, indices.getType(), bound); Value comp = rewriter.create(loc, arith::CmpIPredicate::slt, indices, bounds); rewriter.replaceOp(op, comp); @@ -1767,63 +1767,77 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { } }; -/// The Splat operation is lowered to an insertelement + a shufflevector -/// operation. Splat to only 0-d and 1-d vector result types are lowered. -struct VectorSplatOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +/// A broadcast of a scalar is lowered to an insertelement + a shufflevector +/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this +/// pattern, the higher rank cases are handled by another pattern. +struct VectorBroadcastScalarToLowRankLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, + matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType resultType = cast(splatOp.getType()); + if (isa(broadcast.getSourceType())) + return rewriter.notifyMatchFailure( + broadcast, "broadcast from vector type not handled"); + + VectorType resultType = broadcast.getType(); if (resultType.getRank() > 1) - return failure(); + return rewriter.notifyMatchFailure(broadcast, + "broadcast to 2+-d handled elsewhere"); // First insert it into a poison vector so we can shuffle it. - auto vectorType = typeConverter->convertType(splatOp.getType()); + auto vectorType = typeConverter->convertType(broadcast.getType()); Value poison = - rewriter.create(splatOp.getLoc(), vectorType); + rewriter.create(broadcast.getLoc(), vectorType); auto zero = rewriter.create( - splatOp.getLoc(), + broadcast.getLoc(), typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); // For 0-d vector, we simply do `insertelement`. if (resultType.getRank() == 0) { rewriter.replaceOpWithNewOp( - splatOp, vectorType, poison, adaptor.getInput(), zero); + broadcast, vectorType, poison, adaptor.getSource(), zero); return success(); } // For 1-d vector, we additionally do a `vectorshuffle`. auto v = rewriter.create( - splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero); + broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); - int64_t width = cast(splatOp.getType()).getDimSize(0); + int64_t width = cast(broadcast.getType()).getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. - rewriter.replaceOpWithNewOp(splatOp, v, poison, + rewriter.replaceOpWithNewOp(broadcast, v, poison, zeroValues); return success(); } }; -/// The Splat operation is lowered to an insertelement + a shufflevector -/// operation. Splat to only 2+-d vector result types are lowered by the -/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. -struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +/// The broadcast of a scalar is lowered to an insertelement + a shufflevector +/// operation. Only broadcasts to 2+-d vector result types are lowered by this +/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors +/// are not converted to LLVM, only broadcasts from scalars are. +struct VectorBroadcastScalarToNdLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, + matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType resultType = splatOp.getType(); + if (isa(broadcast.getSourceType())) + return rewriter.notifyMatchFailure( + broadcast, "broadcast from vector type not handled"); + + VectorType resultType = broadcast.getType(); if (resultType.getRank() <= 1) - return failure(); + return rewriter.notifyMatchFailure( + broadcast, "broadcast to 1-d or 0-d handled elsewhere"); // First insert it into an undef vector so we can shuffle it. - auto loc = splatOp.getLoc(); + auto loc = broadcast.getLoc(); auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; @@ -1834,26 +1848,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern { // Construct returned value. Value desc = rewriter.create(loc, llvmNDVectorTy); - // Construct a 1-D vector with the splatted value that we insert in all the - // places within the returned descriptor. + // Construct a 1-D vector with the broadcasted value that we insert in all + // the places within the returned descriptor. Value vdesc = rewriter.create(loc, llvm1DVectorTy); auto zero = rewriter.create( loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvm1DVectorTy, vdesc, - adaptor.getInput(), zero); + adaptor.getSource(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector zeroValues(width, 0); v = rewriter.create(loc, v, v, zeroValues); - // Iterate of linear index, convert to coords space and insert splatted 1-D - // vector in each position. + // Iterate of linear index, convert to coords space and insert broadcasted + // 1-D vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef position) { desc = rewriter.create(loc, desc, v, position); }); - rewriter.replaceOp(splatOp, desc); + rewriter.replaceOp(broadcast, desc); return success(); } }; @@ -2035,6 +2049,19 @@ struct VectorScalableStepOpLowering } }; +/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from +/// `vector.broadcast` through other patterns. +struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(splat, splat.getType(), + adaptor.getInput()); + return success(); + } +}; + } // namespace void mlir::vector::populateVectorRankReducingFMAPattern( @@ -2063,7 +2090,8 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorSplatOpLowering, VectorSplatNdOpLowering, + VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering, + VectorBroadcastScalarToNdLowering, VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, MaskedReductionOpConversion, VectorInterleaveOpLowering, VectorDeinterleaveOpLowering, VectorFromElementsLowering, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index ed51b2126dcdd..43732f58a4e0a 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -444,7 +444,7 @@ struct Strategy { Location loc = xferOp.getLoc(); auto bufferType = dyn_cast(buffer.getType()); auto vecType = dyn_cast(bufferType.getElementType()); - auto vec = b.create(loc, vecType, xferOp.getPadding()); + auto vec = b.create(loc, vecType, xferOp.getPadding()); b.create(loc, vec, buffer, storeIndices); return Value(); @@ -1261,8 +1261,8 @@ struct UnrollTransferReadConversion if (auto insertOp = getInsertOp(xferOp)) return insertOp.getDest(); Location loc = xferOp.getLoc(); - return rewriter.create(loc, xferOp.getVectorType(), - xferOp.getPadding()); + return rewriter.create(loc, xferOp.getVectorType(), + xferOp.getPadding()); } /// If the result of the TransferReadOp has exactly one user, which is a @@ -1583,8 +1583,8 @@ struct Strategy1d { static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { // Inititalize vector with padding value. Location loc = xferOp.getLoc(); - return b.create(loc, xferOp.getVectorType(), - xferOp.getPadding()); + return b.create(loc, xferOp.getVectorType(), + xferOp.getPadding()); } }; diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 21d8e1d9f1156..750ce85049409 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern { } }; +// Convert `vector.splat` to `vector.broadcast`. There is a path from +// `vector.broadcast` to SPIRV via other patterns. +struct VectorSplatToBroadcast final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(splat, splat.getType(), + adaptor.getInput()); + return success(); + } +}; + struct VectorBitcastConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -556,22 +570,27 @@ struct VectorReductionFloatMinMax final } }; -class VectorSplatPattern final : public OpConversionPattern { +class VectorScalarBroadcastPattern final + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, + matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (isa(op.getSourceType())) { + return rewriter.notifyMatchFailure( + op, "only conversion of 'broadcast from scalar' is supported"); + } Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); if (isa(dstType)) { - rewriter.replaceOp(op, adaptor.getInput()); + rewriter.replaceOp(op, adaptor.getSource()); } else { auto dstVecType = cast(dstType); SmallVector source(dstVecType.getNumElements(), - adaptor.getInput()); + adaptor.getSource()); rewriter.replaceOpWithNewOp(op, dstType, source); } @@ -1089,11 +1108,11 @@ void mlir::populateVectorToSPIRVPatterns( VectorReductionPattern, VectorReductionFloatMinMax, VectorReductionFloatMinMax, VectorShapeCast, - VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, - VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter, - VectorStepOpConvert>(typeConverter, patterns.getContext(), - PatternBenefit(1)); + VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert, + VectorShuffleOpConvert, VectorInterleaveOpConvert, + VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern, + VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>( + typeConverter, patterns.getContext(), PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit // than the generic one that extracts all elements. diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir index 095f3e575eca8..b98045195f8cf 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir @@ -230,9 +230,10 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8 } // ----- -// CHECK-LABEL: @conversion_splat + +// CHECK-LABEL: @conversion_broadcast // CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> -// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.splat %arg1 : vector<4xf8E8M0FNU> +// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU> // CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32> // CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32> // CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> @@ -242,8 +243,8 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8 // CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32> // CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf32> -func.func @conversion_splat(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> { - %splat = vector.splat %scale : vector<4xf8E8M0FNU> +func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> { + %splat = vector.broadcast %scale : f8E8M0FNU to vector<4xf8E8M0FNU> %ext = arith.scaling_extf %in, %splat : vector<4xf8E5M2>, vector<4xf8E8M0FNU> to vector<4xf32> return %ext : vector<4xf32> } @@ -252,7 +253,7 @@ func.func @conversion_splat(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector< // CHECK-LABEL: @conversion_scalar // CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 -// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.splat %arg0 : vector<1xf8E5M2> +// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f8E5M2 to vector<1xf8E5M2> // CHECK-NEXT: %[[PACKED_EXT:.+]] = amdgpu.scaled_ext_packed %[[SPLAT_IN]][0], %[[SCALE_F32]] : vector<1xf8E5M2> to vector<2xf32> // CHECK-NEXT: %[[RESULT:.+]] = vector.extract %[[PACKED_EXT]][0] : f32 from vector<2xf32> // CHECK-NEXT: return %[[RESULT]] : f32 diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir index 0519050c5ecc4..488e75cbb1843 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir @@ -159,9 +159,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0F // ----- -// CHECK-LABEL: @conversion_splat +// CHECK-LABEL: @conversion_broadcast // CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf8E5M2> -// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.splat %arg1 : vector<4xf8E8M0FNU> +// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU> // CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32> // CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32> // CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> @@ -173,8 +173,8 @@ func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0F // CHECK-NEXT: %[[OUT_CHUNK1:.+]] = vector.extract_strided_slice %[[PACKED1]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> // CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2> // CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf8E5M2> -func.func @conversion_splat(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector<4xf8E5M2> { - %splat = vector.splat %scale : vector<4xf8E8M0FNU> +func.func @conversion_broadcast(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector<4xf8E5M2> { + %splat = vector.broadcast %scale : f8E8M0FNU to vector<4xf8E8M0FNU> %ext = arith.scaling_truncf %in, %splat : vector<4xf32>, vector<4xf8E8M0FNU> to vector<4xf8E5M2> return %ext : vector<4xf8E5M2> } @@ -183,7 +183,7 @@ func.func @conversion_splat(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector<4xf // CHECK-LABEL: @conversion_scalar // CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 -// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.splat %arg0 : vector<1xf32> +// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f32 to vector<1xf32> // CHECK-NEXT: %[[PACKED_TRUNC:.+]] = amdgpu.packed_scaled_trunc %[[SPLAT_IN]] into undef[0], %[[SCALE_F32]] // CHECK-NEXT: %[[RESULT:.+]] = vector.extract %[[PACKED_TRUNC]][0] // CHECK-NEXT: return %[[RESULT]] : f8E5M2 diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index 510f7a2d94c9e..fb14feb8442b0 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -198,7 +198,28 @@ func.func @splat(%f : f32) -> vector<4xf32> { // CHECK-SAME: (%[[A:.+]]: f32) // CHECK: spirv.ReturnValue %[[A]] : f32 func.func @splat_size1_vector(%f : f32) -> vector<1xf32> { - %splat = vector.splat %f : vector<1xf32> + %bc = vector.broadcast %f : f32 to vector<1xf32> + return %bc : vector<1xf32> +} + +// ----- + +// CHECK-LABEL: func @scalar_broadcast +// CHECK-SAME: (%[[A:.+]]: f32) +// CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] +// CHECK: spirv.ReturnValue %[[VAL]] : vector<4xf32> +func.func @scalar_broadcast(%f : f32) -> vector<4xf32> { + %bc = vector.broadcast %f : f32 to vector<4xf32> + return %bc : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @scalar_broadcast_size1_vector +// CHECK-SAME: (%[[A:.+]]: f32) +// CHECK: spirv.ReturnValue %[[A]] : f32 +func.func @scalar_broadcast_size1_vector(%f : f32) -> vector<1xf32> { + %splat = vector.broadcast %f : f32 to vector<1xf32> return %splat : vector<1xf32> } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index c03d67fdc33fa..77158353bd7d4 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -130,6 +130,52 @@ func.func @broadcast_vec1d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector // ----- +// CHECK-LABEL: @broadcast_scalar_0d +// CHECK-SAME: %[[ELT:.*]]: f32 +func.func @broadcast_scalar_0d(%elt: f32) -> vector { + %v = vector.broadcast %elt : f32 to vector + return %v : vector +} +// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.poison : vector<1xf32> +// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32> +// CHECK-NEXT: %[[VCAST:[0-9]+]] = builtin.unrealized_conversion_cast %[[V]] : vector<1xf32> to vector +// CHECK-NEXT: return %[[VCAST]] : vector + +// ----- + +// CHECK-LABEL: @broadcast_scalar +// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<4xf32> +// CHECK-SAME: %[[ELT:[0-9a-zA-Z]+]]: f32 +func.func @broadcast_scalar(%vec: vector<4xf32>, %elt: f32) -> vector<4xf32> { + %vb = vector.broadcast %elt : f32 to vector<4xf32> + %r = arith.mulf %vec, %vb : vector<4xf32> + return %r : vector<4xf32> +} +// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.poison : vector<4xf32> +// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<4xf32> +// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0, 0, 0, 0] +// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[VEC]], %[[SPLAT]] : vector<4xf32> +// CHECK-NEXT: return %[[SCALE]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: @broadcast_scalar_scalable +// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<[4]xf32> +// CHECK-SAME: %[[ELT:[0-9a-zA-Z]+]]: f32 +func.func @broadcast_scalar_scalable(%vec: vector<[4]xf32>, %elt: f32) -> vector<[4]xf32> { + %vb = vector.broadcast %elt : f32 to vector<[4]xf32> + %r = arith.mulf %vec, %vb : vector<[4]xf32> + return %r : vector<[4]xf32> +} +// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.poison : vector<[4]xf32> +// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<[4]xf32> +// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0, 0, 0, 0] +// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[VEC]], %[[SPLAT]] : vector<[4]xf32> +// CHECK-NEXT: return %[[SCALE]] : vector<[4]xf32> + //===----------------------------------------------------------------------===// // vector.shuffle //===----------------------------------------------------------------------===// @@ -2241,51 +2287,16 @@ func.func @compress_store_op_index(%arg0: memref, %arg1: vector<11xi1>, // vector.splat //===----------------------------------------------------------------------===// +// vector.splat should be converted to vector.broadcast, which in turn is converted to LLVM. // CHECK-LABEL: @splat_0d -// CHECK-SAME: %[[ELT:.*]]: f32 -func.func @splat_0d(%elt: f32) -> vector { - %v = vector.splat %elt : vector - return %v : vector -} -// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.poison : vector<1xf32> -// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32> -// CHECK-NEXT: %[[VCAST:[0-9]+]] = builtin.unrealized_conversion_cast %[[V]] : vector<1xf32> to vector -// CHECK-NEXT: return %[[VCAST]] : vector - -// ----- - -// CHECK-LABEL: @splat -// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<4xf32> -// CHECK-SAME: %[[ELT:[0-9a-zA-Z]+]]: f32 -func.func @splat(%vec: vector<4xf32>, %elt: f32) -> vector<4xf32> { - %vb = vector.splat %elt : vector<4xf32> - %r = arith.mulf %vec, %vb : vector<4xf32> - return %r : vector<4xf32> -} -// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.poison : vector<4xf32> -// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<4xf32> -// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0, 0, 0, 0] -// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[VEC]], %[[SPLAT]] : vector<4xf32> -// CHECK-NEXT: return %[[SCALE]] : vector<4xf32> - -// ----- - -// CHECK-LABEL: @splat_scalable -// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<[4]xf32> -// CHECK-SAME: %[[ELT:[0-9a-zA-Z]+]]: f32 -func.func @splat_scalable(%vec: vector<[4]xf32>, %elt: f32) -> vector<[4]xf32> { - %vb = vector.splat %elt : vector<[4]xf32> - %r = arith.mulf %vec, %vb : vector<[4]xf32> - return %r : vector<[4]xf32> +// CHECK-NOT: splat +// CHECK: return +func.func @splat_0d(%elt: f32) -> (vector, vector<4xf32>, vector<[4]xf32>) { + %a = vector.splat %elt : vector + %b = vector.splat %elt : vector<4xf32> + %c = vector.splat %elt : vector<[4]xf32> + return %a, %b, %c : vector, vector<4xf32>, vector<[4]xf32> } -// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.poison : vector<[4]xf32> -// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<[4]xf32> -// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0, 0, 0, 0] -// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[VEC]], %[[SPLAT]] : vector<[4]xf32> -// CHECK-NEXT: return %[[SCALE]] : vector<[4]xf32> // ----- diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 64e51f5554628..f5ad8579dfc7d 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -182,20 +182,21 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector) -> vector<3x2xf32> { %0 = vector.broadcast %arg0 : vector to vector<3x2xf32> return %0 : vector<3x2xf32> } + // CHECK-LABEL: @broadcast_vec2d_from_vec0d( -// CHECK-SAME: %[[A:.*]]: vector) -// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> -// CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32> -// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32> -// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]] -// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]] -// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][1] : !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T6]], %[[T8]][2] : !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32> -// CHECK: return %[[T10]] : vector<3x2xf32> +// CHECK-SAME: %[[ARG_RANK0:.*]]: vector) +// CHECK-DAG: %[[ARG_RANK1:.*]] = builtin.unrealized_conversion_cast %[[ARG_RANK0]] : vector to vector<1xf32> +// CHECK-DAG: %[[ZERO_64:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-DAG: %[[ARG_SCALAR:.*]] = llvm.extractelement %[[ARG_RANK1]][%[[ZERO_64]] : i64] : vector<1xf32> +// CHECK-DAG: %[[UB_POISON:.*]] = ub.poison : vector<3x2xf32> +// CHECK-DAG: %[[FULL_POISON:.*]] = builtin.unrealized_conversion_cast %[[UB_POISON]] {{.*}} !llvm.array<3 x vector<2xf32>> +// CHECK: %[[PART_RANK1:.*]] = llvm.insertelement %[[ARG_SCALAR]] +// CHECK: %[[FULL_RANK1:.*]] = llvm.shufflevector %[[PART_RANK1]] +// CHECK: %[[INSERT1:.*]] = llvm.insertvalue %[[FULL_RANK1]], %[[FULL_POISON]][0] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[INSERT2:.*]] = llvm.insertvalue %[[FULL_RANK1]], %[[INSERT1]][1] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[INSERT3:.*]] = llvm.insertvalue %[[FULL_RANK1]], %[[INSERT2]][2] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[FINAL:.*]] = builtin.unrealized_conversion_cast %[[INSERT3]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32> +// CHECK: return %[[FINAL]] : vector<3x2xf32> // ----- @@ -1517,7 +1518,7 @@ func.func @constant_mask_2d() -> vector<4x4xi1> { } // CHECK-LABEL: func @constant_mask_2d -// CHECK: %[[VAL_0:.*]] = arith.constant +// CHECK: %[[VAL_0:.*]] = arith.constant // CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1> // CHECK: return %[[VAL_0]] : vector<4x4xi1> diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 99ab0e1dc4eef..e69050ac423c5 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -579,7 +579,7 @@ func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf // CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] // CHECK: return %[[VAL]] func.func @splat(%f : f32) -> vector<4xf32> { - %splat = vector.splat %f : vector<4xf32> + %splat = vector.broadcast %f : f32 to vector<4xf32> return %splat : vector<4xf32> } @@ -590,7 +590,7 @@ func.func @splat(%f : f32) -> vector<4xf32> { // CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %[[A]] // CHECK: return %[[VAL]] func.func @splat_size1_vector(%f : f32) -> vector<1xf32> { - %splat = vector.splat %f : vector<1xf32> + %splat = vector.broadcast %f : f32 to vector<1xf32> return %splat : vector<1xf32> }