Skip to content

CUDA: FA support for Deepseek (Ampere or newer) #13306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
do loop unrolling via C++ template
  • Loading branch information
JohannesGaessler committed May 9, 2025
commit fe2b775ab5d959c9e383e0b5dd5f3e2126a74a60
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ if (CUDAToolkit_FOUND)

set(CUDA_CXX_FLAGS "")

set(CUDA_FLAGS -use_fast_math)
set(CUDA_FLAGS -use_fast_math -extended-lambda)

if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
# Options are:
Expand Down
19 changes: 19 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,25 @@ static __device__ void no_device_code(
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
#endif // __CUDA_ARCH__

// The compiler is always able to unroll loops if they contain continue expressions.
// In such cases loop unrolling can still be achieved via recursion:
template <int n>
struct ggml_cuda_unroll {
template <typename Func, typename... Args>
__device__ void operator()(const Func & f, Args... args) const {
f(n - 1, args...);
ggml_cuda_unroll<n - 1>{}(f, args...);
}
};

template <>
struct ggml_cuda_unroll<1> {
template <typename Func, typename... Args>
__device__ void operator()(const Func & f, Args... args) const {
f(0, args...);
}
};

template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
Expand Down
130 changes: 54 additions & 76 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -106,98 +106,76 @@ struct fattn_mma_f16_config<576, 512> {

// ------------------------------------------------------------------------------------------------------------------

// The compiler is unable to unroll loops with the k0_start == k0_stop condition.
// Therefore, write functions for the loop iterations and unroll the loops manually.
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {

template<int stride_tile, int nwarps, int nbatch_fa, int stride_k>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_async(
const half2 * const __restrict__ KV, const unsigned int tile_KV_32, const int chunks_per_row, const int stride_KV) {
constexpr int preload = 64;
constexpr int h2_per_chunk = 16/sizeof(half2);
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.

const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
if (use_cp_async) {
constexpr int preload = 64;
constexpr int h2_per_chunk = 16/sizeof(half2);
const int chunks_per_row = D2 / h2_per_chunk;

if (k0_start == k0_stop) {
return;
}
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);

#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
auto load = [&] __device__ (const int n) {
const int stride_k = WARP_SIZE >> n;
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;

if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
break;
}
if (k0_start == k0_stop) {
return;
}

#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);

cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
}
}
}

template<int stride_tile, int nwarps, int nbatch_fa, int stride_k>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile_async_loop_iter_sync(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
const int k0_stop = D2 - D2 % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);

if (k0_start == k0_stop) {
return;
}
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
break;
}

#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);

if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
break;
}
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
}
}
};
ggml_cuda_unroll<5>{}(load);
} else {
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
auto load = [&] __device__ (const int n) {
const int stride_k = WARP_SIZE >> n;
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
const int k0_stop = D2 - D2 % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;

if (k0_start == k0_stop) {
return;
}

#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);

tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
}
}
}
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);

template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {

// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
break;
}

if (use_cp_async) {
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
constexpr int h2_per_chunk = 16/sizeof(half2);
const int chunks_per_row = D2 / h2_per_chunk;
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);

flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
(KV, tile_KV_32, chunks_per_row, stride_KV);
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2>
(KV, tile_KV_32, chunks_per_row, stride_KV);
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4>
(KV, tile_KV_32, chunks_per_row, stride_KV);
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/8>
(KV, tile_KV_32, chunks_per_row, stride_KV);
flash_attn_ext_f16_load_tile_async_loop_iter_async<stride_tile, nwarps, nbatch_fa, WARP_SIZE/16>
(KV, tile_KV_32, chunks_per_row, stride_KV);
} else {
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE>
(KV, tile_KV, D2, stride_KV);
flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/2>
(KV, tile_KV, D2, stride_KV);
flash_attn_ext_f16_load_tile_async_loop_iter_sync<stride_tile, nwarps, nbatch_fa, WARP_SIZE/4>
(KV, tile_KV, D2, stride_KV);
tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
}
}
};
ggml_cuda_unroll<3>{}(load);
}
}

Expand Down
Loading