-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][vector] Support direct broadcast conversion (LLVM & SPIRV) #148027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] Support direct broadcast conversion (LLVM & SPIRV) #148027
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesAdd conversion for broadcast from scalar for LLVM and SPIRV. Also some miscellaneous replacements of vector.splat with vector.broadcast in VectorToGPU and ArithToAMDGPU. Part of deprecation of vector.splat RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4 Patch is 37.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148027.diff 11 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cf9bb3a000050..156c679c5039e 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
if (inVecType.getShape().empty()) {
Value zerodSplat =
- rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
Value scalarExt =
@@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+ Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
@@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outVecType.getElementType());
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+ Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType extScaleResultType = VectorType::get(opWidth, outType);
if (!outVecType) {
- Value inCast =
- rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+ Value inCast = rewriter.create<vector::BroadcastOp>(
+ loc, VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
loc, extScaleResultType, inCast, scale, 0);
@@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Value zero = rewriter.create<arith::ConstantOp>(
loc, outType, rewriter.getFloatAttr(outType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
- rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
@@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
- Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
+ Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
@@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
int64_t blockSize = computeProduct(ratio);
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
- rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 18adaa793787c..9a8eb72d72925 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
op.getLoc(), vectorType.getElementType(),
rewriter.getZeroAttr(vectorType.getElementType()));
Value result =
- rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+ rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 501d98862672d..76afb3b6f256d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
/// ```
/// is rewritten into:
/// ```
-/// %r = splat %f0: vector<2x4xf32>
+/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
auto elemType = vType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
- Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
+ Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
@@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
/*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
adaptor.getOperands()[0]);
- Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+ Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indices, bounds);
rewriter.replaceOp(op, comp);
@@ -1767,63 +1767,79 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}
};
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 0-d and 1-d vector result types are lowered.
-struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
- using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
+/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
+/// pattern, the higher rank cases are handled by another pattern.
+struct VectorBroadcastScalarToLowRankLowering
+ : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
+ using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType resultType = cast<VectorType>(splatOp.getType());
+
+ if (isa<VectorType>(broadcast.getSourceType()))
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast from vector type not handled");
+
+ VectorType resultType = broadcast.getType();
if (resultType.getRank() > 1)
- return failure();
+ return rewriter.notifyMatchFailure(broadcast,
+ "broadcast to 2+-d handled elsewhere");
// First insert it into a poison vector so we can shuffle it.
- auto vectorType = typeConverter->convertType(splatOp.getType());
+ auto vectorType = typeConverter->convertType(broadcast.getType());
Value poison =
- rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
+ rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
- splatOp.getLoc(),
+ broadcast.getLoc(),
typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
// For 0-d vector, we simply do `insertelement`.
if (resultType.getRank() == 0) {
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- splatOp, vectorType, poison, adaptor.getInput(), zero);
+ broadcast, vectorType, poison, adaptor.getSource(), zero);
return success();
}
// For 1-d vector, we additionally do a `vectorshuffle`.
auto v = rewriter.create<LLVM::InsertElementOp>(
- splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
+ broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
- int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
+ int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
+ rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
zeroValues);
return success();
}
};
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 2+-d vector result types are lowered by the
-/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
-struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
- using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
+/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 2+-d vector result types are lowered by this
+/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
+/// are not converted to LLVM, only broadcasts from scalars are.
+struct VectorBroadcastScalarToNdLowering
+ : public ConvertOpToLLVMPattern<BroadcastOp> {
+ using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
+ matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType resultType = splatOp.getType();
+
+ if (isa<VectorType>(broadcast.getSourceType()))
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast from vector type not handled");
+
+ VectorType resultType = broadcast.getType();
if (resultType.getRank() <= 1)
- return failure();
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast to 1-d or 0-d handled elsewhere");
// First insert it into an undef vector so we can shuffle it.
- auto loc = splatOp.getLoc();
+ auto loc = broadcast.getLoc();
auto vectorTypeInfo =
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
@@ -1834,26 +1850,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
// Construct returned value.
Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
- // Construct a 1-D vector with the splatted value that we insert in all the
- // places within the returned descriptor.
+ // Construct a 1-D vector with the broadcasted value that we insert in all
+ // the places within the returned descriptor.
Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
- adaptor.getInput(), zero);
+ adaptor.getSource(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
SmallVector<int32_t> zeroValues(width, 0);
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
- // Iterate of linear index, convert to coords space and insert splatted 1-D
- // vector in each position.
+ // Iterate of linear index, convert to coords space and insert broadcasted
+ // 1-D vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
});
- rewriter.replaceOp(splatOp, desc);
+ rewriter.replaceOp(broadcast, desc);
return success();
}
};
@@ -2035,6 +2051,19 @@ struct VectorScalableStepOpLowering
}
};
+/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
+/// `vector.broadcast` through other patterns.
+struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+ adaptor.getInput());
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2063,7 +2092,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatOpLowering, VectorSplatNdOpLowering,
+ VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+ VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index ed51b2126dcdd..43732f58a4e0a 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -444,7 +444,7 @@ struct Strategy<TransferReadOp> {
Location loc = xferOp.getLoc();
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
- auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
+ auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
return Value();
@@ -1261,8 +1261,8 @@ struct UnrollTransferReadConversion
if (auto insertOp = getInsertOp(xferOp))
return insertOp.getDest();
Location loc = xferOp.getLoc();
- return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
/// If the result of the TransferReadOp has exactly one user, which is a
@@ -1583,8 +1583,8 @@ struct Strategy1d<TransferReadOp> {
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
// Inititalize vector with padding value.
Location loc = xferOp.getLoc();
- return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
};
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21d8e1d9f1156..de35015b81108 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};
+// Convert `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to SPIRV via other patterns.
+struct VectorSplatToBroadcast final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+ adaptor.getInput());
+ return success();
+ }
+};
+
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
@@ -556,22 +570,28 @@ struct VectorReductionFloatMinMax final
}
};
-class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
+class VectorScalarBroadcastPattern final
+ : public OpConversionPattern<vector::BroadcastOp> {
public:
- using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
+ matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+
+ if (isa<VectorType>(op.getSourceType())) {
+ return rewriter.notifyMatchFailure(
+ op, "only conversion of 'broadcast from scalar' is supported");
+ }
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
if (isa<spirv::ScalarType>(dstType)) {
- rewriter.replaceOp(op, adaptor.getInput());
+ rewriter.replaceOp(op, adaptor.getSource());
} else {
auto dstVecType = cast<VectorType>(dstType);
SmallVector<Value, 4> source(dstVecType.getNumElements(),
- adaptor.getInput());
+ adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
source);
}
@@ -1089,11 +1109,11 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorRed...
[truncated]
|
@llvm/pr-subscribers-mlir-gpu Author: James Newling (newling) ChangesAdd conversion for broadcast from scalar for LLVM and SPIRV. Also some miscellaneous replacements of vector.splat with vector.broadcast in VectorToGPU and ArithToAMDGPU. Part of deprecation of vector.splat RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4 Patch is 37.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148027.diff 11 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cf9bb3a000050..156c679c5039e 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
if (inVecType.getShape().empty()) {
Value zerodSplat =
- rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
Value scalarExt =
@@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+ Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
@@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outVecType.getElementType());
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+ Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType extScaleResultType = VectorType::get(opWidth, outType);
if (!outVecType) {
- Value inCast =
- rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+ Value inCast = rewriter.create<vector::BroadcastOp>(
+ loc, VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
loc, extScaleResultType, inCast, scale, 0);
@@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Value zero = rewriter.create<arith::ConstantOp>(
loc, outType, rewriter.getFloatAttr(outType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
- rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
@@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
- Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
+ Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
@@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
int64_t blockSize = computeProduct(ratio);
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
- rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 18adaa793787c..9a8eb72d72925 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
op.getLoc(), vectorType.getElementType(),
rewriter.getZeroAttr(vectorType.getElementType()));
Value result =
- rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+ rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 501d98862672d..76afb3b6f256d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
/// ```
/// is rewritten into:
/// ```
-/// %r = splat %f0: vector<2x4xf32>
+/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
auto elemType = vType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
- Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
+ Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
@@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
/*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
adaptor.getOperands()[0]);
- Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+ Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indices, bounds);
rewriter.replaceOp(op, comp);
@@ -1767,63 +1767,79 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}
};
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 0-d and 1-d vector result types are lowered.
-struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
- using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
+/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
+/// pattern, the higher rank cases are handled by another pattern.
+struct VectorBroadcastScalarToLowRankLowering
+ : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
+ using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType resultType = cast<VectorType>(splatOp.getType());
+
+ if (isa<VectorType>(broadcast.getSourceType()))
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast from vector type not handled");
+
+ VectorType resultType = broadcast.getType();
if (resultType.getRank() > 1)
- return failure();
+ return rewriter.notifyMatchFailure(broadcast,
+ "broadcast to 2+-d handled elsewhere");
// First insert it into a poison vector so we can shuffle it.
- auto vectorType = typeConverter->convertType(splatOp.getType());
+ auto vectorType = typeConverter->convertType(broadcast.getType());
Value poison =
- rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
+ rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
- splatOp.getLoc(),
+ broadcast.getLoc(),
typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
// For 0-d vector, we simply do `insertelement`.
if (resultType.getRank() == 0) {
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- splatOp, vectorType, poison, adaptor.getInput(), zero);
+ broadcast, vectorType, poison, adaptor.getSource(), zero);
return success();
}
// For 1-d vector, we additionally do a `vectorshuffle`.
auto v = rewriter.create<LLVM::InsertElementOp>(
- splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
+ broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
- int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
+ int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
+ rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
zeroValues);
return success();
}
};
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 2+-d vector result types are lowered by the
-/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
-struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
- using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
+/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 2+-d vector result types are lowered by this
+/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
+/// are not converted to LLVM, only broadcasts from scalars are.
+struct VectorBroadcastScalarToNdLowering
+ : public ConvertOpToLLVMPattern<BroadcastOp> {
+ using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
+ matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType resultType = splatOp.getType();
+
+ if (isa<VectorType>(broadcast.getSourceType()))
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast from vector type not handled");
+
+ VectorType resultType = broadcast.getType();
if (resultType.getRank() <= 1)
- return failure();
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast to 1-d or 0-d handled elsewhere");
// First insert it into an undef vector so we can shuffle it.
- auto loc = splatOp.getLoc();
+ auto loc = broadcast.getLoc();
auto vectorTypeInfo =
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
@@ -1834,26 +1850,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
// Construct returned value.
Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
- // Construct a 1-D vector with the splatted value that we insert in all the
- // places within the returned descriptor.
+ // Construct a 1-D vector with the broadcasted value that we insert in all
+ // the places within the returned descriptor.
Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
- adaptor.getInput(), zero);
+ adaptor.getSource(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
SmallVector<int32_t> zeroValues(width, 0);
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
- // Iterate of linear index, convert to coords space and insert splatted 1-D
- // vector in each position.
+ // Iterate of linear index, convert to coords space and insert broadcasted
+ // 1-D vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
});
- rewriter.replaceOp(splatOp, desc);
+ rewriter.replaceOp(broadcast, desc);
return success();
}
};
@@ -2035,6 +2051,19 @@ struct VectorScalableStepOpLowering
}
};
+/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
+/// `vector.broadcast` through other patterns.
+struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+ adaptor.getInput());
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2063,7 +2092,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatOpLowering, VectorSplatNdOpLowering,
+ VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+ VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index ed51b2126dcdd..43732f58a4e0a 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -444,7 +444,7 @@ struct Strategy<TransferReadOp> {
Location loc = xferOp.getLoc();
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
- auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
+ auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
return Value();
@@ -1261,8 +1261,8 @@ struct UnrollTransferReadConversion
if (auto insertOp = getInsertOp(xferOp))
return insertOp.getDest();
Location loc = xferOp.getLoc();
- return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
/// If the result of the TransferReadOp has exactly one user, which is a
@@ -1583,8 +1583,8 @@ struct Strategy1d<TransferReadOp> {
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
// Inititalize vector with padding value.
Location loc = xferOp.getLoc();
- return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
};
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21d8e1d9f1156..de35015b81108 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};
+// Convert `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to SPIRV via other patterns.
+struct VectorSplatToBroadcast final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+ adaptor.getInput());
+ return success();
+ }
+};
+
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
@@ -556,22 +570,28 @@ struct VectorReductionFloatMinMax final
}
};
-class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
+class VectorScalarBroadcastPattern final
+ : public OpConversionPattern<vector::BroadcastOp> {
public:
- using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
+ matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+
+ if (isa<VectorType>(op.getSourceType())) {
+ return rewriter.notifyMatchFailure(
+ op, "only conversion of 'broadcast from scalar' is supported");
+ }
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
if (isa<spirv::ScalarType>(dstType)) {
- rewriter.replaceOp(op, adaptor.getInput());
+ rewriter.replaceOp(op, adaptor.getSource());
} else {
auto dstVecType = cast<VectorType>(dstType);
SmallVector<Value, 4> source(dstVecType.getNumElements(),
- adaptor.getInput());
+ adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
source);
}
@@ -1089,11 +1109,11 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorRed...
[truncated]
|
@llvm/pr-subscribers-backend-amdgpu Author: James Newling (newling) ChangesAdd conversion for broadcast from scalar for LLVM and SPIRV. Also some miscellaneous replacements of vector.splat with vector.broadcast in VectorToGPU and ArithToAMDGPU. Part of deprecation of vector.splat RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4 Patch is 37.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148027.diff 11 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cf9bb3a000050..156c679c5039e 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
if (inVecType.getShape().empty()) {
Value zerodSplat =
- rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
Value scalarExt =
@@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+ Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
@@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outVecType.getElementType());
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+ Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType extScaleResultType = VectorType::get(opWidth, outType);
if (!outVecType) {
- Value inCast =
- rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+ Value inCast = rewriter.create<vector::BroadcastOp>(
+ loc, VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
loc, extScaleResultType, inCast, scale, 0);
@@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Value zero = rewriter.create<arith::ConstantOp>(
loc, outType, rewriter.getFloatAttr(outType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
- rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
@@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
- Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
+ Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
@@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
int64_t blockSize = computeProduct(ratio);
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+ Value result =
+ rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
@@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
- rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+ rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 18adaa793787c..9a8eb72d72925 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
op.getLoc(), vectorType.getElementType(),
rewriter.getZeroAttr(vectorType.getElementType()));
Value result =
- rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+ rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 501d98862672d..76afb3b6f256d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
/// ```
/// is rewritten into:
/// ```
-/// %r = splat %f0: vector<2x4xf32>
+/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
@@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
auto elemType = vType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
- Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
+ Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
@@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
/*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
adaptor.getOperands()[0]);
- Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+ Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indices, bounds);
rewriter.replaceOp(op, comp);
@@ -1767,63 +1767,79 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}
};
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 0-d and 1-d vector result types are lowered.
-struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
- using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
+/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
+/// pattern, the higher rank cases are handled by another pattern.
+struct VectorBroadcastScalarToLowRankLowering
+ : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
+ using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType resultType = cast<VectorType>(splatOp.getType());
+
+ if (isa<VectorType>(broadcast.getSourceType()))
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast from vector type not handled");
+
+ VectorType resultType = broadcast.getType();
if (resultType.getRank() > 1)
- return failure();
+ return rewriter.notifyMatchFailure(broadcast,
+ "broadcast to 2+-d handled elsewhere");
// First insert it into a poison vector so we can shuffle it.
- auto vectorType = typeConverter->convertType(splatOp.getType());
+ auto vectorType = typeConverter->convertType(broadcast.getType());
Value poison =
- rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
+ rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
- splatOp.getLoc(),
+ broadcast.getLoc(),
typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
// For 0-d vector, we simply do `insertelement`.
if (resultType.getRank() == 0) {
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- splatOp, vectorType, poison, adaptor.getInput(), zero);
+ broadcast, vectorType, poison, adaptor.getSource(), zero);
return success();
}
// For 1-d vector, we additionally do a `vectorshuffle`.
auto v = rewriter.create<LLVM::InsertElementOp>(
- splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
+ broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);
- int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
+ int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
+ rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
zeroValues);
return success();
}
};
-/// The Splat operation is lowered to an insertelement + a shufflevector
-/// operation. Splat to only 2+-d vector result types are lowered by the
-/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
-struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
- using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
+/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
+/// operation. Only broadcasts to 2+-d vector result types are lowered by this
+/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
+/// are not converted to LLVM, only broadcasts from scalars are.
+struct VectorBroadcastScalarToNdLowering
+ : public ConvertOpToLLVMPattern<BroadcastOp> {
+ using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
+ matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType resultType = splatOp.getType();
+
+ if (isa<VectorType>(broadcast.getSourceType()))
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast from vector type not handled");
+
+ VectorType resultType = broadcast.getType();
if (resultType.getRank() <= 1)
- return failure();
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast to 1-d or 0-d handled elsewhere");
// First insert it into an undef vector so we can shuffle it.
- auto loc = splatOp.getLoc();
+ auto loc = broadcast.getLoc();
auto vectorTypeInfo =
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
@@ -1834,26 +1850,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
// Construct returned value.
Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
- // Construct a 1-D vector with the splatted value that we insert in all the
- // places within the returned descriptor.
+ // Construct a 1-D vector with the broadcasted value that we insert in all
+ // the places within the returned descriptor.
Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
- adaptor.getInput(), zero);
+ adaptor.getSource(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
SmallVector<int32_t> zeroValues(width, 0);
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
- // Iterate of linear index, convert to coords space and insert splatted 1-D
- // vector in each position.
+ // Iterate of linear index, convert to coords space and insert broadcasted
+ // 1-D vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
});
- rewriter.replaceOp(splatOp, desc);
+ rewriter.replaceOp(broadcast, desc);
return success();
}
};
@@ -2035,6 +2051,19 @@ struct VectorScalableStepOpLowering
}
};
+/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
+/// `vector.broadcast` through other patterns.
+struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+ adaptor.getInput());
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2063,7 +2092,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatOpLowering, VectorSplatNdOpLowering,
+ VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+ VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index ed51b2126dcdd..43732f58a4e0a 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -444,7 +444,7 @@ struct Strategy<TransferReadOp> {
Location loc = xferOp.getLoc();
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
- auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
+ auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
return Value();
@@ -1261,8 +1261,8 @@ struct UnrollTransferReadConversion
if (auto insertOp = getInsertOp(xferOp))
return insertOp.getDest();
Location loc = xferOp.getLoc();
- return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
/// If the result of the TransferReadOp has exactly one user, which is a
@@ -1583,8 +1583,8 @@ struct Strategy1d<TransferReadOp> {
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
// Inititalize vector with padding value.
Location loc = xferOp.getLoc();
- return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
- xferOp.getPadding());
+ return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
+ xferOp.getPadding());
}
};
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21d8e1d9f1156..de35015b81108 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};
+// Convert `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to SPIRV via other patterns.
+struct VectorSplatToBroadcast final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
+ adaptor.getInput());
+ return success();
+ }
+};
+
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
@@ -556,22 +570,28 @@ struct VectorReductionFloatMinMax final
}
};
-class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
+class VectorScalarBroadcastPattern final
+ : public OpConversionPattern<vector::BroadcastOp> {
public:
- using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
+ matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+
+ if (isa<VectorType>(op.getSourceType())) {
+ return rewriter.notifyMatchFailure(
+ op, "only conversion of 'broadcast from scalar' is supported");
+ }
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
if (isa<spirv::ScalarType>(dstType)) {
- rewriter.replaceOp(op, adaptor.getInput());
+ rewriter.replaceOp(op, adaptor.getSource());
} else {
auto dstVecType = cast<VectorType>(dstType);
SmallVector<Value, 4> source(dstVecType.getNumElements(),
- adaptor.getInput());
+ adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
source);
}
@@ -1089,11 +1109,11 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorRed...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SPIR-V LGTM
ConversionPatternRewriter &rewriter) const override { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: spurious newline
Add conversion for broadcast from scalar for LLVM and SPIRV. Also some miscellaneous replacements of vector.splat with vector.broadcast in VectorToGPU and ArithToAMDGPU.
Part of deprecation of vector.splat
RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
More complete deprecation: #147818