Skip to content

Commit 0f6a4b4

Browse files
committed
[WIP] quality tweaks - for constants, defer float cast and use double for intermediate computations, add model to EOT token
1 parent 5b9d8a9 commit 0f6a4b4

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

gemma.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
295295
static constexpr size_t kModelDim =
296296
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
297297
static constexpr size_t kHeads = TConfig::kHeads;
298-
const float kQueryScale = 1.0 / sqrtf(static_cast<float>(kQKVDim));
298+
static const float kQueryScale = static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
299299

300300
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
301301
// linear projections to QKV
@@ -418,7 +418,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
418418
hwy::ThreadPool& inner_pool) {
419419
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
420420
static constexpr size_t kModelDim = TConfig::kModelDim;
421-
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
421+
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
422422

423423
pool.Run(
424424
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
@@ -473,7 +473,7 @@ void Transformer(int token, size_t pos,
473473
static constexpr size_t kLayers = TConfig::kLayers;
474474
static constexpr size_t kModelDim = TConfig::kModelDim;
475475

476-
static const float kEmbScaling = sqrtf(static_cast<float>(kModelDim));
476+
static const float kEmbScaling = static_cast<float>(sqrt(static_cast<double>(kModelDim)));
477477

478478
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
479479
activations.x.data(), kModelDim);

run.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
186186
if (abs_pos > 0) {
187187
// Prepend "<end_of_turn>" token if this is a multi-turn dialogue
188188
// continuation.
189-
prompt_string = "<end_of_turn>\n" + prompt_string;
189+
prompt_string = "<end_of_turn>model\n" + prompt_string;
190190
}
191191
}
192192

0 commit comments

Comments
 (0)