Skip to content

CUDA: fix crash on large batch size for MoE models #13384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions ggml/src/ggml-cuda/getrows.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ static __global__ void k_get_rows(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {

const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12;

if (i00 >= ne00) {
return;
Expand Down Expand Up @@ -46,10 +47,11 @@ static __global__ void k_get_rows_float(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {

const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12;

if (i00 >= ne00) {
return;
Expand Down Expand Up @@ -94,8 +96,8 @@ static void get_rows_cuda_q(
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
const dim3 block_nums(ne10, block_num_y, ne11*ne12);

// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
Expand Down Expand Up @@ -127,8 +129,8 @@ static void get_rows_cuda_float(
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(ne10, block_num_y, ne11*ne12);

// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
Expand Down
Loading