Skip to content

Commit 684a044

Browse files
theraysmithcopybara-github
authored andcommitted
Reduced parallelism for TransposeQ, making each thread read and write within its own cache lines
PiperOrigin-RevId: 814241032
1 parent 1424466 commit 684a044

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

gemma/flash_attention.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,30 @@ static constexpr size_t kNFx8HTileSize = 8;
6161
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
6262
const size_t qbatch_size, ThreadingContext& ctx) {
6363
static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ");
64+
// Group floats by the number of floats in a cache line.
65+
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
6466
const size_t num_heads = q.Cols() / q_t.Rows();
6567
const size_t batch_size = q.Rows() / qbatch_size;
6668
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
6769
PROFILER_ZONE3(ctx.profiler, worker, zone);
68-
float* HWY_RESTRICT qt_row = q_t.Row(task);
69-
for (size_t qi = 0; qi < qbatch_size; ++qi)
70-
for (size_t h = 0; h < num_heads; ++h) {
71-
for (size_t b = 0; b < batch_size; ++b) {
72-
qt_row[(qi * num_heads + h) * batch_size + b] =
73-
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + task];
70+
for (size_t lane = 0; lane < kNF; ++lane) {
71+
size_t q_row = task * kNF + lane;
72+
if (q_row >= q_t.Rows()) break;
73+
float* HWY_RESTRICT qt_row = q_t.Row(q_row);
74+
for (size_t qi = 0; qi < qbatch_size; ++qi) {
75+
for (size_t h = 0; h < num_heads; ++h) {
76+
for (size_t b = 0; b < batch_size; ++b) {
77+
qt_row[(qi * num_heads + h) * batch_size + b] =
78+
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row];
79+
}
7480
}
7581
}
82+
}
7683
};
7784
{
7885
// Better than kFlat.
79-
ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx,
86+
size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
87+
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
8088
/*cluster_idx=*/0, func);
8189
}
8290
}

0 commit comments

Comments
 (0)