@@ -340,11 +340,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
340340// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
341341static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2 (
342342 const float * HWY_RESTRICT a, size_t size) {
343- float total = 0 .f ;
344- for (size_t i = 0 ; i < size; ++i) {
345- total += a[i] * a[i];
343+ const hn::ScalableTag<float > d;
344+ const size_t N = hn::Lanes (d);
345+ HWY_DASSERT (size >= 2 * N);
346+ HWY_DASSERT (size % (2 * N) == 0 );
347+
348+ auto sum0 = hn::Zero (d);
349+ auto sum1 = hn::Zero (d);
350+ for (size_t i = 0 ; i <= size - 2 * N; i += 2 * N) {
351+ const auto a0 = LoadU (d, a + i);
352+ sum0 = MulAdd (a0, a0, sum0);
353+ const auto a1 = LoadU (d, a + i + N);
354+ sum1 = MulAdd (a1, a1, sum1);
346355 }
347- return total;
356+
357+ return ReduceSum (d, Add (sum0, sum1));
348358}
349359
350360static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm (
@@ -362,12 +372,30 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
362372static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm (
363373 const float * HWY_RESTRICT x, const hwy::bfloat16_t * HWY_RESTRICT weight,
364374 float * HWY_RESTRICT out, size_t size) {
365- constexpr float eps = 1e-6f ;
366- float ss = SquaredL2 (x, size);
367- ss = 1 .0f / sqrtf (ss / StaticCast<float >(size) + eps);
368- for (size_t j = 0 ; j < size; j++) {
369- // Note 1.0f centering here
370- out[j] = (1 .0f + hwy::F32FromBF16 (weight[j])) * (ss * x[j]);
375+ namespace hn = hwy::HWY_NAMESPACE;
376+
377+ constexpr float kEps = 1e-6f ;
378+ constexpr size_t kUnrollSize = 2 ;
379+
380+ const hn::ScalableTag<hwy::bfloat16_t > dbf;
381+ const hn::Repartition<float , decltype (dbf)> df32;
382+ const size_t N32 = hn::Lanes (df32);
383+
384+ const float ss = SquaredL2 (x, size);
385+ const auto vss =
386+ hn::Set (df32, 1 .0f / sqrtf (ss / StaticCast<float >(size) + kEps ));
387+
388+ HWY_DASSERT (size % (kUnrollSize * MaxLanes (df32)) == 0 );
389+ for (size_t i = 0 ; i < size; i += kUnrollSize * N32) {
390+ const hn::Vec<decltype (dbf)> w16 = hn::LoadU (dbf, weight + i);
391+ const auto w0 = hn::PromoteLowerTo (df32, w16);
392+ const auto w1 = hn::PromoteUpperTo (df32, w16);
393+ const auto m0 = hn::Mul (vss, hn::LoadU (df32, x + i));
394+ const auto m1 = hn::Mul (vss, hn::LoadU (df32, x + i + N32));
395+
396+ // (1+weight) * m = m + weight*m = one FMA.
397+ hn::StoreU (hn::MulAdd (m0, w0, m0), df32, out + i);
398+ hn::StoreU (hn::MulAdd (m1, w1, m1), df32, out + i + N32);
371399 }
372400}
373401
0 commit comments