Skip to content

Commit 23e8fa5

Browse files
authored
Add the option for the macro and note (Dao-AILab#893)
1 parent 3e9414f commit 23e8fa5

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

csrc/flash_attn/src/softmax.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,14 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
7878
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
7979
// max * log_2(e)) This allows the compiler to use the ffma
8080
// instruction instead of fadd and fmul separately.
81-
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
81+
// The following macro will disable the use of fma.
82+
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
83+
// This macro is set in PyTorch and not FlashAttention
84+
#ifdef UNFUSE_FMA
85+
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
86+
#else
87+
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
88+
#endif
8289
}
8390
}
8491
}

0 commit comments

Comments
 (0)