Skip to content

[mlir][vector] Support direct broadcast conversion (LLVM & SPIRV) #148027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,

if (inVecType.getShape().empty()) {
Value zerodSplat =
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
Value scalarExt =
Expand All @@ -166,7 +166,7 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,

VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);

if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
Expand Down Expand Up @@ -315,7 +315,7 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,

VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outVecType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);

if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
Expand Down Expand Up @@ -383,7 +383,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
Value result =
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);

if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
Expand Down Expand Up @@ -478,8 +479,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType extScaleResultType = VectorType::get(opWidth, outType);

if (!outVecType) {
Value inCast =
rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
Value inCast = rewriter.create<vector::BroadcastOp>(
loc, VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
loc, extScaleResultType, inCast, scale, 0);
Expand Down Expand Up @@ -509,7 +510,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,

Value zero = rewriter.create<arith::ConstantOp>(
loc, outType, rewriter.getFloatAttr(outType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
Value result =
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);

for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
Expand All @@ -523,7 +525,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,

VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);

for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
Expand Down Expand Up @@ -587,7 +589,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,

if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
Expand Down Expand Up @@ -616,7 +618,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,

int64_t blockSize = computeProduct(ratio);

Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
Value result =
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);

for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
Expand All @@ -630,7 +633,7 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,

VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);

for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
op.getLoc(), vectorType.getElementType(),
rewriter.getZeroAttr(vectorType.getElementType()));
Value result =
rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
rewriter.create<vector::BroadcastOp>(op.getLoc(), vectorType, fill);

bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();

