Skip to content

Commit b32efb1

Browse files
committed
Don't need to reduce row_sum during online softmax
1 parent f45bbb4 commit b32efb1

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

csrc/flash_attn/src/softmax.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tenso
5555
reduce_<zero_init>(tensor, max, max_op);
5656
}
5757

58-
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
58+
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
5959
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
6060
SumOp<float> sum_op;
61-
reduce_(tensor, sum, sum_op);
61+
thread_reduce_<zero_init>(tensor, sum, sum_op);
6262
}
6363

6464
// Apply the exp to all the elements.
@@ -133,7 +133,7 @@ struct Softmax {
133133
if (Is_first) {
134134
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
135135
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
136-
flash::reduce_sum(scores, row_sum);
136+
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
137137
} else {
138138
Tensor scores_max_prev = make_fragment_like(row_max);
139139
cute::copy(row_max, scores_max_prev);
@@ -152,15 +152,16 @@ struct Softmax {
152152
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
153153
}
154154
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
155-
Tensor scores_sum_cur = make_fragment_like(row_sum);
156-
flash::reduce_sum(scores, scores_sum_cur);
157-
#pragma unroll
158-
for (int mi = 0; mi < size(row_sum); ++mi) { row_sum(mi) += scores_sum_cur(mi); }
155+
// We don't do the reduce across threads here since we don't need to use the row_sum.
156+
// We do that reduce at the end when we need to normalize the softmax.
157+
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
159158
}
160159
};
161160

162161
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
163162
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
163+
SumOp<float> sum_op;
164+
quad_allreduce_(row_sum, row_sum, sum_op);
164165
TensorT lse = make_fragment_like(row_sum);
165166
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
166167
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);

0 commit comments

Comments
 (0)