Skip to content

CUDA: FA support for Deepseek (Ampere or newer) #13306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented May 4, 2025

This PR adds FlashAttention CUDA support for Deepseek models.

  • The comparatively large heads of sizes 576 and 512 are have been difficult to fit in registers + SRAM, I had to do several trade-offs between speed and memory use to make the kernel work. In particular, for Deepseek I'm storing the Q values in SRAM instead of registers, I'm using only a single pipeline stage for data loading (instead of 2), and I load and process the KV data in batches. That way the kernel just barely fits in 99 kiB SRAM + 64k registers. Unfortunately that means that the kernel won't work for Turing which only has 64 kiB SRAM and lacks asynchronous data loading (reduces register usage). Maybe there's still some way I can make it work but I'm leaving that for a future PR. Volta or older is not supported due to a lack of necessary tensor core instructions.
  • Luckily, because the MLA implementation translates to GQA with 16 KV heads per Q head it is possible to use the same kernel for batch size 1 and batch sizes >> 1 so only one kernel is needed to cover prompt processing and token generation. Unfortunately this kernel doesn't properly support KV cache quantization though so performance will be bad.
  • I implemented the ability to use different head sizes for K and V more generally. Are there other models that can piggyback off of this? (I don't have a good overview.)
  • It may be worthwhile to explore the newly added options for lower memory use for other models as well; particularly the Gemma models with their comparatively large heads could benefit (left to future PRs).
  • I swapped two of the dimensions for a matrix multiplication in llama-graph.cpp to trigger a more efficient CUDA code path. In principle something like this could also be done as an automatic optimization in either the CUDA backend or the compute graph. It may also make sense to add some function like ggml_permute_data_layout that directly permutes the ne and nb values of a tensor. As long as all backends support non-contiguous outputs that would save you from having to add ggml_cont.
Performance changes
GPU model n_ubatch test t/s no FA t/s FA Speedup
RTX 3090 deepseek2 16B Q4_0 1 pp16384 74.23 143.64 1.94
RTX 3090 deepseek2 16B Q4_0 2 pp16384 100.17 141.32 1.41
RTX 3090 deepseek2 16B Q4_0 4 pp16384 171.83 227.60 1.32
RTX 3090 deepseek2 16B Q4_0 8 pp16384 268.07 340.07 1.27
RTX 3090 deepseek2 16B Q4_0 16 pp16384 369.53 467.17 1.26
RTX 3090 deepseek2 16B Q4_0 32 pp16384 693.76 743.57 1.07
RTX 3090 deepseek2 16B Q4_0 64 pp16384 1020.42 1063.40 1.04
RTX 3090 deepseek2 16B Q4_0 128 pp16384 1502.64 1562.13 1.04
RTX 3090 deepseek2 16B Q4_0 256 pp16384 2017.41 2066.18 1.02
RTX 3090 deepseek2 16B Q4_0 512 pp16384 2343.85 2399.70 1.02
RTX 3090 deepseek2 16B Q4_0 1024 pp16384 2515.60 2566.45 1.02
RTX 3090 deepseek2 16B Q4_0 2048 pp16384 2545.05 2741.79 1.08
RTX 4090 deepseek2 16B Q4_0 1 pp16384 162.60 193.94 1.19
RTX 4090 deepseek2 16B Q4_0 2 pp16384 164.72 185.95 1.13
RTX 4090 deepseek2 16B Q4_0 4 pp16384 295.38 317.07 1.07
RTX 4090 deepseek2 16B Q4_0 8 pp16384 485.03 520.84 1.07
RTX 4090 deepseek2 16B Q4_0 16 pp16384 730.50 778.03 1.07
RTX 4090 deepseek2 16B Q4_0 32 pp16384 1169.94 1302.99 1.11
RTX 4090 deepseek2 16B Q4_0 64 pp16384 1829.60 1939.27 1.06
RTX 4090 deepseek2 16B Q4_0 128 pp16384 2696.55 2986.22 1.11
RTX 4090 deepseek2 16B Q4_0 256 pp16384 3552.00 4167.87 1.17
RTX 4090 deepseek2 16B Q4_0 512 pp16384 4178.90 5026.71 1.20
RTX 4090 deepseek2 16B Q4_0 1024 pp16384 4292.91 5489.14 1.28
RTX 4090 deepseek2 16B Q4_0 2048 pp16384 4207.91 5462.85 1.30

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels May 4, 2025
@Panchovix
Copy link

Panchovix commented May 5, 2025

Just tested on DeepSeek V3 0324 (Q2_K_XL) and it work fine, so you can use MLA + FA. I'm offloading ~110GB RAM to CPU and the rest on GPU (of a 255GB model), and this saves a lot of GPU usage. I get a bit less PP but I guess it's because CPU is slower with fa? But faster gen speed.

Loading with (PC with Ryzen 7 7800X3D, 192GB RAM at 6000Mhz, Fedora 42)

./llama-server -m '/GGUFs/DeepSeek-V3-0324-UD-Q2_K_XL-merged.gguf' -c 16384 --no-mmap --no-warmup -v -ngl 99 --override-tensor 'blk\.(2[5-9]|[3-6][0-9])\..*_exps\.=CPU' --override-tensor 'blk\.([1-6])\..*_exps\.=CUDA0' --override-tensor 'blk\.([7-9]|1[0])\..*_exps\.=CUDA1' --override-tensor 'blk\.(1[1-5])\..*_exps\.=CUDA2' --override-tensor 'blk\.(1[6-9]|2[0-4])\..*_exps\.=CUDA3' -fa

When not using -fa

prompt eval time = 38919.92 ms / 1528 tokens ( 25.47 ms per token, 39.26 tokens per second)
eval time = 57175.47 ms / 471 tokens ( 121.39 ms per token, 8.24 tokens per second)

When using -fa

prompt eval time =   88134.76 ms /  3252 tokens (   27.10 ms per token,    36.90 tokens per second)
       eval time =   46872.06 ms /   417 tokens (  112.40 ms per token,     8.90 tokens per second)

But if we move the regex a bit to use more tensors (as now we can use more because buffers weight a lot less)

./llama-server -m '/GGUFs/DeepSeek-V3-0324-UD-Q2_K_XL-merged.gguf' -c 16384 --no-mmap --no-warmup -ngl 99 --override-tensor 'blk\.(2[7-9]|[3-6][0-9])\..*_exps\.=CPU' --override-tensor 'blk\.([1-6])\..*_exps\.=CUDA0' --override-tensor 'blk\.([7-9]|1[0])\..*_exps\.=CUDA1' --override-tensor 'blk\.(1[1-6])\..*_exps\.=CUDA2' --override-tensor 'blk\.(1[7-9]|2[0-6])\..*_exps\.=CUDA3' -fa

I get

prompt eval time =   84468.91 ms /  3252 tokens (   25.97 ms per token,    38.50 tokens per second)
       eval time =   37112.76 ms /   343 tokens (  108.20 ms per token,     9.24 tokens per second)

Which I still think has room for improvement, as some GPUs have >4GB left, but it works as quick test. Also I think it is using a slower GPU for PP (saturated at PCI-E 4.0 X8) instead of my faster GPU (at PCI-E 5.0 X8). Will check if I can change the GPU that does PP.

@slaren
Copy link
Member

slaren commented May 5, 2025

I see lower performance with fa enabled at low contexts, but it improves with larger contexts.

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6, VMM: yes

model size params backend ngl fa test t/s
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 0 tg128 200.94 ± 1.50
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 0 tg128 @ d1024 175.40 ± 1.13
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 0 tg128 @ d2048 158.07 ± 1.49
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 0 tg128 @ d4096 137.65 ± 1.36
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 1 tg128 161.17 ± 2.67
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 1 tg128 @ d1024 151.07 ± 1.85
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 1 tg128 @ d2048 153.45 ± 1.76
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 1 tg128 @ d4096 147.69 ± 2.19

build: d19838e (5276)

@Panchovix
Copy link

Panchovix commented May 5, 2025

Changed the device which process PP and speeds are pretty good. This is with DeepSeek V3 0324 UD_Q2_K_XL.

prompt eval time =   49257.75 ms /  3252 tokens (   15.15 ms per token,    66.02 tokens per second)
       eval time =   46322.14 ms /   436 tokens (  106.24 ms per token,     9.41 tokens per second)

It seems to saturate X8 5.0 (26-27 GiB/s), but not X16 5.0 (tops about 28-29 GiB/s), so I guess there is a limitation somewhere.

Hope this can be merged! As some latest updates are pretty good as well.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented May 6, 2025

I was testing the new config options for Gemma and noticed that for non-Deepseek models the kernel in this PR was 5-10% slower than the one on master (meaning the runtime of the kernel itself, not end-to-end performance). I dug around a bit and as it turns out the CUDA compiler is unable to actually unroll the loops for loading KV data. I unrolled the loops manually which ends up being ugly but I don't know how else to do it. There are similar loops for loading Q and storing VKQ that could in principle be given the same treatment but there is no measurable difference to the kernel runtime - those loops are executed once each and not once per 32/64 tokens.

@Panchovix
Copy link

Panchovix commented May 7, 2025

Just an extra comment (sorry for so many!) but this PR is huge for PP performance if you increase ubatch thanks to the saved VRAM from the smaller buffers.

As I posted above, with default ubatch 512, PP is 66 t/s

With ubatch 1024

prompt eval time =   34965.38 ms /  3565 tokens (    9.81 ms per token,   101.96 tokens per second)
       eval time =   45389.59 ms /   416 tokens (  109.11 ms per token,     9.17 tokens per second)

With ubatch 1536

prompt eval time =   28097.73 ms /  3565 tokens (    7.88 ms per token,   126.88 tokens per second)
       eval time =   43426.93 ms /   404 tokens (  107.49 ms per token,     9.30 tokens per second)

This is an 25.7% increase over -ub 1024, 92.4% increase over -ub 512 and 225% increase over -ub 512 and PCI-E X8 4.0.

This wouldn't be possible without this commit, so really, really thanks!

EDIT:

Improved it a little more now.

prompt eval time =   25414.11 ms /  3565 tokens (    7.13 ms per token,   140.28 tokens per second)
      eval time =   38079.82 ms /   344 tokens (  110.70 ms per token,     9.03 tokens per second)

This is by using 2 less layers on GPU but increasing -ub to 2048 and -b to 2560. Just impressive.

@jukofyork
Copy link
Collaborator

jukofyork commented May 7, 2025

This is a very useful PR for my setup where I have all but the non-shared experts in VRAM:

  1. I get around 25% token generation improvement (~4 tokens/s --> ~5 tokens/s). This is around what I got with the old "2D view" method before the MLA PR got merged, and I am still getting 4.85 tokens/s generation after reading in ~55k tokens of context!
  2. This allows me to up the ubatch size massively (using 4096 atm) and this means I no longer have to set const int min_batch_size = 999999999 in ggml/src/ggml-cuda/ggml-cuda.cu due to the very slow transfer from PCI-e 3 16x to VRAM, so using ubatch = 4096 I've gone from ~25 tokens/s to ~50 tokens/s prompt processing speed for a ~55k tokens test!

It would be very "MLA-specific", but it's worth noting that the upper 512 elements of each K is the same as V (the first 64 elements hold the RoPE-ed values, and it must be the first 64 to work with the existing context shifting code):

  • I don't really know anything about CUDA, but from what you said in the old MLA thread about CUDA registers and its relation to CPU cache-thrashing; I wonder if you could get a significant boost by using these instead of having the access the copy placed in V?
  • If it did work, then we could pass through a nullptr for V and this would reduce the KV-cache size by around 47% (1-576/(576+512)).
  • Alternatively, we could pass a strided view of K in place of V? This would require less/no changes to the KV-cache code?

@jukofyork
Copy link
Collaborator

@Panchovix Thanks for your post on Reddit! It was only after reading that, that I also tried the increasing ubatch idea and got a massive increase too!

@JohannesGaessler
Copy link
Collaborator Author

It would be very "MLA-specific", but it's worth noting that the upper 512 elements of each K is the same as V (the first 64 elements hold the RoPE-ed values, and it must be the first 64 to work with the existing context shifting code)

Thank you, I misremembered how the v_mla matrix is used. I thought it was applied prior to writing the values to the KV cache but it is in fact being applied after FlashAttention. In that case the kernel could already take advantage of this without deduplicating the KV cache. For each iteration the kernel is currently loading 576 K values and 512 V values which are explicitly stored in SRAM. You could then just re-use the K data for V and skip the second load. The problem is that on consumer GPUs I'm already at the limit of the SRAM per streaming multiprocessor and I have to load K and V in batches in order to make the kernel fit. So as it is, you would still need to load half of the V data and you'd save only ~25% of the I/O. So this optimization would either be limited to A100/H100/B100 which all have way more SRAM or I would have to drop the number of Q columns that each CUDA block works on - but that would be counterproductive because it would overall lead to 50% more total I/O. I'll try to think of a way to properly take advantage of this; optimizations to deduplicate the KV cache can be done in parallel since you could just pass a pointer into the K cache as V and the code should still work correctly.

@slaren
Copy link
Member

slaren commented May 8, 2025

@JohannesGaessler Is the drop of performance with deepseek lite at low contexts expected?

@JohannesGaessler
Copy link
Collaborator Author

It's expected in the sense that I had to do a lot of tradeoffs to make the kernel work at all. Compared to e.g. LLaMA with a head size of 128 the FA kernel for Deepseek just isn't as performant. So it only becomes faster once the non-FA implementation becomes slow enough. There's probably still optimization headroom but I don't know how much until I try.

@JohannesGaessler
Copy link
Collaborator Author

Actually, there is a reason why the FA kernel performs poorly for Deepseek in particular at short contexts. I forgot that with the MLA implementation Deepseek effectively has only a single KV head. The granularity with which KV slices are assigned to an SM is 256 tokens so there just aren't enough tasks to get good GPU utilization on e.g. an RTX 3090 ti with 84 SMs.

Comment on lines +6 to +14
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
#ifdef CP_ASYNC_AVAILABLE
return __cvta_generic_to_shared(generic_ptr);
#else
GGML_UNUSED(generic_ptr);
NO_DEVICE_CODE;
return 0;
#endif // CP_ASYNC_AVAILABLE
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is no fallback, why not avoid compiling the kernels that need this intrinsic in the first place?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of development it's more convenient for me if potential breakage is encapsulated in an API such as this. That way, if I need to do a git bisect of my WIP commits later on there is less risk of having to deal with code that doesn't compile on specific hardware.

Comment on lines 109 to 110
// The compiler is unable to unroll loops with the k0_start == k0_stop condition.
// Therefore, write functions for the loop iterations and unroll the loops manually.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could avoid some code duplication using the Unroll template from the AMX implementation at

// Forced unrolling

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, this is a good solution. For this to work in device code CUDA needs to be compiled with the flag -extended_lambda. The flag was added with CUDA 8 and should be unprolematic, HIP and MUSA seem to work without modification.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integral_constant to pass the loop index, and the auto parameters of the lambda are important to ensure that the argument is constexpr, otherwise you are still relying on the compiler to remove the parameter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this particular case it seems like the compiler can still do the correct optimizations - I'll include a fix the next time I make a CUDA PR.

@Panchovix
Copy link

Again wanted to mention how awesome this PR is.

It let me load DeepSeek V3 0324 Q3_K_XL (3.53BPW) on 192GB RAM on a consumer PC + 128GB VRAM, with 64K context. First gen is slow but then and onwards for next messages it works fine.

prompt eval time =   52240.35 ms /  2625 tokens (   19.90 ms per token,    50.25 tokens per second)
       eval time =   48155.50 ms /   365 tokens (  131.93 ms per token,     7.58 tokens per second)

If tests are needed for a merge I can test!

@JohannesGaessler JohannesGaessler merged commit 0cf6725 into ggml-org:master May 9, 2025
46 checks passed
@JohannesGaessler
Copy link
Collaborator Author

I'm merging the PR as-is, any potential performance optimizations I'll do in a follow-up PR.

gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 9, 2025
* origin/master: (39 commits)
server : vision support via libmtmd (ggml-org#12898)
sycl : implementation of reordered Q4_0 MMVQ for Intel GPUs (ggml-org#12858)
metal : optimize MoE for large batches (ggml-org#13388)
CUDA: FA support for Deepseek (Ampere or newer) (ggml-org#13306)
llama : do not crash if there is no CPU backend (ggml-org#13395)
CUDA: fix crash on large batch size for MoE models (ggml-org#13384)
imatrix : Add --parse-special for enabling parsing of special tokens in imatrix calculation (ggml-org#13389)
llama-run: add support for downloading models from ModelScope (ggml-org#13370)
mtmd : fix batch_view for m-rope (ggml-org#13397)
llama : one-off chat template fix for Mistral-Small-2503 (ggml-org#13398)
rpc : add rpc_msg_set_tensor_hash_req (ggml-org#13353)
vulkan: Allow up to 4096 elements for mul_mat_id row_ids (ggml-org#13326)
server : (webui) rename has_multimodal --> modalities (ggml-org#13393)
ci : limit write permission to only the release step + fixes (ggml-org#13392)
mtmd : Expose helper_decode_image_chunk (ggml-org#13366)
server : (webui) fix a very small misalignment (ggml-org#13387)
server : (webui) revamp the input area, plus many small UI improvements (ggml-org#13365)
convert : support rope_scaling type and rope_type (ggml-org#13349)
mtmd : fix the calculation of n_tokens for smolvlm (ggml-org#13381)
context : allow cache-less context for embeddings (ggml-org#13108)
...
@Dampfinchen
Copy link

Dampfinchen commented May 9, 2025

Sadly with this PR, it seems Flash Attention has been broken on Turing as the model output gibberish (tested with Qwen 3 MoE) with patial offloading and probably full offloading as well (can't test with this model).

Output b5299 with ./llama-cli -m "Qwen3-30B-A3B-UD-Q4_K_XL.gguf" -fa -ngl 10 -p "Hello /no_think"

Hello! How can I assist you today? 😊

Output b5331 with -fa and -ngl 10

\\/\\/\\/ ‫\\/cerrCheckedChangeListener xen xen

Output b5331 without fa

Hello! How can I assist you today? 😊

@JohannesGaessler
Copy link
Collaborator Author

Should be fixed by #13415 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Eval bug: ggml_cuda_compute_forward: MUL_MAT failed when using FA + MLA on DeepSeekv3 0324, on mixed CPU + GPU
6 participants