Expand Down
94 changes: 61 additions & 33 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ struct VectorScalableExtractOpLowering
/// ```
/// is rewritten into:
/// ```
/// %r = splat %f0: vector<2x4xf32>
/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
Expand Down Expand Up @@ -1441,7 +1441,7 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
auto elemType = vType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
Value desc = rewriter.create<vector::BroadcastOp>(loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
Expand Down Expand Up @@ -1583,7 +1583,7 @@ class VectorCreateMaskOpConversion
/*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
adaptor.getOperands()[0]);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
Value bounds = rewriter.create<BroadcastOp>(loc, indices.getType(), bound);
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indices, bounds);
rewriter.replaceOp(op, comp);
Expand Down Expand Up @@ -1767,63 +1767,77 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}
};

/// The Splat operation is lowered to an insertelement + a shufflevector
/// operation. Splat to only 0-d and 1-d vector result types are lowered.
struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
/// pattern, the higher rank cases are handled by another pattern.
struct VectorBroadcastScalarToLowRankLowering
: public ConvertOpToLLVMPattern<vector::BroadcastOp> {
using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = cast<VectorType>(splatOp.getType());
if (isa<VectorType>(broadcast.getSourceType()))
return rewriter.notifyMatchFailure(
broadcast, "broadcast from vector type not handled");

VectorType resultType = broadcast.getType();
if (resultType.getRank() > 1)
return failure();
return rewriter.notifyMatchFailure(broadcast,
"broadcast to 2+-d handled elsewhere");

// First insert it into a poison vector so we can shuffle it.
auto vectorType = typeConverter->convertType(splatOp.getType());
auto vectorType = typeConverter->convertType(broadcast.getType());
Value poison =
rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
rewriter.create<LLVM::PoisonOp>(broadcast.getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
splatOp.getLoc(),
broadcast.getLoc(),
typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));

// For 0-d vector, we simply do `insertelement`.
if (resultType.getRank() == 0) {
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
splatOp, vectorType, poison, adaptor.getInput(), zero);
broadcast, vectorType, poison, adaptor.getSource(), zero);
return success();
}

// For 1-d vector, we additionally do a `vectorshuffle`.
auto v = rewriter.create<LLVM::InsertElementOp>(
splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero);

int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);

// Shuffle the value across the desired number of elements.
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
zeroValues);
return success();
}
};

/// The Splat operation is lowered to an insertelement + a shufflevector
/// operation. Splat to only 2+-d vector result types are lowered by the
/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
/// operation. Only broadcasts to 2+-d vector result types are lowered by this
/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
/// are not converted to LLVM, only broadcasts from scalars are.
struct VectorBroadcastScalarToNdLowering
: public ConvertOpToLLVMPattern<BroadcastOp> {
using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = splatOp.getType();
if (isa<VectorType>(broadcast.getSourceType()))
return rewriter.notifyMatchFailure(
broadcast, "broadcast from vector type not handled");

VectorType resultType = broadcast.getType();
if (resultType.getRank() <= 1)
return failure();
return rewriter.notifyMatchFailure(
broadcast, "broadcast to 1-d or 0-d handled elsewhere");

// First insert it into an undef vector so we can shuffle it.
auto loc = splatOp.getLoc();
auto loc = broadcast.getLoc();
auto vectorTypeInfo =
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
Expand All @@ -1834,26 +1848,26 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
// Construct returned value.
Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);

// Construct a 1-D vector with the splatted value that we insert in all the
// places within the returned descriptor.
// Construct a 1-D vector with the broadcasted value that we insert in all
// the places within the returned descriptor.
Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
adaptor.getInput(), zero);
adaptor.getSource(), zero);

// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
SmallVector<int32_t> zeroValues(width, 0);
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);

// Iterate of linear index, convert to coords space and insert splatted 1-D
// vector in each position.
// Iterate of linear index, convert to coords space and insert broadcasted
// 1-D vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
});
rewriter.replaceOp(splatOp, desc);
rewriter.replaceOp(broadcast, desc);
return success();
}
};
Expand Down Expand Up @@ -2035,6 +2049,19 @@ struct VectorScalableStepOpLowering
}
};

/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
/// `vector.broadcast` through other patterns.
struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
adaptor.getInput());
return success();
}
};

} // namespace

void mlir::vector::populateVectorRankReducingFMAPattern(
Expand Down Expand Up @@ -2063,7 +2090,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ struct Strategy<TransferReadOp> {
Location loc = xferOp.getLoc();
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
auto vec = b.create<vector::BroadcastOp>(loc, vecType, xferOp.getPadding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);

return Value();
Expand Down Expand Up @@ -1261,8 +1261,8 @@ struct UnrollTransferReadConversion
if (auto insertOp = getInsertOp(xferOp))
return insertOp.getDest();
Location loc = xferOp.getLoc();
return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
xferOp.getPadding());
return rewriter.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
xferOp.getPadding());
}

/// If the result of the TransferReadOp has exactly one user, which is a
Expand Down Expand Up @@ -1583,8 +1583,8 @@ struct Strategy1d<TransferReadOp> {
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
// Inititalize vector with padding value.
Location loc = xferOp.getLoc();
return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
xferOp.getPadding());
return b.create<vector::BroadcastOp>(loc, xferOp.getVectorType(),
xferOp.getPadding());
}
};

Expand Down
39 changes: 29 additions & 10 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};

// Convert `vector.splat` to `vector.broadcast`. There is a path from
// `vector.broadcast` to SPIRV via other patterns.
struct VectorSplatToBroadcast final
: public OpConversionPattern<vector::SplatOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
adaptor.getInput());
return success();
}
};

struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -556,22 +570,27 @@ struct VectorReductionFloatMinMax final
}
};

class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
class VectorScalarBroadcastPattern final
: public OpConversionPattern<vector::BroadcastOp> {
public:
using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<VectorType>(op.getSourceType())) {
return rewriter.notifyMatchFailure(
op, "only conversion of 'broadcast from scalar' is supported");
}
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
if (isa<spirv::ScalarType>(dstType)) {
rewriter.replaceOp(op, adaptor.getInput());
rewriter.replaceOp(op, adaptor.getSource());
} else {
auto dstVecType = cast<VectorType>(dstType);
SmallVector<Value, 4> source(dstVecType.getNumElements(),
adaptor.getInput());
adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
source);
}
Expand Down Expand Up @@ -1089,11 +1108,11 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
VectorStepOpConvert>(typeConverter, patterns.getContext(),
PatternBenefit(1));
VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
VectorShuffleOpConvert, VectorInterleaveOpConvert,
VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
typeConverter, patterns.getContext(), PatternBenefit(1));

// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
Expand Down
Loading