@@ -61,22 +61,30 @@ static constexpr size_t kNFx8HTileSize = 8;
6161static void TransposeQ (const MatPtrT<float >& q, MatPtrT<float >& q_t ,
6262 const size_t qbatch_size, ThreadingContext& ctx) {
6363 static const auto zone = ctx.profiler .AddZone (" Gen.Attention.TransposeQ" );
64+ // Group floats by the number of floats in a cache line.
65+ const size_t kNF = ctx.cache_info .LineBytes () / sizeof (float );
6466 const size_t num_heads = q.Cols () / q_t .Rows ();
6567 const size_t batch_size = q.Rows () / qbatch_size;
6668 const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
6769 PROFILER_ZONE3 (ctx.profiler , worker, zone);
68- float * HWY_RESTRICT qt_row = q_t .Row (task);
69- for (size_t qi = 0 ; qi < qbatch_size; ++qi)
70- for (size_t h = 0 ; h < num_heads; ++h) {
71- for (size_t b = 0 ; b < batch_size; ++b) {
72- qt_row[(qi * num_heads + h) * batch_size + b] =
73- q.Row (b * qbatch_size + qi)[h * q_t .Rows () + task];
70+ for (size_t lane = 0 ; lane < kNF ; ++lane) {
71+ size_t q_row = task * kNF + lane;
72+ if (q_row >= q_t .Rows ()) break ;
73+ float * HWY_RESTRICT qt_row = q_t .Row (q_row);
74+ for (size_t qi = 0 ; qi < qbatch_size; ++qi) {
75+ for (size_t h = 0 ; h < num_heads; ++h) {
76+ for (size_t b = 0 ; b < batch_size; ++b) {
77+ qt_row[(qi * num_heads + h) * batch_size + b] =
78+ q.Row (b * qbatch_size + qi)[h * q_t .Rows () + q_row];
79+ }
7480 }
7581 }
82+ }
7683 };
7784 {
7885 // Better than kFlat.
79- ParallelFor (ParallelismStrategy::kHierarchical , q_t .Rows (), ctx,
86+ size_t num_tasks = hwy::DivCeil (q_t .Rows (), kNF );
87+ ParallelFor (ParallelismStrategy::kHierarchical , num_tasks, ctx,
8088 /* cluster_idx=*/ 0 , func);
8189 }
8290}
0 commit comments