Skip to content

[mlir][xegpu] Handle scalar uniform ops in SIMT distribution. #138593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

charithaintc
Copy link
Contributor

@charithaintc charithaintc commented May 5, 2025

This PR adds support for moving scalar uniform (gpu index ops, constants etc) outside the gpu.warp_execute_on_lane0 op. These kinds of ops do not require distribution and are safe to move out of the warp op. This also avoid adding separate distribution patterns for these ops.

Example:

   %1 = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
     ...
     %block_id_x = gpu.block_id x
     gpu.yield %block_id_x
   }
  // use %1

To:

   %block_id_x = gpu.block_id x
   %1 = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
     ...
     
     gpu.yield %block_id_x
   }
  // use %1

@charithaintc charithaintc changed the title [mlir][xegpu] Handle GPU index ops in SIMT distribution. [mlir][xegpu] Handle scalar uniform ops in SIMT distribution. May 6, 2025
@charithaintc charithaintc requested review from adam-smnk and chencha3 May 6, 2025 02:59
@charithaintc
Copy link
Contributor Author

@Garra1980 Can you take a look?

@charithaintc charithaintc marked this pull request as ready for review May 6, 2025 03:01
@llvmbot
Copy link
Member

llvmbot commented May 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Charitha Saumya (charithaintc)

Changes

This PR adds support for moving scalar uniform (gpu index ops, constants etc) outside the gpu.warp_execute_on_lane0 op. These kinds of ops do not require distribution and are safe to move out of the warp op. This also avoid adding separate distribution patterns for these ops.

Example:

   %1 = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
     ...
     %block_id_x = gpu.block_id x
     gpu.yield %block_id_x
   }
  // use %1

To:

   %block_id_x = gpu.block_id x
   %1 = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
     ...
     
     gpu.yield %block_id_x
   }
  // use %1


Full diff: https://github.com/llvm/llvm-project/pull/138593.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+5)
  • (modified) mlir/test/Dialect/XeGPU/subgroup-distribution.mlir (+48-1)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 019032f7743bf..d4e57cfb4b4a5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1008,6 +1008,11 @@ struct MoveFuncBodyToWarpExecuteOnLane0
     rewriter.setInsertionPointAfter(warpOp);
     rewriter.create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults());
     rewriter.replaceOp(gpuFuncOp, newGpuFunc);
+    // At this point, we have moved the entire function body inside the warpOp.
+    // Now move any scalar uniform code outside of the warpOp (like GPU index
+    // ops, scalar constants, etc.). This will simplify the later lowering and
+    // avoid custom patterns for these ops.
+    vector::moveScalarUniformCode(warpOp);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
index f8f2cd55c28d0..4e50771c26283 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xegpu-subgroup-distribute -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xegpu-subgroup-distribute -cse -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: gpu.func @store_nd_1d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
@@ -160,3 +160,50 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
   gpu.return
 }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @gemm_loop
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
+// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
+// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
+// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
+// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
+// CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+// CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
+// CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+// CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
+// CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
+// CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
+// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: scf.yield %[[T16]] : vector<8x1xf32>
+// CHECK: }
+// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
+// CHECK: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+gpu.module @test {
+gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c8 = arith.constant 8 : index
+  %c1024 = arith.constant 1024 : index
+  %0 = gpu.block_id x
+  %1 = gpu.block_id y
+  %2 = arith.muli %0, %c8 : index
+  %3 = arith.muli %1, %c16 : index
+  %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+  %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) {
+    %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+    %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+    %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
+    %10 = xegpu.load_nd %8 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
+    %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %11 : vector<8x16xf32>
+  }
+  xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.return
+}
+}

@charithaintc charithaintc requested a review from fschlimb May 6, 2025 03:07
Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Garra1980
Copy link

Thanks, this looks great

@charithaintc charithaintc requested a review from mshahneo May 7, 2025 15:15
Copy link
Contributor

@mshahneo mshahneo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@charithaintc charithaintc merged commit 7a66746 into llvm:main May 8, 2025
5 of 9 checks passed
lenary added a commit to lenary/llvm-project that referenced this pull request May 9, 2025
* main: (420 commits)
  [AArch64] Merge scaled and unscaled narrow zero stores (llvm#136705)
  [RISCV] One last migration to getInsertSubvector [nfc]
  [flang][OpenMP] Update `do concurrent` mapping pass to use `fir.do_concurrent` op (llvm#138489)
  [MLIR][LLVM] Fix llvm.mlir.global mismatching print and parser order (llvm#138986)
  [lld][NFC] Fix minor typo in docs (llvm#138898)
  [RISCV] Migrate getConstant indexed insert/extract subvector to new API (llvm#139111)
  GlobalISel: Translate minimumnum and maximumnum (llvm#139106)
  [MemProf] Simplify unittest save and restore of options (llvm#139117)
  [BOLT][AArch64] Patch functions targeted by optional relocs (llvm#138750)
  [Coverage] Support -fprofile-list for cold function coverage (llvm#136333)
  Remove unused forward decl (llvm#139108)
  [AMDGPU][NFC] Get rid of OPW constants. (llvm#139074)
  [CIR] Upstream extract op for VectorType (llvm#138413)
  [mlir][xegpu] Handle scalar uniform ops in SIMT distribution.  (llvm#138593)
  [GlobalISel][AMDGPU] Fix handling of v2i128 type for AND, OR, XOR (llvm#138574)
  AMDGPU][True16][CodeGen] FP_Round f64 to f16 in true16 (llvm#128911)
  Reland [Clang] Deprecate `__is_trivially_relocatable` (llvm#139061)
  [HLSL][NFC] Stricter Overload Tests (clamp,max,min,pow) (llvm#138993)
  [MLIR] Fixing the memref linearization size computation for non-packed memref (llvm#138922)
  [TableGen][NFC] Use early exit to simplify large block in emitAction. (llvm#138220)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants