Skip to content

[rfc][mlir][gpu] Add operations to extract/insert/rotate within subgroup #139048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
@@ -1364,6 +1364,35 @@ def GPU_ShuffleOp : GPU_Op<
];
}

def GPU_RotateOp : GPU_Op<
"rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> {
let summary = "Rotate values within a subgroup.";
let description = [{
The "rotate" op moves values to a across lanes circularly (a.k.a.,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The "rotate" op moves values to a across lanes circularly (a.k.a.,
The "rotate" op moves values across lanes circularly (a.k.a.,

?

invocations, work items) within the same subgroup. The `width` argument
specifies the number of lanes that participate in the rotation, and must
be uniform across all lanes. Further, the first `width` lanes of the

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
be uniform across all lanes. Further, the first `width` lanes of the
be uniform across all lanes. Furthermore, the first `width` lanes of the

Or just remove Further?

subgroup must be active.

Example:

```mlir
%cst1 = arith.constant 1 : i32
%width = arith.constant 16 : i32
%1 = gpu.rotate %0, %cst1, %width : f32
```

For lane 0 < `k` < 16, return the value from lane `(k - 1) % width`.
For lane k == 0, return the value from lane 15.
Comment on lines +1387 to +1388

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For lane 0 < `k` < 16, return the value from lane `(k - 1) % width`.
For lane k == 0, return the value from lane 15.
For lane `0 < k < 16`, return the value from lane `(k - 1) % width`.
For lane `k == 0`, return the value from lane 15.

}];

let assemblyFormat = [{
$value `,` $offset `,` $width attr-dict `:` type($value)
}];
}

def GPU_BarrierOp : GPU_Op<"barrier"> {
let summary = "Synchronizes all work items of a workgroup.";
let description = [{
@@ -1919,6 +1948,79 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
}];
}

def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
[Pure,
TypesMatchWith<"value type matches element type of mma_matrix",
"matrix", "res",
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{

let summary = "Extract a value from GPU warp by invocation and indices";

let description = [{
The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix`
by the invocation in a subgroup.

This operation takes `!gpu.mma_matrix` as its first operand. It is the source
matrix across a subgroup. The op returns a scalar value stored in the invocation
in the subgroup. If there are multiple values packed in an invocation, use
`indices` to specify the element to extract.

Example:

```mlir
%c0 = arith.constant 0 : index
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
```
}];

let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);

let results = (outs AnyIntegerOrFloat:$res);

let assemblyFormat = [{
$matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
}];
}

def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
[Pure,
TypesMatchWith<"value type matches element type of mma_matrix",
"matrix", "value",
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{

let summary = "Insert a value into GPU warp by invocation and indices";

let description = [{
The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix`
by the invocation in a subgroup.

This operation takes scalar value as its first operand and `!gpu.mma_matrix`
as its second operand. It is the matrix across a subgroup. The op inserts the
scalar value stored in the invocation in the subgroup to the matrix. If there
are multiple values packed in an invocation, use `indices` to specify the
location to insert in the packing.

The op returns `!gpu.mma_matrix` with the updated value.

Example:

```mlir
%c0 = arith.constant 0 : index
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
-> !gpu.mma_matrix<16x16xf16, "COp">
```
}];

let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
Variadic<Index>:$indices);

let results = (outs GPU_MMAMatrix:$res);

let assemblyFormat = [{
$value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
}];
}

def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;
41 changes: 40 additions & 1 deletion mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
@@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) const override;
};

/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHR op.
class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -458,6 +468,35 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
return success();
}

//===----------------------------------------------------------------------===//
// Rotate
//===----------------------------------------------------------------------===//

LogicalResult GPURotateConversion::matchAndRewrite(
gpu::RotateOp rotateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Require the rotate width to be the same as the target's subgroup size,
// given that for SPIR-V non-uniform subgroup ops, we cannot select
// participating invocations.
auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
IntegerAttr widthAttr;
if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
widthAttr.getValue().getZExtValue() != subgroupSize)
return rewriter.notifyMatchFailure(
rotateOp, "rotate width and target subgroup size mismatch");

Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);

Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
loc, scope, adaptor.getValue(), adaptor.getOffset(), rotateOp.getWidth());

rewriter.replaceOp(rotateOp, result);
return success();
}

//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
@@ -733,7 +772,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
GPUReturnOpConversion, GPUShuffleConversion,
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
63 changes: 63 additions & 0 deletions mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Original file line number Diff line number Diff line change
@@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final
}
};

/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
/// matrix ops.
struct WmmaExtractOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaExtractOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value matrix = adaptor.getMatrix();
auto coopType =
getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
matrix.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

SmallVector<int32_t> intValues;
for (Value val : op.getIndices()) {
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
intValues.push_back(static_cast<int32_t>(constOp.value()));
} else {
return rewriter.notifyMatchFailure(op, "Indices must be constants");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return rewriter.notifyMatchFailure(op, "Indices must be constants");
return rewriter.notifyMatchFailure(op, "indices must be constants");

}
}

Type elementType = coopType.getElementType();
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
return success();
}
};

/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
/// matrix ops.
struct WmmaInsertOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaInsertOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value value = adaptor.getValue();
Value matrix = adaptor.getMatrix();
auto coopType = getTypeConverter()->convertType(matrix.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

SmallVector<int32_t> intValues;
for (Value val : op.getIndices()) {
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
intValues.push_back(static_cast<int32_t>(constOp.value()));
} else {
return rewriter.notifyMatchFailure(op, "Indices must be constants");
}
}

rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
return success();
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering final
@@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
MLIRContext *context = patterns.getContext();
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
25 changes: 25 additions & 0 deletions mlir/test/Conversion/GPUToSPIRV/rotate.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>, #spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @rotate()
gpu.func @rotate() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [4, 4, 1]>} {
// CHECK: %[[CST8_I32:.*]] = spirv.Constant 8 : i32
// CHECK: %[[CST16_I32:.*]] = spirv.Constant 16 : i32
// CHECK: %[[CST_F32:.*]] = spirv.Constant 4.200000e+01 : f32
%offset = arith.constant 8 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32

// CHECK: spirv.GroupNonUniformRotateKHR <Subgroup>, %[[CST_F32]], %[[CST8_I32]], cluster_size(%[[CST16_I32]])
%result = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}

}
Original file line number Diff line number Diff line change
@@ -93,6 +93,33 @@ module attributes {
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_extract_op
// CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
%c0 = arith.constant 0 : index
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_insert_op
// CHECK-SAME: %[[ARG0:.+]]: f16
// CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.func @gpu_wmma_insert_op(%val: f16,
%m: !gpu.mma_matrix<16x16xf16, "COp">,
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%c0 = arith.constant 0 : index
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>