From d19838e9b9acc13bd1dc753d8d499123ac61926a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 30 Apr 2025 17:05:34 +0200 Subject: [PATCH 1/4] CUDA: FA support for Deepseek (Ampere or newer) --- ggml/src/ggml-cuda/fattn-common.cuh | 26 +- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 891 +++++++++++------- ggml/src/ggml-cuda/fattn-tile-f16.cu | 4 +- ggml/src/ggml-cuda/fattn-tile-f32.cu | 4 +- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 2 +- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 2 +- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 2 +- ggml/src/ggml-cuda/fattn.cu | 116 ++- ggml/src/ggml-cuda/ggml-cuda.cu | 12 +- ...ttn-mma-f16-instance-ncols1_1-ncols2_16.cu | 5 + ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 12 +- ...ttn-mma-f16-instance-ncols1_16-ncols2_1.cu | 12 +- ...ttn-mma-f16-instance-ncols1_16-ncols2_2.cu | 12 +- ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 12 +- ...ttn-mma-f16-instance-ncols1_2-ncols2_16.cu | 5 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 12 +- ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 12 +- ...ttn-mma-f16-instance-ncols1_32-ncols2_1.cu | 12 +- ...ttn-mma-f16-instance-ncols1_32-ncols2_2.cu | 12 +- ...ttn-mma-f16-instance-ncols1_4-ncols2_16.cu | 5 + ...attn-mma-f16-instance-ncols1_4-ncols2_2.cu | 12 +- ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 12 +- ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 12 +- ...ttn-mma-f16-instance-ncols1_64-ncols2_1.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_1.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_2.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 12 +- .../template-instances/generate_cu_files.py | 21 +- src/llama-graph.cpp | 11 + 30 files changed, 786 insertions(+), 512 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 56121705bdf3f..588f26cc959cd 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -template // D == head size +template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { @@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) { fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); } else { - fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); + fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D); fprintf(stderr, "Only f16 is supported.\n"); GGML_ABORT("fatal error"); } } -template +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE @@ -691,7 +691,7 @@ void launch_fattn( GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); @@ -752,10 +752,13 @@ void launch_fattn( const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; const dim3 block_dim(warp_size, nwarps, 1); + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + dim3 blocks_num; if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int max_blocks = 2*nsm; + const int max_blocks = max_blocks_per_sm*nsm; const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); @@ -767,14 +770,11 @@ void launch_fattn( blocks_num.y = 1; blocks_num.z = 1; - dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. - int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. - CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); - // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); @@ -851,19 +851,19 @@ void launch_fattn( if (stream_k) { if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup <<>> ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); } } else if (parallel_blocks > 1) { - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results + flash_attn_combine_results <<>> (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 04804a15c9d51..068f3f75b6d26 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -13,104 +13,214 @@ typedef tile<16, 16, float> tile_C_KQ_16; typedef tile<16, 4, half2> tile_C_VKQ; typedef tile<16, 8, half2> tile_C_VKQ_16; -template +// Config options for specific head sizes. +// Should not affect results, only speed/register pressure/shared memory use. +// +// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. +// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). +// Q_in_reg: whether the Q values should be kept permanently in registers. +// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. +// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. +// nbatch_V2: number of V half2 values in direction of DV to load in parallel. +// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. + +template +struct fattn_mma_f16_config; + +template <> +struct fattn_mma_f16_config< 64, 64> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 32; + static constexpr int nbatch_V2 = 32; + static constexpr int nbatch_combine = 32; +}; + +template <> +struct fattn_mma_f16_config< 80, 80> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 40; + static constexpr int nbatch_V2 = 40; + static constexpr int nbatch_combine = 40; +}; + +template <> +struct fattn_mma_f16_config< 96, 96> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 48; + static constexpr int nbatch_V2 = 48; + static constexpr int nbatch_combine = 48; +}; + +template <> +struct fattn_mma_f16_config<112, 112> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 56; + static constexpr int nbatch_V2 = 56; + static constexpr int nbatch_combine = 56; +}; + +template <> +struct fattn_mma_f16_config<128, 128> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 64; + static constexpr int nbatch_V2 = 64; + static constexpr int nbatch_combine = 64; +}; + +template <> +struct fattn_mma_f16_config<256, 256> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 128; + static constexpr int nbatch_V2 = 128; + static constexpr int nbatch_combine = 128; +}; + +template <> +struct fattn_mma_f16_config<576, 512> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 8; + static constexpr bool Q_in_reg = false; + static constexpr int nstages_target = 1; + static constexpr int nbatch_K2 = 160; + static constexpr int nbatch_V2 = 128; + static constexpr int nbatch_combine = 128; +}; + +template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { - // If cp.async is available, load up to the highest power of 2 in D asynchronously: -#ifdef CP_ASYNC_AVAILABLE - static_assert(D >= 64 && D < 512, "bad D"); - constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + + if (use_cp_async) { + const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); - const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + + const int chunks_per_row = D2 / h2_per_chunk; - constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; - constexpr int stride_i = WARP_SIZE / chunks_per_row; #pragma unroll - for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); - const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4, WARP_SIZE/8, WARP_SIZE/16}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; - cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); - } -#else - constexpr int k0_sync_start = 0; -#endif // CP_ASYNC_AVAILABLE - static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); + if (k0_start == k0_stop) { + continue; + } - // If D is not a power of 2, the rest is loaded synchronously. - // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - if (k0_start == k0_stop || k0_stop <= k0_sync_start) { - continue; + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); + } + } } + } else { + static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop || k0_stop <= 0) { + continue; + } #pragma unroll - for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; + tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + } } } } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { - static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); -#ifdef CP_ASYNC_AVAILABLE - constexpr int preload = KQ_per_iter * sizeof(half); - constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; - constexpr int stride_j = nwarps * cols_per_warp; + static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); - const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + if (use_cp_async) { + constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; + constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + - (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); - if (j0 + stride_j > ncols1 && j >= ncols1) { - break; - } + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } - const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); + const int i = 4 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } + return; } -#else - constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; + + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); + const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); if (j0 + stride_j > ncols1 && j >= ncols1) { break; } - const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); + const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); - tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; + tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; } -#endif // CP_ASYNC_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -123,9 +233,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float logit_softcap, const int ne01, const int ne02, - const int stride_KV, + const int stride_K, + const int stride_V, const int stride_mask, const int jt, + half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, half2 * const __restrict__ tile_mask, @@ -135,59 +247,107 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float * const __restrict__ KQ_rowsum, const int kb0) { #ifdef NEW_MMA_AVAILABLE + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. - const int k_VKQ_0 = kb0 * KQ_per_iter; - tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; + + const int k_VKQ_0 = kb0 * c::nbatch_fa; + tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; // Use wide variants of tiles if ntiles >= 2. tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; -#ifdef CP_ASYNC_AVAILABLE - cp_async_wait_all(); - __syncthreads(); - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); -#else - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + } else { + constexpr bool use_cp_async = nstages == 1; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + } } - flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); - __syncthreads(); -#endif // CP_ASYNC_AVAILABLE - // Calculate tile of KQ: #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { + const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + const int k0_diff = k0_stop - k0_start; + + if (nstages <= 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate tile of KQ: + if constexpr (c::Q_in_reg) { #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { - tile_A K_A; - load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); - if (ntiles == 1) { - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); - } else { + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } + } + } + } else { + static_assert(ntiles == 2, "ntiles != 2 not implemented"); +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); } } } - } -#ifndef CP_ASYNC_AVAILABLE - __syncthreads(); // Only needed if tile_K == tile_V. -#endif // CP_ASYNC_AVAILABLE + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } if (use_logit_softcap) { - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { + for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); @@ -205,7 +365,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (ntiles == 1) { if (ncols2 > 1 || mask_h2) { #pragma unroll - for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { @@ -213,16 +373,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * - __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); + __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); @@ -238,10 +398,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); - + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); @@ -252,7 +411,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } else { // ntiles > 1 if (ncols2 > 1 || mask_h2) { #pragma unroll - for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; #pragma unroll for (int t = 0; t < ntiles/2; ++t) { @@ -261,7 +420,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; - const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); + const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; @@ -272,9 +431,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { #pragma unroll @@ -294,9 +453,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { #pragma unroll @@ -325,7 +484,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (ntiles == 1) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < D/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { #pragma unroll for (int l = 0; l < tile_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; @@ -336,7 +495,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { #pragma unroll for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; @@ -347,16 +506,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; + tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; tile_B_16 * B_16 = (tile_B_16 *) B; - static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); if (ntiles == 1) { #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); } } else { - for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); @@ -364,52 +523,67 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } -#ifdef CP_ASYNC_AVAILABLE - // Preload K tile for next iteration: - cp_async_wait_all(); - __syncthreads(); - if (!last_iter) { - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); + if (nstages > 1) { + // Preload K tile for next iteration: + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); } - flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); } -#else - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); - __syncthreads(); -#endif // CP_ASYNC_AVAILABLE - // Calculate VKQ tile: #pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); + for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { + const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; + const int i0_diff = i0_stop - i0_start; + + if (nstages == 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); #pragma unroll - for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { - const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; - tile_A A; - load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); - if (ntiles == 1) { - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); - } else { + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } } } } - } - -#ifndef CP_ASYNC_AVAILABLE - __syncthreads(); // Only needed if tile_K == tile_V. -#endif // CP_ASYNC_AVAILABLE + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } #else GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); @@ -419,7 +593,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -434,7 +608,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int ne02, const int stride_Q1, const int stride_Q2, - const int stride_KV, + const int stride_K, + const int stride_V, const int stride_mask, const int jt, const int kb0_start, @@ -442,6 +617,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #ifdef NEW_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; @@ -449,22 +632,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - static_assert(D % nwarps == 0, "bad D"); - static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; - // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements: - extern __shared__ half2 tile_K[]; -#ifdef CP_ASYNC_AVAILABLE - half2 * tile_V = tile_K + KQ_per_iter*D2_padded; -#else - half2 * tile_V = tile_K; -#endif // CP_ASYNC_AVAILABLE - half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; + extern __shared__ half2 tile_Q[]; + half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; + half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; - tile_B Q_B[D/(2*tile_B::J) * ntiles]; - tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; + tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; + tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; @@ -476,13 +656,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_max[col] = -FLT_MAX/2.0f; } - // Temporarily load Q data into tile_K, will be loaded into registers afterwards. + // Load Q data into tile_Q, either temporarily or permanently. + // Q in registers is faster, but register pressure is the biggest bottleneck. // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); const int stride_jc = WARP_SIZE / stride_k; if (k0_start == k0_stop) { @@ -506,14 +687,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; - tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); } } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); + tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); } } } @@ -521,18 +702,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - { + if (c::Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { if (ntiles == 1) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } else { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], - tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); + tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); } } } @@ -540,35 +721,37 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - // Preload mask and K data for first iteration when using cp_async: -#ifdef CP_ASYNC_AVAILABLE - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); + // Preload mask and K data for first iteration when using cp_async with multiple stages: + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + constexpr bool use_cp_async = true; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); } - flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); -#endif // CP_ASYNC_AVAILABLE // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } - // With cp_async there is no __syncthreads at the end of the iter, + // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. -#ifdef CP_ASYNC_AVAILABLE - if (nwarps*cols_per_warp > KQ_per_iter) { + if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { __syncthreads(); } -#endif // CP_ASYNC_AVAILABLE // Finally, sum up partial KQ rowsums. // The partial sums are spread across 8/4 threads each, does not need full reduce. @@ -584,38 +767,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } - // Write VKQ accumulators to shared memory in column-major format. - // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. - // Also for np > 1 the combination is done via these values in shared memory. - if (ntiles == 1) { - const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + // Combine VKQ accumulator values if np > 1. + // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // So also write VKQ accumulators to shared memory in column-major format if np == 1. -#pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); - - tile_K[jc_cwd*D2_padded + k] = B.x[l]; - } - } - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { -#pragma unroll - for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { - const int j = j0 + tile_C_VKQ_16::get_i(l); - const int k = k0 + tile_C_VKQ_16::get_j(l); - - tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; - } - } - } - } + constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int tile_stride = nbatch_combine + 4; + static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); if constexpr (ntiles == 1) { const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset @@ -624,7 +782,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } __syncthreads(); @@ -649,7 +807,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } __syncthreads(); @@ -676,11 +834,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); - float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; + float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; float2 meta[nmeta]; #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; + meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; } float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. @@ -720,7 +878,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int imeta = 0; imeta < nmeta; ++imeta) { if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { // Combined KQ max scale + rowsum. - meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); + meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); } } @@ -736,88 +894,118 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } - if (np > 1) { - __syncthreads(); - } - - if (np == 1 || threadIdx.y % np == 0) { - // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. - // The values after that are for the partial results of the individual blocks. - float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); +#pragma unroll + for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); - if (k0_start == k0_stop) { - continue; + tile_Q[jc_cwd*tile_stride + k] = B.x[l]; + } } - + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { #pragma unroll - for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { - const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); - if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { - break; + tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; + } } + } + } - const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + __syncthreads(); - const int j_dst = jc_dst / ncols2; - const int c_dst = jc_dst % ncols2; + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { continue; } - const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - float2 dstk_val = make_float2(0.0f, 0.0f); -#pragma unroll - for (int ip = 0; ip < np; ++ip) { - const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; - const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); - dstk_val.x += dstk_val_add.x*KQ_crs; - dstk_val.y += dstk_val_add.y*KQ_crs; + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; } - if (!needs_fixup && !is_fixup) { - const float KQ_rowsum_j = meta_j[1]; - dstk_val.x /= KQ_rowsum_j; - dstk_val.y /= KQ_rowsum_j; + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + continue; } - if (is_fixup) { - dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; - } else { - dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; + const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0]; + const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val; + } } } } } - } - - if (np > 1) { - __syncthreads(); + if (np > 1) { + __syncthreads(); + } } #else GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); - GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -template -__launch_bounds__(nwarps*WARP_SIZE, 2) +template +__launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -857,24 +1045,27 @@ static __global__ void flash_attn_ext_f16( #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { NO_DEVICE_CODE; return; } - static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); + typedef fattn_mma_f16_config c; + + static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); - const int stride_KV = nb11 / sizeof(half2); + const int stride_K = nb11 / sizeof(half2); + const int stride_V = nb21 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2); const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice. + constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; @@ -893,9 +1084,9 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -905,14 +1096,14 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } kbc += iter_k; @@ -931,9 +1122,9 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -942,9 +1133,9 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); @@ -960,28 +1151,42 @@ static __global__ void flash_attn_ext_f16( #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } -template +template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + typedef fattn_mma_f16_config c; + + constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; + constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; + constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; + + const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + constexpr int ncols = ncols1 * ncols2; - constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; - constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; - constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); + constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int nwarps_max_x = ncols / cols_per_warp; + constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; + constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; - static_assert(D % tile_B::J == 0, "bad D"); + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); + static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(ncols % cols_per_warp == 0, "bad ncols"); - const ggml_tensor * KQV = dst; - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); - const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; + const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; - const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); - const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); - const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); - - const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : + nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); @@ -989,59 +1194,73 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); } -#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ - template void ggml_cuda_flash_attn_ext_mma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8) - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16) - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32) - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64) - -// Kernels with ncols == 128 are only 4% faster due to register pressure. -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory. +#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \ + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) + +// The number of viable configurations for Deepseek is very limited: +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index e0039e1755c2b..9283560d5c4ee 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; case 128: { @@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; default: { diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index fcb6f848fe003..32673adb57fc1 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; case 128: { @@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; default: { diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e17d2d0e4fb39..ef0addc1dbc70 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -315,7 +315,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index d42ddca49f66d..7064675d5ab3f 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index bc21b27a0cc07..c5668adb152b2 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); } void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 7a2d1e45365af..9c5c803d02bc7 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -8,58 +8,32 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" -template +template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - if (Q->ne[1] <= 8/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; + if constexpr (ncols2 <= 8) { + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } } if (Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } if (Q->ne[1] <= 32/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); -} - -template -static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); } -static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const float use_gqa_opt = mask && max_bias == 0.0f; + const bool use_gqa_opt = mask && max_bias == 0.0f; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; if (use_gqa_opt && gqa_ratio % 8 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio == 4) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio == 2) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); +} + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + switch (Q->ne[0]) { + case 64: + GGML_ASSERT(V->ne[0] == 64); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst); + break; + case 80: + GGML_ASSERT(V->ne[0] == 80); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst); + break; + case 96: + GGML_ASSERT(V->ne[0] == 96); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst); + break; + case 112: + GGML_ASSERT(V->ne[0] == 112); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst); + break; + case 128: + GGML_ASSERT(V->ne[0] == 128); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); + break; + case 256: + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); + break; + case 576: { + // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + GGML_ASSERT(V->ne[0] == 512); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 16 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } break; + default: + GGML_ABORT("fatal error"); + break; + } } #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ @@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9fb2134f98d3d..4c668632bd27d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3206,16 +3206,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - // different head sizes of K and V are not supported yet - return false; + const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; + if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) { + return false; + } + const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; + return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; } if (op->src[0]->ne[0] == 192) { return false; } - if (op->src[0]->ne[0] == 576) { - // DeepSeek MLA - return false; - } if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu new file mode 100644 index 0000000000000..fb26abeb0dab3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index 80108615ac8fb..dc16829021f90 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 1, 8); -DECL_FATTN_MMA_F16_CASE(80, 1, 8); -DECL_FATTN_MMA_F16_CASE(96, 1, 8); -DECL_FATTN_MMA_F16_CASE(112, 1, 8); -DECL_FATTN_MMA_F16_CASE(128, 1, 8); -DECL_FATTN_MMA_F16_CASE(256, 1, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu index 66161c0abeb42..9d3cfd8edf74b 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 1); -DECL_FATTN_MMA_F16_CASE(80, 16, 1); -DECL_FATTN_MMA_F16_CASE(96, 16, 1); -DECL_FATTN_MMA_F16_CASE(112, 16, 1); -DECL_FATTN_MMA_F16_CASE(128, 16, 1); -DECL_FATTN_MMA_F16_CASE(256, 16, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu index ee88c72aa0470..2e1883af40ed2 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 2); -DECL_FATTN_MMA_F16_CASE(80, 16, 2); -DECL_FATTN_MMA_F16_CASE(96, 16, 2); -DECL_FATTN_MMA_F16_CASE(112, 16, 2); -DECL_FATTN_MMA_F16_CASE(128, 16, 2); -DECL_FATTN_MMA_F16_CASE(256, 16, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index d888a5a423abf..2074e954a32f0 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 4); -DECL_FATTN_MMA_F16_CASE(80, 16, 4); -DECL_FATTN_MMA_F16_CASE(96, 16, 4); -DECL_FATTN_MMA_F16_CASE(112, 16, 4); -DECL_FATTN_MMA_F16_CASE(128, 16, 4); -DECL_FATTN_MMA_F16_CASE(256, 16, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu new file mode 100644 index 0000000000000..f011a208cd270 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index d93a2d08ed7ea..24c64cf000fec 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 2, 4); -DECL_FATTN_MMA_F16_CASE(80, 2, 4); -DECL_FATTN_MMA_F16_CASE(96, 2, 4); -DECL_FATTN_MMA_F16_CASE(112, 2, 4); -DECL_FATTN_MMA_F16_CASE(128, 2, 4); -DECL_FATTN_MMA_F16_CASE(256, 2, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 617464c94568f..163b1d939e49d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 2, 8); -DECL_FATTN_MMA_F16_CASE(80, 2, 8); -DECL_FATTN_MMA_F16_CASE(96, 2, 8); -DECL_FATTN_MMA_F16_CASE(112, 2, 8); -DECL_FATTN_MMA_F16_CASE(128, 2, 8); -DECL_FATTN_MMA_F16_CASE(256, 2, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu index 970d2b68696e9..0543532ea3479 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 32, 1); -DECL_FATTN_MMA_F16_CASE(80, 32, 1); -DECL_FATTN_MMA_F16_CASE(96, 32, 1); -DECL_FATTN_MMA_F16_CASE(112, 32, 1); -DECL_FATTN_MMA_F16_CASE(128, 32, 1); -DECL_FATTN_MMA_F16_CASE(256, 32, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu index 65cd377c395e1..407b6cf4c7020 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 32, 2); -DECL_FATTN_MMA_F16_CASE(80, 32, 2); -DECL_FATTN_MMA_F16_CASE(96, 32, 2); -DECL_FATTN_MMA_F16_CASE(112, 32, 2); -DECL_FATTN_MMA_F16_CASE(128, 32, 2); -DECL_FATTN_MMA_F16_CASE(256, 32, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu new file mode 100644 index 0000000000000..f5fd0e2369cf2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu index f4a8bf3489963..5e46685024b84 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 2); -DECL_FATTN_MMA_F16_CASE(80, 4, 2); -DECL_FATTN_MMA_F16_CASE(96, 4, 2); -DECL_FATTN_MMA_F16_CASE(112, 4, 2); -DECL_FATTN_MMA_F16_CASE(128, 4, 2); -DECL_FATTN_MMA_F16_CASE(256, 4, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index de191a8ab66fd..1ada657f194c4 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 4); -DECL_FATTN_MMA_F16_CASE(80, 4, 4); -DECL_FATTN_MMA_F16_CASE(96, 4, 4); -DECL_FATTN_MMA_F16_CASE(112, 4, 4); -DECL_FATTN_MMA_F16_CASE(128, 4, 4); -DECL_FATTN_MMA_F16_CASE(256, 4, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index e8cb0e1b31238..bad296b4141e0 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 8); -DECL_FATTN_MMA_F16_CASE(80, 4, 8); -DECL_FATTN_MMA_F16_CASE(96, 4, 8); -DECL_FATTN_MMA_F16_CASE(112, 4, 8); -DECL_FATTN_MMA_F16_CASE(128, 4, 8); -DECL_FATTN_MMA_F16_CASE(256, 4, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu index a532e96296b79..0d7a9c728537d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 1); -DECL_FATTN_MMA_F16_CASE(80, 64, 1); -DECL_FATTN_MMA_F16_CASE(96, 64, 1); -DECL_FATTN_MMA_F16_CASE(112, 64, 1); -DECL_FATTN_MMA_F16_CASE(128, 64, 1); -DECL_FATTN_MMA_F16_CASE(256, 64, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu index bf25181aa7653..9d5a9976f0ed1 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 1); -DECL_FATTN_MMA_F16_CASE(80, 8, 1); -DECL_FATTN_MMA_F16_CASE(96, 8, 1); -DECL_FATTN_MMA_F16_CASE(112, 8, 1); -DECL_FATTN_MMA_F16_CASE(128, 8, 1); -DECL_FATTN_MMA_F16_CASE(256, 8, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu index 378c132e658e0..a6e6f093dcb24 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 2); -DECL_FATTN_MMA_F16_CASE(80, 8, 2); -DECL_FATTN_MMA_F16_CASE(96, 8, 2); -DECL_FATTN_MMA_F16_CASE(112, 8, 2); -DECL_FATTN_MMA_F16_CASE(128, 8, 2); -DECL_FATTN_MMA_F16_CASE(256, 8, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 372641be9a0fe..86d4ffae27c28 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 4); -DECL_FATTN_MMA_F16_CASE(80, 8, 4); -DECL_FATTN_MMA_F16_CASE(96, 8, 4); -DECL_FATTN_MMA_F16_CASE(112, 8, 4); -DECL_FATTN_MMA_F16_CASE(128, 8, 4); -DECL_FATTN_MMA_F16_CASE(256, 8, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 9ff5968b6abeb..680a13ca6de58 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 8); -DECL_FATTN_MMA_F16_CASE(80, 8, 8); -DECL_FATTN_MMA_F16_CASE(96, 8, 8); -DECL_FATTN_MMA_F16_CASE(112, 8, 8); -DECL_FATTN_MMA_F16_CASE(128, 8, 8); -DECL_FATTN_MMA_F16_CASE(256, 8, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index dd373a09d26f3..3428113dc8fd2 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -18,7 +18,7 @@ """ -SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -57,18 +57,21 @@ def get_head_sizes(type_k, type_v): with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) -for ncols in [8, 16, 32, 64, 128]: - for ncols2 in [1, 2, 4, 8]: +for ncols in [8, 16, 32, 64]: + for ncols2 in [1, 2, 4, 8, 16]: + if ncols2 > ncols: + continue ncols1 = ncols // ncols2 - if ncols == 128: - continue # Too much register pressure. with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: f.write(SOURCE_FATTN_MMA_START) - for head_size in [64, 80, 96, 112, 128, 256]: - if ncols == 128 and head_size == 256: - continue # Needs too much shared memory. - f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size)) + for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: + if head_size_kq != 576 and ncols2 == 16: + continue + if head_size_kq == 576 and ncols2 != 16: + continue + head_size_v = head_size_kq if head_size_kq != 576 else 512 + f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0da4e7d2b0547..1e7c925d31551 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1227,8 +1227,19 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); if (v_mla) { +#if 0 + // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. + // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); cur = ggml_mul_mat(ctx0, v_mla, cur); +#else + // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1. + // The permutations are noops and only change how the tensor data is interpreted. + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_mul_mat(ctx0, v_mla, cur); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs. +#endif } cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); From 187054a70c97f1444170ebceb2d8f6804cfd0ead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 6 May 2025 13:07:47 +0200 Subject: [PATCH 2/4] wrap __cvta_generic_to_shared for HIP --- ggml/src/ggml-cuda/cp-async.cuh | 11 +++++++++++ ggml/src/ggml-cuda/fattn-mma-f16.cuh | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh index ecb659997ba51..63d0c482ff727 100644 --- a/ggml/src/ggml-cuda/cp-async.cuh +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -2,6 +2,17 @@ #include "common.cuh" + +static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) { +#ifdef CP_ASYNC_AVAILABLE + return __cvta_generic_to_shared(generic_ptr); +#else + GGML_UNUSED(generic_ptr); + NO_DEVICE_CODE; + return 0; +#endif // CP_ASYNC_AVAILABLE +} + // Copies data from global to shared memory, cg == cache global. // Both the src and dst pointers must be aligned to 16 bit. // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 068f3f75b6d26..0c39624cc2412 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -112,7 +112,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. if (use_cp_async) { - const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); + const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); constexpr int preload = 64; constexpr int h2_per_chunk = 16/sizeof(half2); @@ -186,7 +186,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; - const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll for (int j0 = 0; j0 < ncols1; j0 += stride_j) { From dd05446579291ba41ae1c10527370b7a64388c74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 6 May 2025 21:26:05 +0200 Subject: [PATCH 3/4] fix loop unrolling for KV data load --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 133 ++++++++++++++++----------- 1 file changed, 78 insertions(+), 55 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 0c39624cc2412..ef96c99f8a87a 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -104,75 +104,100 @@ struct fattn_mma_f16_config<576, 512> { static constexpr int nbatch_combine = 128; }; -template -static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { +// ------------------------------------------------------------------------------------------------------------------ - // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. +// The compiler is unable to unroll loops with the k0_start == k0_stop condition. +// Therefore, write functions for the loop iterations and unroll the loops manually. - if (use_cp_async) { - const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_async( + const half2 * const __restrict__ KV, const unsigned int tile_KV_32, const int chunks_per_row, const int stride_KV) { + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); - constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); + const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; - const int chunks_per_row = D2 / h2_per_chunk; + if (k0_start == k0_stop) { + return; + } #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4, WARP_SIZE/8, WARP_SIZE/16}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); - const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - if (k0_start == k0_stop) { - continue; - } + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { - break; - } + cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); + } + } +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_sync( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + return; + } #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); - } - } + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; } - } else { - static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); - const int k0_stop = D2 - D2 % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - if (k0_start == k0_stop || k0_stop <= 0) { - continue; - } + tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + } + } +} -#pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { - if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { - break; - } + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. -#pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + if (use_cp_async) { + const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; - tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; - } - } - } + flash_attn_ext_f16_load_tile_async_loop_iter_async + (KV, tile_KV_32, chunks_per_row, stride_KV); + flash_attn_ext_f16_load_tile_async_loop_iter_async + (KV, tile_KV_32, chunks_per_row, stride_KV); + flash_attn_ext_f16_load_tile_async_loop_iter_async + (KV, tile_KV_32, chunks_per_row, stride_KV); + flash_attn_ext_f16_load_tile_async_loop_iter_async + (KV, tile_KV_32, chunks_per_row, stride_KV); + flash_attn_ext_f16_load_tile_async_loop_iter_async + (KV, tile_KV_32, chunks_per_row, stride_KV); + } else { + static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + flash_attn_ext_f16_load_tile_async_loop_iter_sync + (KV, tile_KV, D2, stride_KV); + flash_attn_ext_f16_load_tile_async_loop_iter_sync + (KV, tile_KV, D2, stride_KV); + flash_attn_ext_f16_load_tile_async_loop_iter_sync + (KV, tile_KV, D2, stride_KV); } } @@ -848,10 +873,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } float KQ_cms[nmeta]; // KQ combine max scale per warp. @@ -867,10 +891,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } // Write back combined meta data: From fe2b775ab5d959c9e383e0b5dd5f3e2126a74a60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 9 May 2025 11:05:59 +0200 Subject: [PATCH 4/4] do loop unrolling via C++ template --- ggml/src/ggml-cuda/CMakeLists.txt | 2 +- ggml/src/ggml-cuda/common.cuh | 19 ++++ ggml/src/ggml-cuda/fattn-mma-f16.cuh | 130 +++++++++++---------------- 3 files changed, 74 insertions(+), 77 deletions(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index f3cfdeaef513a..d05fcf22886ec 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -100,7 +100,7 @@ if (CUDAToolkit_FOUND) set(CUDA_CXX_FLAGS "") - set(CUDA_FLAGS -use_fast_math) + set(CUDA_FLAGS -use_fast_math -extended-lambda) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") # Options are: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2ea014e6476a7..d56f746c8a7f2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -300,6 +300,25 @@ static __device__ void no_device_code( #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #endif // __CUDA_ARCH__ +// The compiler is always able to unroll loops if they contain continue expressions. +// In such cases loop unrolling can still be achieved via recursion: +template +struct ggml_cuda_unroll { + template + __device__ void operator()(const Func & f, Args... args) const { + f(n - 1, args...); + ggml_cuda_unroll{}(f, args...); + } +}; + +template <> +struct ggml_cuda_unroll<1> { + template + __device__ void operator()(const Func & f, Args... args) const { + f(0, args...); + } +}; + template static __device__ __forceinline__ int warp_reduce_sum(int x) { #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index ef96c99f8a87a..2b6bdc30c024b 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -106,98 +106,76 @@ struct fattn_mma_f16_config<576, 512> { // ------------------------------------------------------------------------------------------------------------------ -// The compiler is unable to unroll loops with the k0_start == k0_stop condition. -// Therefore, write functions for the loop iterations and unroll the loops manually. +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { -template -static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_async( - const half2 * const __restrict__ KV, const unsigned int tile_KV_32, const int chunks_per_row, const int stride_KV) { - constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. - const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); - const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + if (use_cp_async) { + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; - if (k0_start == k0_stop) { - return; - } + const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); -#pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + auto load = [&] __device__ (const int n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; - if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { - break; - } + if (k0_start == k0_stop) { + return; + } #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - - cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); - } - } -} - -template -static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_sync( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); - const int k0_stop = D2 - D2 % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - if (k0_start == k0_stop) { - return; - } + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { - break; - } + cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); + } + } + }; + ggml_cuda_unroll<5>{}(load); + } else { + static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + auto load = [&] __device__ (const int n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + return; + } #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - - tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; - } - } -} + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); -template -static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { - - // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } - if (use_cp_async) { - const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); - constexpr int h2_per_chunk = 16/sizeof(half2); - const int chunks_per_row = D2 / h2_per_chunk; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - flash_attn_ext_f16_load_tile_async_loop_iter_async - (KV, tile_KV_32, chunks_per_row, stride_KV); - flash_attn_ext_f16_load_tile_async_loop_iter_async - (KV, tile_KV_32, chunks_per_row, stride_KV); - flash_attn_ext_f16_load_tile_async_loop_iter_async - (KV, tile_KV_32, chunks_per_row, stride_KV); - flash_attn_ext_f16_load_tile_async_loop_iter_async - (KV, tile_KV_32, chunks_per_row, stride_KV); - flash_attn_ext_f16_load_tile_async_loop_iter_async - (KV, tile_KV_32, chunks_per_row, stride_KV); - } else { - static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); - flash_attn_ext_f16_load_tile_async_loop_iter_sync - (KV, tile_KV, D2, stride_KV); - flash_attn_ext_f16_load_tile_async_loop_iter_sync - (KV, tile_KV, D2, stride_KV); - flash_attn_ext_f16_load_tile_async_loop_iter_sync - (KV, tile_KV, D2, stride_KV); + tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + } + } + }; + ggml_cuda_unroll<3>{}(load); } }