forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 61
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
I am trying to build a BF16 GEMM computation with CUTE (without usage of Collective API) based on the example of cutlass-sycl/examples/sycl/00_pvc_gemm/00_pvc_gemm.cpp
.
Specifically, here is how I use the cute API
- I am using same MMA
XE_8x16x16_F32BF16BF16F32_TT
in the example to build Tiled_MMA as https://github.com/codeplaysoftware/cutlass-sycl/blob/531dbe66adc53b32b7126cc90d8fa0c3e6ce24d4/examples/sycl/00_pvc_gemm/00_pvc_gemm.cpp#L364-L366 - Here is how I use CUTE to build this GEMM computation
template <
typename TA, typename TB,
typename TC, typename TD,
typename TiledMma, typename TileShape, int kTileM, int kTileN, int kTileK>
void my_kernel(
const TA* ptr_A,
const TB* ptr_B,
const TC* ptr_C,
TD* ptr_D,
int M,
int N,
int K) {
// A: size: (M, K) stride: (K, 1)
// B: size: (N, K) stride: (1, N)
// D: size: (M, N) stride: (N, 1)
TiledMma tiled_mma;
cute::Tensor A = cute::make_tensor(cute::make_gmem_ptr(ptr_A), cute::make_shape(M, K), cute::make_stride(K, cute::Int<1>{}));
cute::Tensor B = cute::make_tensor(cute::make_gmem_ptr(ptr_B), cute::make_shape(N, K), cute::make_stride(cute::Int<1>{}, N)); // Column Major
cute::Tensor D = cute::make_tensor(cute::make_gmem_ptr(ptr_D), cute::make_shape(M, N), cute::make_stride(N, cute::Int<1>{}));
int ix = BlockIdxX(); // N DIM by define of grid(grid_n, grid_m)
int iy = BlockIdxY(); // M Dim
// gA(kTileM, kTileK, num_tile_k)
// gB(kTileN, kTileK, num_tile_k)
// gC(kTileM, kTileN)
cute::Tensor gA = cute::local_tile(A, cute::make_tile(cute::Int<kTileM>{}, cute::Int<kTileK>{}), cute::make_coord(iy, cute::_));
cute::Tensor gB = cute::local_tile(B, cute::make_tile(cute::Int<kTileN>{}, cute::Int<kTileK>{}), cute::make_coord(ix, cute::_));
cute::Tensor gD = cute::local_tile(D, cute::make_tile(cute::Int<kTileM>{}, cute::Int<kTileN>{}), cute::make_coord(iy, ix));
auto thr_mma = tiled_mma.get_slice(int(ThreadIdxX()));
auto tAgA = thr_mma.partition_A(gA); // (MMA, MMA_M, MMA_K, num_tile_k)
auto tBgB = thr_mma.partition_B(gB); // (MMA, MMA_N, MMA_K, num_tile_k)
auto tDgD = thr_mma.partition_C(gD); // (MMA, MMA_M, MMA_N)
auto tArA = thr_mma.partition_fragment_A(gA(cute::_, cute::_, 0)); // (MMA, MMA_M, MMA_K)
auto tBrB = thr_mma.partition_fragment_B(gB(cute::_, cute::_, 0)); // (MMA, MMA_N, MMA_K)
auto tDrD = thr_mma.partition_fragment_C(gD(cute::_, cute::_)); // (MMA, MMA_M, MMA_N)
// set to zero
cute::clear(tDrD);
int num_tile_k = cute::size<2>(gA);
#pragma unroll
for(int itile = 0; itile < num_tile_k; ++itile) {
cute::copy(tAgA(cute::_, cute::_, cute::_, itile), tArA);
cute::copy(tBgB(cute::_, cute::_, cute::_, itile), tBrB);
cute::gemm(tiled_mma, tDrD, tArA, tBrB, tDrD);
}
cute::copy(tDrD, tDgD);
}
template <
class Gemm
>
void raw_run(const Options& options, typename Gemm::GemmKernel::Arguments const& args) {
printf("\n ---- hit the raw run ---- \n");
Gemm gemm_op;
using GemmKernel = typename Gemm::GemmKernel;
using TiledMma = typename GemmKernel::CollectiveMainloop::TiledMma;
using TileShape = typename GemmKernel::CollectiveMainloop::WorkgroupTileShape;
static constexpr auto kTileM = get<0>(TileShape{});
static constexpr auto kTileN = get<1>(TileShape{});
static constexpr auto kTileK = get<2>(TileShape{});
int grid_m = options.m / kTileM;
int grid_n = options.n / kTileN;
int grid_l = options.l;
if (grid_l != 1) {
std::cout<<"---- to support the case when grid_l not equal to 1 ----"<<std::endl;
std::exit(1);
}
dim3 grid(grid_n, grid_m, grid_l);
dim3 const block = dim3(size(TiledMma{}));
// <TODO> Using smem
int smem_size = 0;
sycl::queue q = syclcompat::get_default_queue();
// submit kernel
// Option 1:
// syclcompat::launch<my_kernel>(grid, block, q);
// Option 2:
using EmptyProperties = decltype(sycl::ext::oneapi::experimental::properties());
auto kernel_props = syclcompat::experimental::kernel_properties<EmptyProperties>{};
syclcompat::experimental::launch_properties launch_props {
sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size),
};
syclcompat::experimental::launch_policy policy{
grid, block, launch_props, kernel_props
};
// auto params = gemm_op.params();
syclcompat::experimental::launch<
my_kernel<
typename Gemm::ElementA, typename Gemm::ElementB,
typename Gemm::ElementC, typename Gemm::ElementD,
TiledMma, TileShape, kTileM, kTileN, kTileK
>
>(
policy, q, args.mainloop.ptr_A, args.mainloop.ptr_B, args.epilogue.ptr_C, args.epilogue.ptr_D,
options.m, options.n, options.k
);
}
When using similar code on SM80 with Tiled_MMA
using MMA_fp32 = decltype(make_tiled_mma(mma_atom_fp32{},
make_layout(cute::Shape<cute::_2, cute::_2, cute::_1>{}), // thr layout
cute::Tile<cute::_32, cute::_16, cute::_16>{})); // permutation
The GEMM passes the correctness check. For more details as how we change the 00_pvc_gemm.cpp
example, please kindly refer to this commit: leslie-fang-intel@2716a96
Steps/Code to reproduce bug
build and run example of: test_examples_00_pvc_gemm
with branch https://github.com/leslie-fang-intel/cutlass-sycl/tree/leslie/poc_bf16_cute
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working