diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index e1cf843de1a65..2db5b4ab0f09c 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q( const int64_t s13 = src1->nb[3] / ts_src1; quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + CUDA_CHECK(cudaGetLastError()); } const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); @@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q( const int64_t s13 = src1->nb[2] / ts_src1; quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + CUDA_CHECK(cudaGetLastError()); } const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index cb93181455d47..a0b03a740d74c 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1( constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; - const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; + const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4; if (i0 >= ne0) { return; } - const int64_t i1 = blockIdx.y; + const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.z % ne2; const int64_t i3 = blockIdx.z / ne2; @@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1( block_q8_1_mmq * y = (block_q8_1_mmq *) vy; - const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel - const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel const int64_t iqs = i0 % (4*QK8_1); // quant index in block // Load 4 floats per thread and calculate max. abs. value between them: @@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda( GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne0 % (4*QK8_1) == 0); - const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); - const dim3 num_blocks(block_num_x, ne1, ne2*ne3); + // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid: + const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(ne1, block_num_y, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); switch (mmq_get_q8_1_ds_layout(type_src0)) { case MMQ_Q8_1_DS_LAYOUT_D4: