@@ -55,10 +55,10 @@ __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tenso
5555 reduce_<zero_init>(tensor, max, max_op);
5656}
5757
58- template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
58+ template <bool zero_init= true , typename Engine0, typename Layout0, typename Engine1, typename Layout1>
5959__device__ __forceinline__ void reduce_sum (Tensor<Engine0, Layout0> const & tensor, Tensor<Engine1, Layout1> &sum){
6060 SumOp<float > sum_op;
61- reduce_ (tensor, sum, sum_op);
61+ thread_reduce_<zero_init> (tensor, sum, sum_op);
6262}
6363
6464// Apply the exp to all the elements.
@@ -133,7 +133,7 @@ struct Softmax {
133133 if (Is_first) {
134134 flash::template reduce_max</* zero_init=*/ true >(scores, row_max);
135135 flash::scale_apply_exp2 (scores, row_max, softmax_scale_log2);
136- flash::reduce_sum (scores, row_sum);
136+ flash::reduce_sum< /* zero_init= */ true > (scores, row_sum);
137137 } else {
138138 Tensor scores_max_prev = make_fragment_like (row_max);
139139 cute::copy (row_max, scores_max_prev);
@@ -152,15 +152,16 @@ struct Softmax {
152152 for (int ni = 0 ; ni < size<1 >(acc_o_rowcol); ++ni) { acc_o_rowcol (mi, ni) *= scores_scale; }
153153 }
154154 flash::scale_apply_exp2 (scores, row_max, softmax_scale_log2);
155- Tensor scores_sum_cur = make_fragment_like (row_sum);
156- flash::reduce_sum (scores, scores_sum_cur);
157- #pragma unroll
158- for (int mi = 0 ; mi < size (row_sum); ++mi) { row_sum (mi) += scores_sum_cur (mi); }
155+ // We don't do the reduce across threads here since we don't need to use the row_sum.
156+ // We do that reduce at the end when we need to normalize the softmax.
157+ flash::reduce_sum</* zero_init=*/ false >(scores, row_sum);
159158 }
160159 };
161160
162161 template <bool Is_dropout=false , bool Split=false , typename Tensor0>
163162 __forceinline__ __device__ TensorT normalize_softmax_lse (Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0 ) {
163+ SumOp<float > sum_op;
164+ quad_allreduce_ (row_sum, row_sum, sum_op);
164165 TensorT lse = make_fragment_like (row_sum);
165166 Tensor acc_o_rowcol = make_tensor (acc_o.data (), flash::convert_layout_acc_rowcol (acc_o.layout ()));
166167 static_assert (decltype (size<0 >(acc_o_rowcol))::value == kNRows );
0 commit comments