diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 68095b7bf5c59..612ae3fac2d77 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1364,6 +1364,35 @@ def GPU_ShuffleOp : GPU_Op< ]; } +def GPU_RotateOp : GPU_Op< + "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>, + Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>, + Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> { + let summary = "Rotate values within a subgroup."; + let description = [{ + The "rotate" op moves values to a across lanes circularly (a.k.a., + invocations, work items) within the same subgroup. The `width` argument + specifies the number of lanes that participate in the rotation, and must + be uniform across all lanes. Further, the first `width` lanes of the + subgroup must be active. + + Example: + + ```mlir + %cst1 = arith.constant 1 : i32 + %width = arith.constant 16 : i32 + %1 = gpu.rotate %0, %cst1, %width : f32 + ``` + + For lane 0 < `k` < 16, return the value from lane `(k - 1) % width`. + For lane k == 0, return the value from lane 15. + }]; + + let assemblyFormat = [{ + $value `,` $offset `,` $width attr-dict `:` type($value) + }]; +} + def GPU_BarrierOp : GPU_Op<"barrier"> { let summary = "Synchronizes all work items of a workgroup."; let description = [{ @@ -1919,6 +1948,79 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix", }]; } +def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract", + [Pure, + TypesMatchWith<"value type matches element type of mma_matrix", + "matrix", "res", + "::llvm::cast($_self).getElementType()">]>{ + + let summary = "Extract a value from GPU warp by invocation and indices"; + + let description = [{ + The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix` + by the invocation in a subgroup. + + This operation takes `!gpu.mma_matrix` as its first operand. It is the source + matrix across a subgroup. The op returns a scalar value stored in the invocation + in the subgroup. If there are multiple values packed in an invocation, use + `indices` to specify the element to extract. + + Example: + + ```mlir + %c0 = arith.constant 0 : index + %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32 + ``` + }]; + + let arguments = (ins GPU_MMAMatrix:$matrix, Variadic:$indices); + + let results = (outs AnyIntegerOrFloat:$res); + + let assemblyFormat = [{ + $matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res) + }]; +} + +def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert", + [Pure, + TypesMatchWith<"value type matches element type of mma_matrix", + "matrix", "value", + "::llvm::cast($_self).getElementType()"> ]>{ + + let summary = "Insert a value into GPU warp by invocation and indices"; + + let description = [{ + The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix` + by the invocation in a subgroup. + + This operation takes scalar value as its first operand and `!gpu.mma_matrix` + as its second operand. It is the matrix across a subgroup. The op inserts the + scalar value stored in the invocation in the subgroup to the matrix. If there + are multiple values packed in an invocation, use `indices` to specify the + location to insert in the packing. + + The op returns `!gpu.mma_matrix` with the updated value. + + Example: + + ```mlir + %c0 = arith.constant 0 : index + %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> + -> !gpu.mma_matrix<16x16xf16, "COp"> + ``` + }]; + + let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix, + Variadic:$indices); + + let results = (outs GPU_MMAMatrix:$res); + + let assemblyFormat = [{ + $value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res) + }]; +} + def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">; def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">; def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">; diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 3cc64b82950b5..e96709e4b4a35 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override; }; +/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHR op. +class GPURotateConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + class GPUPrintfConversion final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -458,6 +468,35 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( return success(); } +//===----------------------------------------------------------------------===// +// Rotate +//===----------------------------------------------------------------------===// + +LogicalResult GPURotateConversion::matchAndRewrite( + gpu::RotateOp rotateOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Require the rotate width to be the same as the target's subgroup size, + // given that for SPIR-V non-uniform subgroup ops, we cannot select + // participating invocations. + auto targetEnv = getTypeConverter()->getTargetEnv(); + unsigned subgroupSize = + targetEnv.getAttr().getResourceLimits().getSubgroupSize(); + IntegerAttr widthAttr; + if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) || + widthAttr.getValue().getZExtValue() != subgroupSize) + return rewriter.notifyMatchFailure( + rotateOp, "rotate width and target subgroup size mismatch"); + + Location loc = rotateOp.getLoc(); + auto scope = rewriter.getAttr(spirv::Scope::Subgroup); + + Value result = rewriter.create( + loc, scope, adaptor.getValue(), adaptor.getOffset(), rotateOp.getWidth()); + + rewriter.replaceOp(rotateOp, result); + return success(); +} + //===----------------------------------------------------------------------===// // Group ops //===----------------------------------------------------------------------===// @@ -733,7 +772,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion, - GPUReturnOpConversion, GPUShuffleConversion, + GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion, diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index df2da138d3b52..78d266693fc2a 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final } }; +/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative +/// matrix ops. +struct WmmaExtractOpToSPIRVLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value matrix = adaptor.getMatrix(); + auto coopType = + getTypeConverter()->convertType( + matrix.getType()); + if (!coopType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + SmallVector intValues; + for (Value val : op.getIndices()) { + if (auto constOp = val.getDefiningOp()) { + intValues.push_back(static_cast(constOp.value())); + } else { + return rewriter.notifyMatchFailure(op, "Indices must be constants"); + } + } + + Type elementType = coopType.getElementType(); + rewriter.replaceOpWithNewOp( + op, elementType, matrix, rewriter.getI32ArrayAttr(intValues)); + return success(); + } +}; + +/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative +/// matrix ops. +struct WmmaInsertOpToSPIRVLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value value = adaptor.getValue(); + Value matrix = adaptor.getMatrix(); + auto coopType = getTypeConverter()->convertType(matrix.getType()); + if (!coopType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + SmallVector intValues; + for (Value val : op.getIndices()) { + if (auto constOp = val.getDefiningOp()) { + intValues.push_back(static_cast(constOp.value())); + } else { + return rewriter.notifyMatchFailure(op, "Indices must be constants"); + } + } + + rewriter.replaceOpWithNewOp( + op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues)); + return success(); + } +}; + /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for /// the default case. struct WmmaElementwiseOpToSPIRVDefaultLowering final @@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( MLIRContext *context = patterns.getContext(); patterns.add(converter, context); // Give the following patterns higher benefit to prevail over the default one. patterns.add(converter, context, diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir new file mode 100644 index 0000000000000..102c2fb01edb6 --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> +} { + +gpu.module @kernels { + // CHECK-LABEL: spirv.func @rotate() + gpu.func @rotate() kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: %[[CST8_I32:.*]] = spirv.Constant 8 : i32 + // CHECK: %[[CST16_I32:.*]] = spirv.Constant 16 : i32 + // CHECK: %[[CST_F32:.*]] = spirv.Constant 4.200000e+01 : f32 + %offset = arith.constant 8 : i32 + %width = arith.constant 16 : i32 + %val = arith.constant 42.0 : f32 + + // CHECK: spirv.GroupNonUniformRotateKHR , %[[CST_F32]], %[[CST8_I32]], cluster_size(%[[CST16_I32]]) + %result = gpu.rotate %val, %offset, %width : f32 + gpu.return + } +} + +} diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir index 477f344b1ae5f..3e8a3b21e7e94 100644 --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir @@ -93,6 +93,33 @@ module attributes { gpu.return } + // CHECK-LABEL: spirv.func @gpu_wmma_extract_op + // CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA> + gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">, + %ptr: memref<16x16xf32, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA> + %c0 = arith.constant 0 : index + %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32 + memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class> + gpu.return + } + + // CHECK-LABEL: spirv.func @gpu_wmma_insert_op + // CHECK-SAME: %[[ARG0:.+]]: f16 + // CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + gpu.func @gpu_wmma_insert_op(%val: f16, + %m: !gpu.mma_matrix<16x16xf16, "COp">, + %ptr: memref<16x16xf16, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %c0 = arith.constant 0 : index + %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} : + !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class> + gpu.return + } + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>