Skip to content

[mlir] [XeGPU] Add XeGPU workgroup to subgroup pass #139477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented May 11, 2025

This PR adds the XeGPU workgroup (wg) to subgroup (sg) pass. The wg to sg pass transforms the xegpu wg level operations to subgroup operations based on the sg_layout and sg_data attribute. The PR adds transformation patterns for following Ops

  1. CreateNdDesc
  2. LoadNd
  3. StoreNd
  4. PrefetchNd
  5. UpdateNdOffset
  6. Dpas

@llvmbot
Copy link
Member

llvmbot commented May 11, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds the XeGPU workgroup (wg) to subgroup (sg) pass. The wg to sg pass transforms the xegpu wg level operations to subgroup operations based on the sg_layout and sg_data attribute. The PR adds transformation patterns for following Ops

  1. CreateNdDesc
  2. LoadNd
  3. StoreNd
  4. PrefetchNd
  5. UpdateNdOffset
  6. Dpas

Patch is 32.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139477.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td (+19-12)
  • (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h (+4)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp (+386)
  • (added) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+66)
  • (added) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+82)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 3e81f2d0ed786..bdea88cfd7022 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
 
@@ -18,9 +17,7 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
     The pass folds aliasing ops into XeGPU ops that they operate on the original
     source references.
   }];
-  let dependentDialects = [
-      "memref::MemRefDialect", "xegpu::XeGPUDialect"
-  ];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect"];
 }
 
 def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
@@ -28,14 +25,24 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   let description = [{
     The pass distributes subgroup level (SIMD) XeGPU ops to work items.
   }];
-  let dependentDialects = [
-      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
-  ];
-  let options = [
-    Option<"printOnly", "print-analysis-only", "bool",
-         /*default=*/"false",
-         "Print the result of the subgroup map propagation analysis and exit.">
-  ];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect"];
+  let options = [Option<
+      "printOnly", "print-analysis-only", "bool",
+      /*default=*/"false",
+      "Print the result of the subgroup map propagation analysis and exit.">];
+}
+
+def XeGPUWgToSg : Pass<"xegpu-wg-to-sg", "::mlir::gpu::GPUModuleOp"> {
+  let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
+  let description = [{
+    This transform pass distributes the workgroup level computation to
+    multiple subgroups based on the sg_layout and sg_data attributes.
+  }];
+
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect", "arith::ArithDialect",
+                           "gpu::GPUDialect", "index::IndexDialect"];
 }
 
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 3e94021c7a1ea..388ba32e1eebb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Transforms/DialectConversion.h"
+
 namespace mlir {
 class RewritePatternSet;
 
@@ -18,6 +20,8 @@ namespace xegpu {
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns,
+                                 ConversionTarget &target);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 901e02d3c9cf5..b258921cc87fd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
   XeGPUSubgroupDistribute.cpp
+  XeGPUWgToSg.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
new file mode 100644
index 0000000000000..5eabb04e3b858
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -0,0 +1,386 @@
+//===- XeGPUWgToSg.cpp - XeGPU WorkGroup to Subgroup Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+#include <mlir/Dialect/GPU/IR/GPUDialect.h>
+#include <mlir/Dialect/Index/IR/IndexOps.h>
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUWGTOSG
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-wg-to-sg"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+// clang-format off
+/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
+/// from a workgroup descriptor. It replaces the offsets and sizes with
+/// appropriate values for the subgroup.
+/// It uses round-robin assignment to distribute the work to the subgroups.
+/// Following create_nd_desc operation:,
+///    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
+///       -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
+///           sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+/// is converted to 9 subgroup level operations based on the sg_layout & sg_data:
+///    %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
+///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+///
+/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
+///
+/// 24x24 matrix distribution example:
+/// sg_layout = [4, 4], sg_data = [2, 2]
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
+///
+/// +------------------------+
+/// | 8x8 | 8x8 | 8x8 |      <- 3 tiles across
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |      <- 3 tiles down
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |
+/// +------------------------+
+///
+/// Each 8x8 tile is further subdivided among subgroups:
+/// +------------------------+
+/// | 2x2 2x2 2x2 2x2 |  <- 4 subgroups across (each handles 2 columns)
+/// | 2x2 2x2 2x2 2x2 |  <- 4 subgroups down (each handles 2 rows)
+/// | 2x2 2x2 2x2 2x2 |
+/// | 2x2 2x2 2x2 2x2 |
+/// +------------------------+
+///
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
+/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
+// clang-format on
+struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+  // Helper to extract mixed offsets into a Value array
+  SmallVector<Value> extractOffsets(ConversionPatternRewriter &rewriter,
+                                    xegpu::CreateNdDescOp op) const {
+    llvm::SmallVector<Value> offsets;
+    auto staticOffsets = op.getStaticOffsets();
+    auto dynamicOffsets = op.getOffsets();
+
+    for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
+      if (ShapedType::isDynamic(staticOffsets[i])) {
+        offsets.push_back(dynamicOffsets[j++]);
+      } else {
+        offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
+            op.getLoc(), staticOffsets[i]));
+      }
+    }
+    return offsets;
+  }
+
+  // Convert linear subgroup ID to 2D coordinates
+  // TODO: Delinearize for nD
+  SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
+                                           Location loc, Value sgID,
+                                           Value sgDimX, Value sgDimY) const {
+    return {rewriter.create<index::DivUOp>(loc, sgID, sgDimY),
+            rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
+  }
+
+  // Create a constant index value
+  Value createConstantIndex(ConversionPatternRewriter &rewriter, Location loc,
+                            int64_t value) const {
+    return rewriter.create<arith::ConstantIndexOp>(loc, value);
+  }
+
+  // Calculate offset for each subgroup
+  SmallVector<OpFoldResult>
+  calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
+                         const SmallVector<Value> &originalOffsets,
+                         const SmallVector<Value> &localOffset,
+                         const SmallVector<int64_t> &distUnitBaseAddr) const {
+
+    Value constOffsetX =
+        createConstantIndex(rewriter, loc, distUnitBaseAddr[0]);
+    Value constOffsetY =
+        createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
+
+    Value offsetX =
+        rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
+    Value offsetY =
+        rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
+
+    size_t lastDimIndex = originalOffsets.size() - 1;
+    size_t secondLastDimIndex = lastDimIndex - 1;
+
+    Value globalOffsetX = rewriter.createOrFold<index::AddOp>(
+        loc, originalOffsets[secondLastDimIndex], offsetX);
+    Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
+        loc, originalOffsets[lastDimIndex], offsetY);
+
+    SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
+                                            originalOffsets.end());
+    globalOffsets[secondLastDimIndex] = globalOffsetX;
+    globalOffsets[lastDimIndex] = globalOffsetY;
+
+    return globalOffsets;
+  }
+
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *ctx = op.getContext();
+    xegpu::TensorDescType tdescTy = op.getType();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+    Type elemTy = tdescTy.getElementType();
+    ArrayRef<int64_t> wgShape = tdescTy.getShape();
+    ArrayRef<int64_t> sgShape =
+        llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
+    ArrayRef<int64_t> sgLayout =
+        llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
+
+    // Get the subgroup ID
+    auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
+
+    // Create constants for layout dimensions
+    SmallVector<Value> sgLayoutDim(sgLayout.size());
+    SmallVector<Value> sgDataDim(sgShape.size());
+
+    for (size_t i = 0; i < sgLayout.size(); i++) {
+      sgLayoutDim[i] = createConstantIndex(rewriter, loc, sgLayout[i]);
+      sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
+    }
+
+    // Delinearize the 1D subgroup id into 2d
+    SmallVector<Value> sgIds = delinearizeSubgroupId(
+        rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
+
+    // Calculate distribution unit shape and local offsets for subgroup
+    SmallVector<int64_t> distUnitShape(sgLayout.size());
+    SmallVector<Value> localOffset(sgLayout.size());
+    for (size_t i = 0; i < sgLayout.size(); i++) {
+      distUnitShape[i] = sgLayout[i] * sgShape[i];
+      localOffset[i] =
+          rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
+    }
+
+    SmallVector<Value> originalOffsets = extractOffsets(rewriter, op);
+
+    xegpu::TensorDescType newTdescTy =
+        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+                                   layout.dropSgLayoutAndData());
+    SmallVector<Value> newCreateNdOps;
+    for (const SmallVector<int64_t> &distUnitBaseAddr :
+         StaticTileOffsetRange(wgShape, distUnitShape)) {
+      SmallVector<OpFoldResult> globalOffsets = calculateGlobalOffsets(
+          rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr);
+
+      auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
+          loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
+          op.getMixedStrides());
+      newCreateNdOps.push_back(newCreateNdOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+    return success();
+  }
+};
+
+/// This pattern transforms the LoadNdOp to load subgroup data.
+struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> newLoadOps;
+    for (auto src : adaptor.getTensorDesc()) {
+      xegpu::TensorDescType tdescTy =
+          dyn_cast<xegpu::TensorDescType>(src.getType());
+      ArrayRef<int64_t> srcShape = tdescTy.getShape();
+      VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
+      auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(op.getLoc(), newResTy,
+                                                        src, op->getAttrs());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return mlir::success();
+  }
+};
+
+/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
+/// It creates a StoreNdOp op to store the updated values to the new subgroup
+/// src tensor descriptors.
+struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
+  using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
+      rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(),
+                                        op.getL2HintAttr(), op.getL3HintAttr());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
+/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
+/// offsets of the new subgroup src tensor descriptors.
+struct WgToSgUpdateNdOffsetOp
+    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+  using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    llvm::SmallVector<Value> newUpdateTileOffsetOps;
+    for (auto tDesc : adaptor.getTensorDesc()) {
+      auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+          op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
+          op.getConstOffsets());
+      newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
+    return success();
+  }
+};
+
+/// This pattern transforms the DpasOp to work at subgroup level.
+struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
+  using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType resultTy = op.getResult().getType();
+    if (resultTy.getRank() != 2)
+      return failure();
+
+    auto originalLayout =
+        llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    if (!originalLayout)
+      return failure();
+
+    SmallVector<Value> newDpasOps;
+    size_t i = 0;
+    for (auto aVec : adaptor.getLhs()) {
+      for (auto bVec : adaptor.getRhs()) {
+
+        llvm::SmallVector<Value> operands({aVec, bVec});
+        Value tmpC;
+        if (op.getAcc()) {
+          tmpC = adaptor.getAcc()[i++];
+          operands.push_back(tmpC);
+        }
+
+        ArrayRef<int64_t> aVecShape =
+            llvm::cast<VectorType>(aVec.getType()).getShape();
+        ArrayRef<int64_t> bVecShape =
+            llvm::cast<VectorType>(bVec.getType()).getShape();
+        VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
+                                           resultTy.getElementType());
+        tmpC = rewriter.create<xegpu::DpasOp>(
+            loc, resTy, operands,
+            llvm::ArrayRef<NamedAttribute>(
+                {"layout", originalLayout.dropSgLayoutAndData()}));
+        newDpasOps.push_back(tmpC);
+      }
+    }
+    rewriter.replaceOpWithMultiple(op, {newDpasOps});
+    return success();
+  }
+};
+
+/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
+struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
+  using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    for (auto src : adaptor.getTensorDesc()) {
+      rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
+                                           op->getAttrs());
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace xegpu {
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
+  patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
+      patterns.getContext());
+}
+} // namespace xegpu
+} // namespace mlir
+
+namespace {
+struct XeGPUWgToSgPass : public xegpu::impl::XeGPUWgToSgBase<XeGPUWgToSgPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void XeGPUWgToSgPass::runOnOperation() {
+  MLIRContext *ctx = &getContext();
+  RewritePatternSet patterns(ctx);
+  ConversionTarget target(*ctx);
+
+  auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
+    if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
+      return createOp.getType();
+    if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
+      return loadOp.getTensorDescType();
+    if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
+      return storeOp.getTensorDescType();
+    if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
+      return updateOp.getType();
+    if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
+      return prefetchOp.getTensorDescType();
+    return xegpu::TensorDescType();
+  };
+
+  auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
+    return !layout || layout.getSgLayout() == nullptr;
+  };
+
+  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
+                               xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
+                               xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
+    auto tdescTy = getTensorDescType(op);
+    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
+    return isLegal(layout);
+  });
+
+  target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
+    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    return isLegal(layout);
+  });
+
+  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+  xegpu::populateXeGPUWgToSgPatterns(patterns);
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    return signalPassFailure();
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
new file mode 100644
index 0000000000000..de2c548ec7ebb
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+
+gpu.module @test_round_robin_assignment {
+  // CHECK: test_create_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
+      // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_...
[truncated]

@@ -9,6 +9,8 @@
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H

#include "mlir/Transforms/DialectConversion.h"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not needed, cleaned up. Thanks

@chencha3 chencha3 self-assigned this May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants