From 819f6c59f5a2a43d465403be3059db0c6a261578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 8 May 2025 16:12:27 +0200 Subject: [PATCH] CUDA: fix crash on large batch size for MoE models --- ggml/src/ggml-cuda/getrows.cu | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index ea8bf69160996..963e4d03dd77b 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -10,10 +10,11 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -46,10 +47,11 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = blockIdx.x*blockDim.x + threadIdx.x; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = blockIdx.y * blockDim.x + threadIdx.x; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -94,8 +96,8 @@ static void get_rows_cuda_q( const size_t nb1, const size_t nb2, const size_t nb3, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -127,8 +129,8 @@ static void get_rows_cuda_float( const size_t nb1, const size_t nb2, const size_t nb3, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t);