Skip to content

[mlir][vector] Support direct broadcast conversion (LLVM & SPIRV) #148027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

newling
Copy link
Contributor

@newling newling commented Jul 10, 2025

Add conversion for broadcast from scalar for LLVM and SPIRV. Also some miscellaneous replacements of vector.splat with vector.broadcast in VectorToGPU and ArithToAMDGPU.

Part of deprecation of vector.splat

RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
More complete deprecation: #147818

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Add conversion for broadcast from scalar for LLVM and SPIRV. Also some miscellaneous replacements of vector.splat with vector.broadcast in VectorToGPU and ArithToAMDGPU.

Part of deprecation of vector.splat

RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
More complete deprecation: #147818


Patch is 37.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148027.diff

11 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+14-11)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-1)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+63-33)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+5-5)
  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+30-10)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir (+6-5)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir (+5-5)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+22-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+54-43)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+15-14)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+2-2)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cf9bb3a000050..156c679c5039e 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
 
   if (inVecType.getShape().empty()) {
     Value zerodSplat =
-        rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+        rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     Value scalarExt =
@@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
 
   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
                                       outType.getElementType());
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
 
   if (inVecType.getRank() > 1) {
     inVecType = VectorType::get(SmallVector<int64_t>{numElements},
@@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
 
   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
                                       outVecType.getElementType());
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
 
   if (inVectorTy.getRank() > 1) {
     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   int64_t numElements = outVecType.getNumElements();
   Value zero = rewriter.createOrFold<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+  Value result =
+      rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
 
   if (inVectorTy.getRank() > 1) {
     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
   VectorType extScaleResultType = VectorType::get(opWidth, outType);
 
   if (!outVecType) {
-    Value inCast =
-        rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+    Value inCast = rewriter.create<vector::BroadcastOp>(
+        loc, VectorType::get(1, inType), in);
     // TODO: replace this with non-packed ScaledExtOp
     Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
         loc, extScaleResultType, inCast, scale, 0);
@@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outType, rewriter.getFloatAttr(outType, 0.0));
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+  Value result =
+      rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
 
   for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
     SmallVector<int64_t> strides(offsets.size(), 1);
@@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
     VectorType blockResultType = VectorType::get(blockSize, outType);
     Value blockResult =
-        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+        rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
     for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
          i < blockSize;
@@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   if (!outVecType) {
     Type inVecType = VectorType::get(1, inType);
-    Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
+    Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
     // TODO: replace this with non-packed ScaledTruncOp
     Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
         loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
@@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   int64_t blockSize = computeProduct(ratio);
 
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+  Value result =
+      rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
 
   for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
     SmallVector<int64_t> strides(offsets.size(), 1);
@@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
     VectorType blockResultType = VectorType::get(blockSize, outType);
     Value blockResult =
-        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+        rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
     for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
          i < blockSize;
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 18adaa793787c..9a8eb72d72925 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
       op.getLoc(), vectorType.getElementType(),
       rewriter.getZeroAttr(vectorType.getElementType()));
   Value result =
-      rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+      rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
 
   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 501d98862672d..76afb3b6f256d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
 /// ```
 /// is rewritten into:
 /// ```
-///  %r = splat %f0: vector<2x4xf32>
+///  %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
     auto elemType = vType.getElementType();
     Value zero = rewriter.create<arith::ConstantOp>(
         loc, elemType, rewriter.getZeroAttr(elemType));
-    Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
+    Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
@@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
                                  /*isScalable=*/true));
     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
                                                  adaptor.getOperands()[0]);
-    Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+    Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                 indices, bounds);
     rewriter.replaceOp(op, comp);
@@ -1767,63 +1767,79 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
   }
 };
 
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 0-d and 1-d vector result types are lowered.
-struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
-  using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
+/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
+/// pattern, the higher rank cases are handled by another pattern.
+struct VectorBroadcastScalarToLowRankLowering
+    : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
+  using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = cast<VectorType>(splatOp.getType());
+
+    if (isa<VectorType>(broadcast.getSourceType()))
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast from vector type not handled");
+
+    VectorType resultType = broadcast.getType();
     if (resultType.getRank() > 1)
-      return failure();
+      return rewriter.notifyMatchFailure(broadcast,
+                                         "broadcast to 2+-d handled elsewhere");
 
     // First insert it into a poison vector so we can shuffle it.
-    auto vectorType = typeConverter->convertType(splatOp.getType());
+    auto vectorType = typeConverter->convertType(broadcast.getType());
     Value poison =
-        rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
+        rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
     auto zero = rewriter.create<LLVM::ConstantOp>(
-        splatOp.getLoc(),
+        broadcast.getLoc(),
         typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
 
     // For 0-d vector, we simply do `insertelement`.
     if (resultType.getRank() == 0) {
       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
-          splatOp, vectorType, poison, adaptor.getInput(), zero);
+          broadcast, vectorType, poison, adaptor.getSource(), zero);
       return success();
     }
 
     // For 1-d vector, we additionally do a `vectorshuffle`.
     auto v = rewriter.create<LLVM::InsertElementOp>(
-        splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
+        broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
 
-    int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
+    int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
     SmallVector<int32_t> zeroValues(width, 0);
 
     // Shuffle the value across the desired number of elements.
-    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
+    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
                                                        zeroValues);
     return success();
   }
 };
 
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 2+-d vector result types are lowered by the
-/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
-struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
-  using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
+/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 2+-d vector result types are lowered by this
+/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
+/// are not converted to LLVM, only broadcasts from scalars are.
+struct VectorBroadcastScalarToNdLowering
+    : public ConvertOpToLLVMPattern<BroadcastOp> {
+  using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
+  matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = splatOp.getType();
+
+    if (isa<VectorType>(broadcast.getSourceType()))
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast from vector type not handled");
+
+    VectorType resultType = broadcast.getType();
     if (resultType.getRank() <= 1)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast to 1-d or 0-d handled elsewhere");
 
     // First insert it into an undef vector so we can shuffle it.
-    auto loc = splatOp.getLoc();
+    auto loc = broadcast.getLoc();
     auto vectorTypeInfo =
         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
@@ -1834,26 +1850,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
     // Construct returned value.
     Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
 
-    // Construct a 1-D vector with the splatted value that we insert in all the
-    // places within the returned descriptor.
+    // Construct a 1-D vector with the broadcasted value that we insert in all
+    // the places within the returned descriptor.
     Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
     auto zero = rewriter.create<LLVM::ConstantOp>(
         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
-                                                     adaptor.getInput(), zero);
+                                                     adaptor.getSource(), zero);
 
     // Shuffle the value across the desired number of elements.
     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
     SmallVector<int32_t> zeroValues(width, 0);
     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
 
-    // Iterate of linear index, convert to coords space and insert splatted 1-D
-    // vector in each position.
+    // Iterate of linear index, convert to coords space and insert broadcasted
+    // 1-D vector in each position.
     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
     });
-    rewriter.replaceOp(splatOp, desc);
+    rewriter.replaceOp(broadcast, desc);
     return success();
   }
 };
@@ -2035,6 +2051,19 @@ struct VectorScalableStepOpLowering
   }
 };
 
+/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
+/// `vector.broadcast` through other patterns.
+struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+                                                     adaptor.getInput());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2063,7 +2092,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorInsertOpConversion, VectorPrintOpConversion,
                VectorTypeCastOpConversion, VectorScaleOpConversion,
                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
-               VectorSplatOpLowering, VectorSplatNdOpLowering,
+               VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+               VectorBroadcastScalarToNdLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index ed51b2126dcdd..43732f58a4e0a 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -444,7 +444,7 @@ struct Strategy<TransferReadOp> {
     Location loc = xferOp.getLoc();
     auto bufferType = dyn_cast<ShapedType>(buffer.getType());
     auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
-    auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
+    auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
 
     return Value();
@@ -1261,8 +1261,8 @@ struct UnrollTransferReadConversion
     if (auto insertOp = getInsertOp(xferOp))
       return insertOp.getDest();
     Location loc = xferOp.getLoc();
-    return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
-                                            xferOp.getPadding());
+    return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+                                                xferOp.getPadding());
   }
 
   /// If the result of the TransferReadOp has exactly one user, which is a
@@ -1583,8 +1583,8 @@ struct Strategy1d<TransferReadOp> {
   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
     // Inititalize vector with padding value.
     Location loc = xferOp.getLoc();
-    return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
-                                     xferOp.getPadding());
+    return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+                                         xferOp.getPadding());
   }
 };
 
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21d8e1d9f1156..de35015b81108 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
   }
 };
 
+// Convert `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to SPIRV via other patterns.
+struct VectorSplatToBroadcast final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+                                                     adaptor.getInput());
+    return success();
+  }
+};
+
 struct VectorBitcastConvert final
     : public OpConversionPattern<vector::BitCastOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -556,22 +570,28 @@ struct VectorReductionFloatMinMax final
   }
 };
 
-class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
+class VectorScalarBroadcastPattern final
+    : public OpConversionPattern<vector::BroadcastOp> {
 public:
-  using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
+  matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
+    if (isa<VectorType>(op.getSourceType())) {
+      return rewriter.notifyMatchFailure(
+          op, "only conversion of 'broadcast from scalar' is supported");
+    }
     Type dstType = getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return failure();
     if (isa<spirv::ScalarType>(dstType)) {
-      rewriter.replaceOp(op, adaptor.getInput());
+      rewriter.replaceOp(op, adaptor.getSource());
     } else {
       auto dstVecType = cast<VectorType>(dstType);
       SmallVector<Value, 4> source(dstVecType.getNumElements(),
-                                   adaptor.getInput());
+                                   adaptor.getSource());
       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
                                                                source);
     }
@@ -1089,11 +1109,11 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorRed...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir-gpu

Author: James Newling (newling)

Changes

Add conversion for broadcast from scalar for LLVM and SPIRV. Also some miscellaneous replacements of vector.splat with vector.broadcast in VectorToGPU and ArithToAMDGPU.

Part of deprecation of vector.splat

RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
More complete deprecation: #147818


Patch is 37.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148027.diff

11 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+14-11)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-1)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+63-33)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+5-5)
  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+30-10)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir (+6-5)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir (+5-5)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+22-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+54-43)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+15-14)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+2-2)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cf9bb3a000050..156c679c5039e 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
 
   if (inVecType.getShape().empty()) {
     Value zerodSplat =
-        rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+        rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     Value scalarExt =
@@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
 
   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
                                       outType.getElementType());
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
 
   if (inVecType.getRank() > 1) {
     inVecType = VectorType::get(SmallVector<int64_t>{numElements},
@@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
 
   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
                                       outVecType.getElementType());
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
 
   if (inVectorTy.getRank() > 1) {
     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   int64_t numElements = outVecType.getNumElements();
   Value zero = rewriter.createOrFold<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+  Value result =
+      rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
 
   if (inVectorTy.getRank() > 1) {
     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
   VectorType extScaleResultType = VectorType::get(opWidth, outType);
 
   if (!outVecType) {
-    Value inCast =
-        rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+    Value inCast = rewriter.create<vector::BroadcastOp>(
+        loc, VectorType::get(1, inType), in);
     // TODO: replace this with non-packed ScaledExtOp
     Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
         loc, extScaleResultType, inCast, scale, 0);
@@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outType, rewriter.getFloatAttr(outType, 0.0));
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+  Value result =
+      rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
 
   for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
     SmallVector<int64_t> strides(offsets.size(), 1);
@@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
     VectorType blockResultType = VectorType::get(blockSize, outType);
     Value blockResult =
-        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+        rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
     for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
          i < blockSize;
@@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   if (!outVecType) {
     Type inVecType = VectorType::get(1, inType);
-    Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
+    Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
     // TODO: replace this with non-packed ScaledTruncOp
     Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
         loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
@@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   int64_t blockSize = computeProduct(ratio);
 
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+  Value result =
+      rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
 
   for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
     SmallVector<int64_t> strides(offsets.size(), 1);
@@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
     VectorType blockResultType = VectorType::get(blockSize, outType);
     Value blockResult =
-        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+        rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
     for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
          i < blockSize;
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 18adaa793787c..9a8eb72d72925 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
       op.getLoc(), vectorType.getElementType(),
       rewriter.getZeroAttr(vectorType.getElementType()));
   Value result =
-      rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+      rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
 
   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 501d98862672d..76afb3b6f256d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
 /// ```
 /// is rewritten into:
 /// ```
-///  %r = splat %f0: vector<2x4xf32>
+///  %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
     auto elemType = vType.getElementType();
     Value zero = rewriter.create<arith::ConstantOp>(
         loc, elemType, rewriter.getZeroAttr(elemType));
-    Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
+    Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
@@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
                                  /*isScalable=*/true));
     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
                                                  adaptor.getOperands()[0]);
-    Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+    Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                 indices, bounds);
     rewriter.replaceOp(op, comp);
@@ -1767,63 +1767,79 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
   }
 };
 
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 0-d and 1-d vector result types are lowered.
-struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
-  using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
+/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
+/// pattern, the higher rank cases are handled by another pattern.
+struct VectorBroadcastScalarToLowRankLowering
+    : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
+  using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = cast<VectorType>(splatOp.getType());
+
+    if (isa<VectorType>(broadcast.getSourceType()))
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast from vector type not handled");
+
+    VectorType resultType = broadcast.getType();
     if (resultType.getRank() > 1)
-      return failure();
+      return rewriter.notifyMatchFailure(broadcast,
+                                         "broadcast to 2+-d handled elsewhere");
 
     // First insert it into a poison vector so we can shuffle it.
-    auto vectorType = typeConverter->convertType(splatOp.getType());
+    auto vectorType = typeConverter->convertType(broadcast.getType());
     Value poison =
-        rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
+        rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
     auto zero = rewriter.create<LLVM::ConstantOp>(
-        splatOp.getLoc(),
+        broadcast.getLoc(),
         typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
 
     // For 0-d vector, we simply do `insertelement`.
     if (resultType.getRank() == 0) {
       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
-          splatOp, vectorType, poison, adaptor.getInput(), zero);
+          broadcast, vectorType, poison, adaptor.getSource(), zero);
       return success();
     }
 
     // For 1-d vector, we additionally do a `vectorshuffle`.
     auto v = rewriter.create<LLVM::InsertElementOp>(
-        splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
+        broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
 
-    int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
+    int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
     SmallVector<int32_t> zeroValues(width, 0);
 
     // Shuffle the value across the desired number of elements.
-    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
+    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
                                                        zeroValues);
     return success();
   }
 };
 
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 2+-d vector result types are lowered by the
-/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
-struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
-  using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
+/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 2+-d vector result types are lowered by this
+/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
+/// are not converted to LLVM, only broadcasts from scalars are.
+struct VectorBroadcastScalarToNdLowering
+    : public ConvertOpToLLVMPattern<BroadcastOp> {
+  using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
+  matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = splatOp.getType();
+
+    if (isa<VectorType>(broadcast.getSourceType()))
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast from vector type not handled");
+
+    VectorType resultType = broadcast.getType();
     if (resultType.getRank() <= 1)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast to 1-d or 0-d handled elsewhere");
 
     // First insert it into an undef vector so we can shuffle it.
-    auto loc = splatOp.getLoc();
+    auto loc = broadcast.getLoc();
     auto vectorTypeInfo =
         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
@@ -1834,26 +1850,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
     // Construct returned value.
     Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
 
-    // Construct a 1-D vector with the splatted value that we insert in all the
-    // places within the returned descriptor.
+    // Construct a 1-D vector with the broadcasted value that we insert in all
+    // the places within the returned descriptor.
     Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
     auto zero = rewriter.create<LLVM::ConstantOp>(
         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
-                                                     adaptor.getInput(), zero);
+                                                     adaptor.getSource(), zero);
 
     // Shuffle the value across the desired number of elements.
     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
     SmallVector<int32_t> zeroValues(width, 0);
     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
 
-    // Iterate of linear index, convert to coords space and insert splatted 1-D
-    // vector in each position.
+    // Iterate of linear index, convert to coords space and insert broadcasted
+    // 1-D vector in each position.
     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
     });
-    rewriter.replaceOp(splatOp, desc);
+    rewriter.replaceOp(broadcast, desc);
     return success();
   }
 };
@@ -2035,6 +2051,19 @@ struct VectorScalableStepOpLowering
   }
 };
 
+/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
+/// `vector.broadcast` through other patterns.
+struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+                                                     adaptor.getInput());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2063,7 +2092,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorInsertOpConversion, VectorPrintOpConversion,
                VectorTypeCastOpConversion, VectorScaleOpConversion,
                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
-               VectorSplatOpLowering, VectorSplatNdOpLowering,
+               VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+               VectorBroadcastScalarToNdLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index ed51b2126dcdd..43732f58a4e0a 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -444,7 +444,7 @@ struct Strategy<TransferReadOp> {
     Location loc = xferOp.getLoc();
     auto bufferType = dyn_cast<ShapedType>(buffer.getType());
     auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
-    auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
+    auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
 
     return Value();
@@ -1261,8 +1261,8 @@ struct UnrollTransferReadConversion
     if (auto insertOp = getInsertOp(xferOp))
       return insertOp.getDest();
     Location loc = xferOp.getLoc();
-    return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
-                                            xferOp.getPadding());
+    return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+                                                xferOp.getPadding());
   }
 
   /// If the result of the TransferReadOp has exactly one user, which is a
@@ -1583,8 +1583,8 @@ struct Strategy1d<TransferReadOp> {
   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
     // Inititalize vector with padding value.
     Location loc = xferOp.getLoc();
-    return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
-                                     xferOp.getPadding());
+    return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+                                         xferOp.getPadding());
   }
 };
 
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21d8e1d9f1156..de35015b81108 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
   }
 };
 
+// Convert `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to SPIRV via other patterns.
+struct VectorSplatToBroadcast final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+                                                     adaptor.getInput());
+    return success();
+  }
+};
+
 struct VectorBitcastConvert final
     : public OpConversionPattern<vector::BitCastOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -556,22 +570,28 @@ struct VectorReductionFloatMinMax final
   }
 };
 
-class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
+class VectorScalarBroadcastPattern final
+    : public OpConversionPattern<vector::BroadcastOp> {
 public:
-  using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
+  matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
+    if (isa<VectorType>(op.getSourceType())) {
+      return rewriter.notifyMatchFailure(
+          op, "only conversion of 'broadcast from scalar' is supported");
+    }
     Type dstType = getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return failure();
     if (isa<spirv::ScalarType>(dstType)) {
-      rewriter.replaceOp(op, adaptor.getInput());
+      rewriter.replaceOp(op, adaptor.getSource());
     } else {
       auto dstVecType = cast<VectorType>(dstType);
       SmallVector<Value, 4> source(dstVecType.getNumElements(),
-                                   adaptor.getInput());
+                                   adaptor.getSource());
       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
                                                                source);
     }
@@ -1089,11 +1109,11 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorRed...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: James Newling (newling)

Changes

Add conversion for broadcast from scalar for LLVM and SPIRV. Also some miscellaneous replacements of vector.splat with vector.broadcast in VectorToGPU and ArithToAMDGPU.

Part of deprecation of vector.splat

RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
More complete deprecation: #147818


Patch is 37.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148027.diff

11 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+14-11)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-1)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+63-33)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+5-5)
  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+30-10)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir (+6-5)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir (+5-5)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+22-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+54-43)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+15-14)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+2-2)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cf9bb3a000050..156c679c5039e 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
 
   if (inVecType.getShape().empty()) {
     Value zerodSplat =
-        rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+        rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     Value scalarExt =
@@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
 
   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
                                       outType.getElementType());
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
 
   if (inVecType.getRank() > 1) {
     inVecType = VectorType::get(SmallVector<int64_t>{numElements},
@@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
 
   VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
                                       outVecType.getElementType());
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
 
   if (inVectorTy.getRank() > 1) {
     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   int64_t numElements = outVecType.getNumElements();
   Value zero = rewriter.createOrFold<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+  Value result =
+      rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
 
   if (inVectorTy.getRank() > 1) {
     inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
   VectorType extScaleResultType = VectorType::get(opWidth, outType);
 
   if (!outVecType) {
-    Value inCast =
-        rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+    Value inCast = rewriter.create<vector::BroadcastOp>(
+        loc, VectorType::get(1, inType), in);
     // TODO: replace this with non-packed ScaledExtOp
     Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
         loc, extScaleResultType, inCast, scale, 0);
@@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outType, rewriter.getFloatAttr(outType, 0.0));
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+  Value result =
+      rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
 
   for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
     SmallVector<int64_t> strides(offsets.size(), 1);
@@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
     VectorType blockResultType = VectorType::get(blockSize, outType);
     Value blockResult =
-        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+        rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
     for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
          i < blockSize;
@@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   if (!outVecType) {
     Type inVecType = VectorType::get(1, inType);
-    Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
+    Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
     // TODO: replace this with non-packed ScaledTruncOp
     Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
         loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
@@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   int64_t blockSize = computeProduct(ratio);
 
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+  Value result =
+      rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
 
   for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
     SmallVector<int64_t> strides(offsets.size(), 1);
@@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
     VectorType blockResultType = VectorType::get(blockSize, outType);
     Value blockResult =
-        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+        rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
     for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
          i < blockSize;
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 18adaa793787c..9a8eb72d72925 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
       op.getLoc(), vectorType.getElementType(),
       rewriter.getZeroAttr(vectorType.getElementType()));
   Value result =
-      rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+      rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
 
   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 501d98862672d..76afb3b6f256d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
 /// ```
 /// is rewritten into:
 /// ```
-///  %r = splat %f0: vector<2x4xf32>
+///  %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
     auto elemType = vType.getElementType();
     Value zero = rewriter.create<arith::ConstantOp>(
         loc, elemType, rewriter.getZeroAttr(elemType));
-    Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
+    Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
@@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
                                  /*isScalable=*/true));
     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
                                                  adaptor.getOperands()[0]);
-    Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+    Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                 indices, bounds);
     rewriter.replaceOp(op, comp);
@@ -1767,63 +1767,79 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
   }
 };
 
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 0-d and 1-d vector result types are lowered.
-struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
-  using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
+/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
+/// pattern, the higher rank cases are handled by another pattern.
+struct VectorBroadcastScalarToLowRankLowering
+    : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
+  using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = cast<VectorType>(splatOp.getType());
+
+    if (isa<VectorType>(broadcast.getSourceType()))
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast from vector type not handled");
+
+    VectorType resultType = broadcast.getType();
     if (resultType.getRank() > 1)
-      return failure();
+      return rewriter.notifyMatchFailure(broadcast,
+                                         "broadcast to 2+-d handled elsewhere");
 
     // First insert it into a poison vector so we can shuffle it.
-    auto vectorType = typeConverter->convertType(splatOp.getType());
+    auto vectorType = typeConverter->convertType(broadcast.getType());
     Value poison =
-        rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
+        rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
     auto zero = rewriter.create<LLVM::ConstantOp>(
-        splatOp.getLoc(),
+        broadcast.getLoc(),
         typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
 
     // For 0-d vector, we simply do `insertelement`.
     if (resultType.getRank() == 0) {
       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
-          splatOp, vectorType, poison, adaptor.getInput(), zero);
+          broadcast, vectorType, poison, adaptor.getSource(), zero);
       return success();
     }
 
     // For 1-d vector, we additionally do a `vectorshuffle`.
     auto v = rewriter.create<LLVM::InsertElementOp>(
-        splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
+        broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
 
-    int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
+    int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
     SmallVector<int32_t> zeroValues(width, 0);
 
     // Shuffle the value across the desired number of elements.
-    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
+    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
                                                        zeroValues);
     return success();
   }
 };
 
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 2+-d vector result types are lowered by the
-/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
-struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
-  using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
+/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 2+-d vector result types are lowered by this
+/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
+/// are not converted to LLVM, only broadcasts from scalars are.
+struct VectorBroadcastScalarToNdLowering
+    : public ConvertOpToLLVMPattern<BroadcastOp> {
+  using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
+  matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = splatOp.getType();
+
+    if (isa<VectorType>(broadcast.getSourceType()))
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast from vector type not handled");
+
+    VectorType resultType = broadcast.getType();
     if (resultType.getRank() <= 1)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast to 1-d or 0-d handled elsewhere");
 
     // First insert it into an undef vector so we can shuffle it.
-    auto loc = splatOp.getLoc();
+    auto loc = broadcast.getLoc();
     auto vectorTypeInfo =
         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
@@ -1834,26 +1850,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
     // Construct returned value.
     Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
 
-    // Construct a 1-D vector with the splatted value that we insert in all the
-    // places within the returned descriptor.
+    // Construct a 1-D vector with the broadcasted value that we insert in all
+    // the places within the returned descriptor.
     Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
     auto zero = rewriter.create<LLVM::ConstantOp>(
         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
-                                                     adaptor.getInput(), zero);
+                                                     adaptor.getSource(), zero);
 
     // Shuffle the value across the desired number of elements.
     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
     SmallVector<int32_t> zeroValues(width, 0);
     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
 
-    // Iterate of linear index, convert to coords space and insert splatted 1-D
-    // vector in each position.
+    // Iterate of linear index, convert to coords space and insert broadcasted
+    // 1-D vector in each position.
     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
     });
-    rewriter.replaceOp(splatOp, desc);
+    rewriter.replaceOp(broadcast, desc);
     return success();
   }
 };
@@ -2035,6 +2051,19 @@ struct VectorScalableStepOpLowering
   }
 };
 
+/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
+/// `vector.broadcast` through other patterns.
+struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+                                                     adaptor.getInput());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2063,7 +2092,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorInsertOpConversion, VectorPrintOpConversion,
                VectorTypeCastOpConversion, VectorScaleOpConversion,
                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
-               VectorSplatOpLowering, VectorSplatNdOpLowering,
+               VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+               VectorBroadcastScalarToNdLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index ed51b2126dcdd..43732f58a4e0a 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -444,7 +444,7 @@ struct Strategy<TransferReadOp> {
     Location loc = xferOp.getLoc();
     auto bufferType = dyn_cast<ShapedType>(buffer.getType());
     auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
-    auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
+    auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
 
     return Value();
@@ -1261,8 +1261,8 @@ struct UnrollTransferReadConversion
     if (auto insertOp = getInsertOp(xferOp))
       return insertOp.getDest();
     Location loc = xferOp.getLoc();
-    return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
-                                            xferOp.getPadding());
+    return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+                                                xferOp.getPadding());
   }
 
   /// If the result of the TransferReadOp has exactly one user, which is a
@@ -1583,8 +1583,8 @@ struct Strategy1d<TransferReadOp> {
   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
     // Inititalize vector with padding value.
     Location loc = xferOp.getLoc();
-    return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
-                                     xferOp.getPadding());
+    return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+                                         xferOp.getPadding());
   }
 };
 
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21d8e1d9f1156..de35015b81108 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
   }
 };
 
+// Convert `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to SPIRV via other patterns.
+struct VectorSplatToBroadcast final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+                                                     adaptor.getInput());
+    return success();
+  }
+};
+
 struct VectorBitcastConvert final
     : public OpConversionPattern<vector::BitCastOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -556,22 +570,28 @@ struct VectorReductionFloatMinMax final
   }
 };
 
-class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
+class VectorScalarBroadcastPattern final
+    : public OpConversionPattern<vector::BroadcastOp> {
 public:
-  using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
+  matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
+    if (isa<VectorType>(op.getSourceType())) {
+      return rewriter.notifyMatchFailure(
+          op, "only conversion of 'broadcast from scalar' is supported");
+    }
     Type dstType = getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return failure();
     if (isa<spirv::ScalarType>(dstType)) {
-      rewriter.replaceOp(op, adaptor.getInput());
+      rewriter.replaceOp(op, adaptor.getSource());
     } else {
       auto dstVecType = cast<VectorType>(dstType);
       SmallVector<Value, 4> source(dstVecType.getNumElements(),
-                                   adaptor.getInput());
+                                   adaptor.getSource());
       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
                                                                source);
     }
@@ -1089,11 +1109,11 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorRed...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPIR-V LGTM

ConversionPatternRewriter &rewriter) const override {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spurious newline

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants