Skip to content

Commit 4aa8d05

Browse files
committed
Merge branch 'dev' into refactor-tidy
2 parents 60d054e + a9aa63f commit 4aa8d05

File tree

7 files changed

+70
-20
lines changed

7 files changed

+70
-20
lines changed

.clang-tidy

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
FormatStyle: file
2+
WarningsAsErrors: "*"
23
Checks: "-*,\
34
abseil-*,\
45
-abseil-string-find-startswith,\
@@ -204,3 +205,6 @@ Checks: "-*,\
204205
-readability-uppercase-literal-suffix,\
205206
-readability-use-anyofallof
206207
"
208+
CheckOptions:
209+
- { key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase }
210+
- { key: readability-identifier-naming.ConstexprVariablePrefix, value: k }

BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ cc_library(
6969
deps = [
7070
":args",
7171
":transformer_ops",
72+
# "//base",
7273
"//compression:compress",
7374
"@hwy//:hwy",
7475
"@hwy//:matvec",
@@ -88,6 +89,7 @@ cc_binary(
8889
":app",
8990
":args",
9091
":gemma_lib",
92+
# "//base",
9193
"//compression:compress",
9294
"@hwy//:hwy",
9395
"@hwy//:nanobenchmark",

compression/blob_store.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ BlobError BlobReader::Open(const char* filename) {
341341
#endif
342342
if (fd_ < 0) return __LINE__;
343343

344-
#if HWY_OS_LINUX
344+
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
345345
// Doubles the readahead window, which seems slightly faster when cached.
346346
(void)posix_fadvise(fd_, 0, 0, POSIX_FADV_SEQUENTIAL);
347347
#endif

gemma.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,16 @@
1919
// which we pass the filename via macro 'argument'.
2020
#undef HWY_TARGET_INCLUDE
2121
#define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT
22-
#include "hwy/foreach_target.h" // IWYU pragma: keep
22+
#include "hwy/foreach_target.h" // IWYU pragma: keep
2323
// Must come after foreach_target.h to avoid redefinition errors.
2424
// copybara:import_next_line:gemma_cpp
2525
#include "compression/compress-inl.h"
2626
// copybara:import_next_line:gemma_cpp
2727
#include "ops.h"
28-
// copybara:import_next_line:gemma_cpp
2928
#include "hwy/contrib/matvec/matvec-inl.h"
3029
#include "hwy/highway.h"
3130
#include "hwy/profiler.h"
3231
#include "hwy/timer.h"
33-
#include "util/args.h" // Path
3432

3533
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
3634
// compile pass, whereas we want this defined in the first.
@@ -766,10 +764,10 @@ GemmaImpl<Config>::GemmaImpl(
766764
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
767765
hwy::AlignedFreeUniquePtr<uint8_t[]>& compressed_weights,
768766
hwy::ThreadPool& pool)
769-
: compressed_weights(std::move(compressed_weights)),
767+
: tokenizer(std::move(tokenizer)),
768+
compressed_weights(std::move(compressed_weights)),
770769
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
771-
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
772-
tokenizer(std::move(tokenizer)) {}
770+
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
773771

774772
template <>
775773
void GemmaImpl<ConfigGemma2B>::Generate(

ops.h

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
340340
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
341341
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
342342
const float* HWY_RESTRICT a, size_t size) {
343-
float total = 0.f;
344-
for (size_t i = 0; i < size; ++i) {
345-
total += a[i] * a[i];
343+
const hn::ScalableTag<float> d;
344+
const size_t N = hn::Lanes(d);
345+
HWY_DASSERT(size >= 2 * N);
346+
HWY_DASSERT(size % (2 * N) == 0);
347+
348+
auto sum0 = hn::Zero(d);
349+
auto sum1 = hn::Zero(d);
350+
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
351+
const auto a0 = LoadU(d, a + i);
352+
sum0 = MulAdd(a0, a0, sum0);
353+
const auto a1 = LoadU(d, a + i + N);
354+
sum1 = MulAdd(a1, a1, sum1);
346355
}
347-
return total;
356+
357+
return ReduceSum(d, Add(sum0, sum1));
348358
}
349359

350360
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
@@ -362,12 +372,30 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
362372
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
363373
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
364374
float* HWY_RESTRICT out, size_t size) {
365-
constexpr float eps = 1e-6f;
366-
float ss = SquaredL2(x, size);
367-
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
368-
for (size_t j = 0; j < size; j++) {
369-
// Note 1.0f centering here
370-
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
375+
namespace hn = hwy::HWY_NAMESPACE;
376+
377+
constexpr float kEps = 1e-6f;
378+
constexpr size_t kUnrollSize = 2;
379+
380+
const hn::ScalableTag<hwy::bfloat16_t> dbf;
381+
const hn::Repartition<float, decltype(dbf)> df32;
382+
const size_t N32 = hn::Lanes(df32);
383+
384+
const float ss = SquaredL2(x, size);
385+
const auto vss =
386+
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
387+
388+
HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0);
389+
for (size_t i = 0; i < size; i += kUnrollSize * N32) {
390+
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
391+
const auto w0 = hn::PromoteLowerTo(df32, w16);
392+
const auto w1 = hn::PromoteUpperTo(df32, w16);
393+
const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
394+
const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
395+
396+
// (1+weight) * m = m + weight*m = one FMA.
397+
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i);
398+
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32);
371399
}
372400
}
373401

run.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
6666
<< std::thread::hardware_concurrency() << std::endl
6767
<< "Instruction set : "
6868
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
69-
<< hwy::VectorBytes() * 8 << " bits)"
70-
<< "\n"
69+
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
70+
<< "Compiled config : " << CompiledConfig() << "\n"
7171
<< "Weight Type : "
7272
<< gcpp::TypeName(gcpp::WeightT()) << "\n"
7373
<< "EmbedderInput Type : "
@@ -119,7 +119,7 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
119119
verbosity](int token, float) {
120120
++abs_pos;
121121
++current_pos;
122-
if (current_pos < prompt_size) {
122+
if (current_pos <= prompt_size) {
123123
std::cerr << "." << std::flush;
124124
} else if (token == gcpp::EOS_ID) {
125125
if (!args.multiturn) {

util/app.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,24 @@
3636

3737
namespace gcpp {
3838

39+
static inline const char* CompiledConfig() {
40+
if (HWY_IS_ASAN) {
41+
return "asan";
42+
} else if (HWY_IS_MSAN) {
43+
return "msan";
44+
} else if (HWY_IS_TSAN) {
45+
return "tsan";
46+
#if defined(HWY_IS_UBSAN)
47+
} else if (HWY_IS_UBSAN) {
48+
return "ubsan";
49+
#endif
50+
} else if (HWY_IS_DEBUG_BUILD) {
51+
return "dbg";
52+
} else {
53+
return "opt";
54+
}
55+
}
56+
3957
static inline void PinThreadToCore(size_t cpu_index) {
4058
#if HWY_OS_LINUX
4159
// Forces the thread to run on the logical processor with the same number.

0 commit comments

Comments
 (0)