Skip to content

[BUG] Incorrect result when writing example 00_pvc_gemm.cpp from CUTE #389

@leslie-fang-intel

Description

@leslie-fang-intel

Describe the bug
I am trying to build a BF16 GEMM computation with CUTE (without usage of Collective API) based on the example of cutlass-sycl/examples/sycl/00_pvc_gemm/00_pvc_gemm.cpp.

Specifically, here is how I use the cute API

template <
  typename TA, typename TB,
  typename TC, typename TD,
  typename TiledMma, typename TileShape, int kTileM, int kTileN, int kTileK>
void my_kernel(
  const TA* ptr_A,
  const TB* ptr_B,
  const TC* ptr_C,
  TD* ptr_D,
  int M,
  int N,
  int K) {
  // A: size: (M, K) stride: (K, 1)
  // B: size: (N, K) stride: (1, N)
  // D: size: (M, N) stride: (N, 1)
  
  TiledMma tiled_mma;

  cute::Tensor A = cute::make_tensor(cute::make_gmem_ptr(ptr_A), cute::make_shape(M, K), cute::make_stride(K, cute::Int<1>{}));
  cute::Tensor B = cute::make_tensor(cute::make_gmem_ptr(ptr_B), cute::make_shape(N, K), cute::make_stride(cute::Int<1>{}, N));  // Column Major
  cute::Tensor D = cute::make_tensor(cute::make_gmem_ptr(ptr_D), cute::make_shape(M, N), cute::make_stride(N, cute::Int<1>{}));

  int ix = BlockIdxX();  // N DIM by define of grid(grid_n, grid_m) 
  int iy = BlockIdxY();  // M Dim

  // gA(kTileM, kTileK, num_tile_k)
  // gB(kTileN, kTileK, num_tile_k)
  // gC(kTileM, kTileN) 
  cute::Tensor gA = cute::local_tile(A, cute::make_tile(cute::Int<kTileM>{}, cute::Int<kTileK>{}), cute::make_coord(iy, cute::_));
  cute::Tensor gB = cute::local_tile(B, cute::make_tile(cute::Int<kTileN>{}, cute::Int<kTileK>{}), cute::make_coord(ix, cute::_));
  cute::Tensor gD = cute::local_tile(D, cute::make_tile(cute::Int<kTileM>{}, cute::Int<kTileN>{}), cute::make_coord(iy, ix));

  auto thr_mma = tiled_mma.get_slice(int(ThreadIdxX()));

  auto tAgA = thr_mma.partition_A(gA);  // (MMA, MMA_M, MMA_K, num_tile_k)
  auto tBgB = thr_mma.partition_B(gB);  // (MMA, MMA_N, MMA_K, num_tile_k)
  auto tDgD = thr_mma.partition_C(gD);  // (MMA, MMA_M, MMA_N)

  auto tArA = thr_mma.partition_fragment_A(gA(cute::_, cute::_, 0));  // (MMA, MMA_M, MMA_K)
  auto tBrB = thr_mma.partition_fragment_B(gB(cute::_, cute::_, 0));  // (MMA, MMA_N, MMA_K)
  auto tDrD = thr_mma.partition_fragment_C(gD(cute::_, cute::_));     // (MMA, MMA_M, MMA_N)

  // set to zero
  cute::clear(tDrD);

  int num_tile_k = cute::size<2>(gA);
  #pragma unroll
  for(int itile = 0; itile < num_tile_k; ++itile) {
    cute::copy(tAgA(cute::_, cute::_, cute::_, itile), tArA);
    cute::copy(tBgB(cute::_, cute::_, cute::_, itile), tBrB);

    cute::gemm(tiled_mma, tDrD, tArA, tBrB, tDrD);
  }

  cute::copy(tDrD, tDgD);

}

template <
  class Gemm
>
void raw_run(const Options& options, typename Gemm::GemmKernel::Arguments const& args) {
  printf("\n ---- hit the raw run ---- \n");
  Gemm gemm_op;
  using GemmKernel = typename Gemm::GemmKernel;
  using TiledMma = typename GemmKernel::CollectiveMainloop::TiledMma;
  using TileShape = typename GemmKernel::CollectiveMainloop::WorkgroupTileShape;

  static constexpr auto kTileM = get<0>(TileShape{});
  static constexpr auto kTileN = get<1>(TileShape{});
  static constexpr auto kTileK = get<2>(TileShape{});
  int grid_m = options.m / kTileM;
  int grid_n = options.n / kTileN;
  int grid_l = options.l;
  if (grid_l != 1) {
    std::cout<<"---- to support the case when grid_l not equal to 1 ----"<<std::endl;
    std::exit(1);
  }
  dim3 grid(grid_n, grid_m, grid_l);
  dim3 const block = dim3(size(TiledMma{}));

  // <TODO> Using smem
  int smem_size = 0;
  sycl::queue q = syclcompat::get_default_queue();
  
  // submit kernel
  // Option 1:
  // syclcompat::launch<my_kernel>(grid, block, q);

  // Option 2:
  using EmptyProperties = decltype(sycl::ext::oneapi::experimental::properties());
  auto kernel_props = syclcompat::experimental::kernel_properties<EmptyProperties>{};
  syclcompat::experimental::launch_properties launch_props {
    sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size),
  };
  syclcompat::experimental::launch_policy policy{
    grid, block, launch_props, kernel_props
  };

  // auto params = gemm_op.params();

  syclcompat::experimental::launch<
    my_kernel<
      typename Gemm::ElementA, typename Gemm::ElementB,
      typename Gemm::ElementC, typename Gemm::ElementD,
      TiledMma, TileShape, kTileM, kTileN, kTileK
    >
  >(
    policy, q, args.mainloop.ptr_A, args.mainloop.ptr_B, args.epilogue.ptr_C, args.epilogue.ptr_D,
    options.m, options.n, options.k
  );

}

When using similar code on SM80 with Tiled_MMA

  using MMA_fp32 = decltype(make_tiled_mma(mma_atom_fp32{}, 
                      make_layout(cute::Shape<cute::_2, cute::_2, cute::_1>{}), // thr layout
                      cute::Tile<cute::_32, cute::_16, cute::_16>{})); // permutation

The GEMM passes the correctness check. For more details as how we change the 00_pvc_gemm.cpp example, please kindly refer to this commit: leslie-fang-intel@2716a96

Steps/Code to reproduce bug
build and run example of: test_examples_00_pvc_gemm with branch https://github.com/leslie-fang-intel/cutlass-sycl/tree/leslie/poc_bf16_cute

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions