Skip to content

Eval bug: b5335 break flash attention on 4070 #13430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
steampunque opened this issue May 10, 2025 · 20 comments · Fixed by #13438
Closed

Eval bug: b5335 break flash attention on 4070 #13430

steampunque opened this issue May 10, 2025 · 20 comments · Fixed by #13438

Comments

@steampunque
Copy link

Name and Version

b5335 server

Operating systems

Linux

GGML backends

CUDA

Hardware

4070

Models

any (tested with Qwen3 8B)

Problem description & steps to reproduce

gibberish is generation when FA is turned on.

The problem goes away if making the following change in the cuda source file :

fattn-mma-f16.cuh

line 550 at b5335

//constexpr bool use_cp_async = nstages == 1;
constexpr bool use_cp_async = 0;

First Bad Commit

Unknown

Relevant log output

flash attention on:

bash-5.1$ lm Hello
郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦郦^Cbash-5.1$ 

flash attention off:

bash-5.1$ 
bash-5.1$ 
bash-5.1$ 
bash-5.1$ lm Hello
<think>
Okay, the user said "Hello". I need to respond appropriately. Since it's a greeting, I should acknowledge it and offer assistance. Let me keep it friendly and open-ended. Maybe ask how I can help them today. That way, they know I'm here to assist with any questions or tasks they might have. I should make sure the response is welcoming and not too formal. Let me check for any typos or errors. Alright, that should work.
</think>
@JohannesGaessler
Copy link
Collaborator

Does this patch also fix the issue?

diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 64fb4ff4c..fc231e97e 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -216,7 +216,7 @@ typedef float2 dfloat2;
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
-#define CP_ASYNC_AVAILABLE
+// #define CP_ASYNC_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
 #if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
@@ -258,7 +258,7 @@ static bool new_mma_available(const int cc) {
 }
 
 static bool cp_async_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+    return false;
 }
 
 static constexpr __device__ int ggml_cuda_get_physical_warp_size() {

@steampunque
Copy link
Author

Does this patch also fix the issue?

Look OK with that patch too (my patch first removed)

@JohannesGaessler
Copy link
Collaborator

I was unable to reproduce the issue but is it fixed by #13438 ?

@steampunque
Copy link
Author

I was unable to reproduce the issue but is it fixed by #13438 ?

Look solid now. Most likley my 4070 wasn't getting done in time compared to the CPU thread that was spawning the async copy. Nice find and thanks again for your fast responses!

@JohannesGaessler
Copy link
Collaborator

This is not a race condition condition between CPU and GPU, it's a race condition between threads on one of the streaming multiprocessors on the GPU. Whether that race condition actually manifests as a bug is not guaranteed though (I was using a 4090).

@steampunque
Copy link
Author

OK. Might be driver version + hardware related hard to say. I guess I just drew the short straw and my hardware + driver combo happened to triggered the bug.

@steampunque
Copy link
Author

@JohannesGaessler Unfortunately there is still a problem with flash attention on even after this fix. Generation is fine for awhile, then if too many tokens get generated it will start to generate gibberish :

b5342 with FA on generation:

.
.
Year 10: Elsie is 29 years old

  • Elsie stores 10 more apples, making a total of 27 + 10 = 37 apples.
  • = = = = = = = = %D = visually impaired
  • = = 평소처럼 it rats of 평소처럼 there are = = = = 평소처럼 = fast .
    .
    .

b5342 with FA off generation (no problems)

.
.
Year 10: Elsie is 29 years old

  • Elsie stores 10 more apples, making a total of 27 + 10 = 37 apples.
  • Rats steal 2 apples, leaving 37 - 2 = 35 apples.
  • Elsie eats 5 apples, leaving 35 - 5 = 30 apples.

Year 11: Elsie is 30 years old
.
.

The model is Lllama 4 Scout, moe experts fully offloaded to CPU and rest of tensors including moe shared on GPU.

b5279 doesn't show this problem but I do not know the commit that suddently causes the FA gibberish problem to start.
Issue should probably be reopened.

@JohannesGaessler
Copy link
Collaborator

Please tell me the exact model and command then that you're using.

@steampunque
Copy link
Author

steampunque commented May 11, 2025

Please tell me the exact model and command then that you're using.

It fails with the Q3_K_H model that you downloaded the same way (after a certain number of tokens it starts generating garbage, always at the same point).

This prompt can be used:

Count from 1 to 200. Spell out numbers on each new line.

with FA on, garbage after 159:

One Hundred Fifty-Six
One Hundred Fifty-Seven
One Hundred Fiftyight
One Hundred Fifty-Nine
One Hundred Fifty Hundred Tally goes on
One One Hundred Fifty Hundred Ninety Fifty
One hundred eighty
Ninety fifty
Nin692/80
One Hundred Thirty-Five One Hundred Ninلسط 141
� Two
One Hundred Two
0

With FA off, it will go to 200 correctly.

The problem I am having is that it only does this when I am speculating it with Lllama 3.2 1B, and upstream llama.cpp will not speculate this model since it doesn't support translating between vocabularies of speculator and target which my downstream does. The unique thing happening during speculation is that during generation the target is not evaluating with batch size 1, it is evaluating with batch size 4 to 5. So something in the short batch sizes seems to have got wiped out with recent changes. It works fine with FA off, or with b5279. I don't know the commit that broke it but I am suspect the same one that caused all the async bugs to start happening.

I will see if I can find a way to trigger the bug with upstream. If there is a way to force it to do decodes with short batch sizes it might help expose the problem.

UPDATE:

Problem does not appear to be related to async copy. I applied the patch above to global turn off async copy on release b5347 and identical problem happens. Identical failure in identical place also happens with KV set to F16 or q8_0. So something else got changed since b5279 which is causing short block size decodes to fail consistently after a certain number of tokens get evaluated.

UPDATE 2:

Problem does not happen with Qwen3-30B-A3B speculated with Qwen3-0.6B with similar config (experts offloaded to CPU, rest of tensors on GPU). This model does not have shared experts like Lllama 4 Scout.

@JohannesGaessler
Copy link
Collaborator

I ran up to 1000 question from MMLU, MMLU-Pro, GPQA, and GSM8K on a variety of models, once on a commit prior to my changes to MMQ and FlashAttention and once after those changes.

Before: 45150/103440 correct answers.
After: 44535/103440 correct answers.

I included the LLaMA 4 quant you reported issues with, all other models are at FP16 precision.

Before, LLaMA 4 q3_K_H only: 3729/6688
After, LLaMA 4 q3_K_H only: 3601/6688
Before, FP16 only: 41421/96544
After, FP16 only: 40934/96544

This does look like there could be a problem.

@steampunque
Copy link
Author

steampunque commented May 11, 2025

I ran up to 1000 question from MMLU, MMLU-Pro, GPQA, and GSM8K on a variety of models, once on a commit prior to my changes to MMQ and FlashAttention and once after those changes.

Before: 45150/103440 correct answers. After: 44535/103440 correct answers.

I included the LLaMA 4 quant you reported issues with, all other models are at FP16 precision.

Before, LLaMA 4 q3_K_H only: 3729/6688 After, LLaMA 4 q3_K_H only: 3601/6688 Before, FP16 only: 41421/96544 After, FP16 only: 40934/96544

This does look like there could be a problem.

I am thinking some interaction with SWA mask and FA on Lllama 4 Scout. Qwen 3 30 moe is rock solid with identical CPU/GPU offload config, Llama 4 is crapping out after a number of tokens go by which points to something going on with SWA + FA now in the newer code (which recently started to give gibberish after a certain number of tokens, b5279 never gave the gibberish, but I think it was still running degraded similar to your findings)

@JohannesGaessler
Copy link
Collaborator

I noticed that the scores for MMLU-Pro in particular got worse. My suspicion was that models would produce comparatively longer answers so that the bug only appears if the context is sufficiently long. On my RTX 3090 I get bad perplexity results with

export model_name=phi_4_mini_instruct-4b && export quantization=f16
./ppl --file wikitext-2-raw/wiki.test.raw -ngl 99 --model models/opt/${model_name}-${quantization}.gguf --chunks 1 -ub 32 -c 8192 -fa

The PR that introduced the problem is #13438 . The problem is that the __syncthreads instruction I added is not always executed by all threads in a CUDA block so the points at which the threads synchronize with each other can become misaligned. I'll make a fix tomorrow, right now I'm too tired.

@steampunque
Copy link
Author

steampunque commented May 11, 2025

I noticed that the scores for MMLU-Pro in particular got worse. My suspicion was that models would produce comparatively longer answers

Yeah that is what I found with the latest code. It was doing OK until the gens got too long then went into complete meltdown with gibbersh output, but I haven't seen that complete meltdown until recently. Maybe it was degrading with the longer gens though. Soft degradation is much harder to detect give statistical nature of underlying models.

so that the bug only appears if the context is sufficiently long. On my RTX 3090 I get bad perplexity results with

export model_name=phi_4_mini_instruct-4b && export quantization=f16
./ppl --file wikitext-2-raw/wiki.test.raw -ngl 99 --model models/opt/${model_name}-${quantization}.gguf --chunks 1 -ub 32 -c 8192 -fa

The PR that introduced the problem is #13438 . The problem is that the __syncthreads instruction I added is not always executed by all threads in a CUDA block so the points at which the threads synchronize with each other can become misaligned. I'll make a fix tomorrow, right now I'm too tired.

Sounds great look forward to testing it tomorrow. Strange that the hard fail point in my tests in latest code was always a precise certain number of tokens in, 740 good gen tokens then gibberish. With prompt length of 26 KV would be at 766 total tokens at the fail point in my test.

@JohannesGaessler
Copy link
Collaborator

The problem is that a call to __syncthreads is uncoalesced so some of the threads get out-of-sync with the rest. But this happens at the end of a KV chunk that a CUDA block is working on. If the context is sufficiently short then the CUDA blocks don't do any additional work after the first chunk and the misaligned synchronizations don't matter.

@JohannesGaessler
Copy link
Collaborator

The problem should be fixed by #13469 .

@steampunque
Copy link
Author

steampunque commented May 12, 2025

The problem should be fixed by #13469 .

Its looking good now, thanks for all your work on debugging this. I'll rerun the benches and update the earlier issue to point to this one to help keep track of the fixes as I think this fix was related to that earlier problem not this later problem unless I am not understanding correctly.

count.txt=Count from 1 to 200. Spell out numbers on each new line.

(note: Qwen3 30B needs this to work : count.txt= Count from 1 to 200. Spell out numbers on each new line, as in One\n Two\n ...)

Llama 4 Scout Q3_K_H, speculated with Llama 3.2 1B instruct using my downstream server, experts offloaded to CPU using -ot exps=CPU and rest of tensor on 4070, FA on, F16 KV, Llama.cpp v b5353:

bash-5.1$ lm count.txt
One
Two
Three
Four
Five
Six
Seven
Eight
Nine
Ten
Eleven
Twelve
Thirteen
Fourteen
Fifteen
Sixteen
Seventeen
Eighteen
Nineteen
Twenty
Twenty-One
Twenty-Two
Twenty-Three
Twenty-Four
Twenty-Five
Twenty-Six
Twenty-Seven
Twenty-Eight
Twenty-Nine
Thirty
Thirty-One
Thirty-Two
Thirty-Three
Thirty-Four
Thirty-Five
Thirty-Six
Thirty-Seven
Thirty-Eight
Thirty-Nine
Forty
Forty-One
Forty-Two
Forty-Three
Forty-Four
Forty-Five
Forty-Six
Forty-Seven
Forty-Eight
Forty-Nine
Fifty
Fifty-One
Fifty-Two
Fifty-Three
Fifty-Four
Fifty-Five
Fifty-Six
Fifty-Seven
Fifty-Eight
Fifty-Nine
Sixty
Sixty-One
Sixty-Two
Sixty-Three
Sixty-Four
Sixty-Five
Sixty-Six
Sixty-Seven
Sixty-Eight
Sixty-Nine
Seventy
Seventy-One
Seventy-Two
Seventy-Three
Seventy-Four
Seventy-Five
Seventy-Six
Seventy-Seven
Seventy-Eight
Seventy-Nine
Eighty
Eighty-One
Eighty-Two
Eighty-Three
Eighty-Four
Eighty-Five
Eighty-Six
Eighty-Seven
Eighty-Eight
Eighty-Nine
Ninety
Ninety-One
Ninety-Two
Ninety-Three
Ninety-Four
Ninety-Five
Ninety-Six
Ninety-Seven
Ninety-Eight
Ninety-Nine
One Hundred
One Hundred One
One Hundred Two
One Hundred Three
One Hundred Four
One Hundred Five
One Hundred Six
One Hundred Seven
One Hundred Eight
One Hundred Nine
One Hundred Ten
One Hundred Eleven
One Hundred Twelve
One Hundred Thirteen
One Hundred Fourteen
One Hundred Fifteen
One Hundred Sixteen
One Hundred Seventeen
One Hundred Eighteen
One Hundred Nineteen
One Hundred Twenty
One Hundred Twenty-One
One Hundred Twenty-Two
One Hundred Twenty-Three
One Hundred Twenty-Four
One Hundred Twenty-Five
One Hundred Twenty-Six
One Hundred Twenty-Seven
One Hundred Twenty-Eight
One Hundred Twenty-Nine
One Hundred Thirty
One Hundred Thirty-One
One Hundred Thirty-Two
One Hundred Thirty-Three
One Hundred Thirty-Four
One Hundred Thirty-Five
One Hundred Thirty-Six
One Hundred Thirty-Seven
One Hundred Thirty-Eight
One Hundred Thirty-Nine
One Hundred Forty
One Hundred Forty-One
One Hundred Forty-Two
One Hundred Forty-Three
One Hundred Forty-Four
One Hundred Forty-Five
One Hundred Forty-Six
One Hundred Forty-Seven
One Hundred Forty-Eight
One Hundred Forty-Nine
One Hundred Fifty
One Hundred Fifty-One
One Hundred Fifty-Two
One Hundred Fifty-Three
One Hundred Fifty-Four
One Hundred Fifty-Five
One Hundred Fifty-Six
One Hundred Fifty-Seven
One Hundred Fifty-Eight
One Hundred Fifty-Nine
One Hundred Sixty
One Hundred Sixty-One
One Hundred Sixty-Two
One Hundred Sixty-Three
One Hundred Sixty-Four
One Hundred Sixty-Five
One Hundred Sixty-Six
One Hundred Sixty-Seven
One Hundred Sixty-Eight
One Hundred Sixty-Nine
One Hundred Seventy
One Hundred Seventy-One
One Hundred Seventy-Two
One Hundred Seventy-Three
One Hundred Seventy-Four
One Hundred Seventy-Five
One Hundred Seventy-Six
One Hundred Seventy-Seven
One Hundred Seventy-Eight
One Hundred Seventy-Nine
One Hundred Eighty
One Hundred Eighty-One
One Hundred Eighty-Two
One Hundred Eighty-Three
One Hundred Eighty-Four
One Hundred Eighty-Five
One Hundred Eighty-Six
One Hundred Eighty-Seven
One Hundred Eighty-Eight
One Hundred Eighty-Nine
One Hundred Ninety
One Hundred Ninety-One
One Hundred Ninety-Two
One Hundred Ninety-Three
One Hundred Ninety-Four
One Hundred Ninety-Five
One Hundred Ninety-Six
One Hundred Ninety-Seven
One Hundred Ninety-Eight
One Hundred Ninety-Nine
Two Hundred

@JohannesGaessler
Copy link
Collaborator

I'll rerun the benches and update the earlier issue to point to this one to help keep track of the fixes as I think this fix was related to that earlier problem not this later problem unless I am not understanding correctly.

This was about FlashAttention, not quantized MoE models.

@steampunque
Copy link
Author

I'll rerun the benches and update the earlier issue to point to this one to help keep track of the fixes as I think this fix was related to that earlier problem not this later problem unless I am not understanding correctly.

This was about FlashAttention, not quantized MoE models.

Right this issue should be linked to your latest PR if I am understanding?

#13287

@JohannesGaessler
Copy link
Collaborator

The issue that I fixed was caused by me making changes to the FlashAttention code and it could in principle affect all models. Other possible issues could be due to me making changes to the matrix multiplication code ("MMQ") which should affect quantized models or quantized MoE models in particular. Generally speaking I can never guarantee that there are no bugs, only that I fixed all of the bugs that I found so far.

@steampunque
Copy link
Author

The issue that I fixed was caused by me making changes to the FlashAttention code and it could in principle affect all models. Other possible issues could be due to me making changes to the matrix multiplication code ("MMQ") which should affect quantized models or quantized MoE models in particular. Generally speaking I can never guarantee that there are no bugs, only that I fixed all of the bugs that I found so far.

Ok this is my understanding of the sequence.

Original issue :

#13287

First fix for original issue which left some unresolved question marks in performance:

#13438

Fix for issue discovered in #13438: #13469

(generation looks good, re-checking for performance)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants