Skip to content

Commit 1b72c22

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Refactor Gemma ctor and improve pool NUMA support
Gemma receives a MatMulEnv arg, with comment on lifetime Split threading into topology so the latter can be used in allocator Add AllocClasses() for non-POD (ThreadPool) Support binding pool to NUMA node Update threading_test with latency measurements Also update Highway version. PiperOrigin-RevId: 736904748
1 parent 1b1b63d commit 1b72c22

31 files changed

+911
-688
lines changed

BUILD.bazel

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,13 @@ cc_library(
2929
],
3030
)
3131

32+
# Split from :threading to break a circular dependency with :allocator.
3233
cc_library(
33-
name = "threading",
34-
srcs = ["util/threading.cc"],
35-
hdrs = ["util/threading.h"],
34+
name = "topology",
35+
srcs = ["util/topology.cc"],
36+
hdrs = ["util/topology.h"],
3637
deps = [
37-
":basics",
38-
# Placeholder for container detection, do not remove
3938
"@highway//:hwy",
40-
"@highway//:thread_pool",
4139
"@highway//:topology",
4240
],
4341
)
@@ -48,32 +46,54 @@ cc_library(
4846
hdrs = ["util/allocator.h"],
4947
deps = [
5048
":basics",
51-
":threading",
49+
":topology",
5250
"@highway//:hwy",
5351
"@highway//:thread_pool",
52+
"@highway//:topology",
5453
],
5554
)
5655

5756
cc_library(
58-
name = "test_util",
59-
hdrs = ["util/test_util.h"],
57+
name = "threading",
58+
srcs = ["util/threading.cc"],
59+
hdrs = ["util/threading.h"],
6060
deps = [
61+
":allocator",
62+
":basics",
63+
":topology",
64+
# Placeholder for container detection, do not remove
6165
"@highway//:hwy",
62-
"@highway//:hwy_test_util",
63-
"@highway//:stats",
66+
"@highway//:thread_pool",
67+
"@highway//:topology",
6468
],
6569
)
6670

6771
cc_test(
6872
name = "threading_test",
6973
srcs = ["util/threading_test.cc"],
7074
deps = [
75+
":allocator",
76+
":basics",
7177
":threading",
7278
"@googletest//:gtest_main",
79+
"@highway//:auto_tune",
7380
"@highway//:hwy",
7481
"@highway//:hwy_test_util",
7582
"@highway//:nanobenchmark",
83+
"@highway//:robust_statistics",
84+
"@highway//:stats",
7685
"@highway//:thread_pool",
86+
"@highway//:timer",
87+
],
88+
)
89+
90+
cc_library(
91+
name = "test_util",
92+
hdrs = ["util/test_util.h"],
93+
deps = [
94+
"@highway//:hwy",
95+
"@highway//:hwy_test_util",
96+
"@highway//:stats",
7797
],
7898
)
7999

@@ -104,6 +124,7 @@ cc_library(
104124
":allocator",
105125
":basics",
106126
":threading",
127+
":topology",
107128
"//compression:compress",
108129
"@highway//:algo",
109130
"@highway//:bit_set",
@@ -113,7 +134,6 @@ cc_library(
113134
"@highway//:nanobenchmark",
114135
"@highway//:profiler",
115136
"@highway//:thread_pool",
116-
"@highway//:topology",
117137
"@highway//hwy/contrib/sort:vqsort",
118138
],
119139
)
@@ -128,11 +148,11 @@ cc_test(
128148
tags = ["ops_tests"],
129149
deps = [
130150
":allocator",
151+
":app",
131152
":ops",
132153
":test_util",
133154
":threading",
134155
"@googletest//:gtest_main", # buildcleaner: keep
135-
"//:app",
136156
"//compression:compress",
137157
"//compression:test_util",
138158
"@highway//:hwy",
@@ -154,11 +174,12 @@ cc_test(
154174
tags = ["ops_tests"],
155175
deps = [
156176
":allocator",
177+
":app",
157178
":common",
158179
":ops",
159180
":test_util",
181+
":threading",
160182
"@googletest//:gtest_main", # buildcleaner: keep
161-
"//:app",
162183
"//compression:compress",
163184
"@highway//:hwy",
164185
"@highway//:hwy_test_util",
@@ -405,6 +426,7 @@ cc_library(
405426
":cross_entropy",
406427
":gemma_lib",
407428
":kv_cache",
429+
":ops",
408430
":threading",
409431
# Placeholder for internal dep, do not remove.,
410432
"@google_benchmark//:benchmark",
@@ -464,13 +486,13 @@ cc_binary(
464486
":benchmark_helper",
465487
":common",
466488
":gemma_lib",
489+
":ops",
467490
":threading",
468491
# Placeholder for internal dep, do not remove.,
469492
"//compression:sfp",
470493
"//paligemma:image",
471494
"@highway//:hwy",
472495
"@highway//:profiler",
473-
"@highway//:thread_pool",
474496
],
475497
)
476498

@@ -634,13 +656,12 @@ cc_test(
634656
":backprop",
635657
":backprop_scalar",
636658
":common",
637-
":gemma_lib",
638659
":ops",
639660
":prompt",
640661
":sampler",
662+
":threading",
641663
":weights",
642664
"@googletest//:gtest_main",
643-
"//:threading",
644665
"//compression:compress",
645666
"@highway//:hwy",
646667
"@highway//:hwy_test_util",

CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
2222
set(CMAKE_CXX_STANDARD_REQUIRED ON)
2323
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2424

25-
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995 EXCLUDE_FROM_ALL)
25+
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06 EXCLUDE_FROM_ALL)
2626
FetchContent_MakeAvailable(highway)
2727

2828
## Note: absl needs to be installed by sentencepiece. This will only happen if
@@ -108,6 +108,8 @@ set(SOURCES
108108
util/test_util.h
109109
util/threading.cc
110110
util/threading.h
111+
util/topology.cc
112+
util/topology.h
111113
)
112114

113115
if(NOT CMAKE_BUILD_TYPE)

MODULE.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
1818
# Require a more recent version.
1919
git_override(
2020
module_name = "highway",
21-
commit = "f2209b911c74019e85d0b7a7a2833c9a2e1b7995",
21+
commit = "c5bebf84ad01edec97e336f5c97ca4e0df6b4d06",
2222
remote = "https://github.com/google/highway",
2323
)
2424

backprop/backward_test.cc

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
#include "ops/ops.h"
3636
#include "util/threading.h"
3737
#include "hwy/base.h"
38-
#include "hwy/contrib/thread_pool/thread_pool.h"
3938

4039
// clang-format off
4140
#undef HWY_TARGET_INCLUDE
@@ -59,9 +58,9 @@ void TestMatMulVJP() {
5958
static const size_t kRows = 8;
6059
static const size_t kCols = 64;
6160
static const size_t kTokens = 5;
62-
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
63-
BoundedSlice(0, 8));
64-
Allocator::Init(pools.Topology());
61+
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
62+
Allocator::Init(topology);
63+
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
6564
std::mt19937 gen(42);
6665
MatStorageT<float> weights("weights", kRows, kCols);
6766
MatStorageT<float> x("x", kTokens, kCols);
@@ -105,9 +104,9 @@ void TestMultiHeadMatMulVJP() {
105104
static const size_t kCols = 16;
106105
static const size_t kHeads = 4;
107106
static const size_t kTokens = 3;
108-
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
109-
BoundedSlice(0, 8));
110-
Allocator::Init(pools.Topology());
107+
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
108+
Allocator::Init(topology);
109+
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
111110
std::mt19937 gen(42);
112111
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
113112
MatStorageT<float> x("x", kTokens, kCols * kHeads);
@@ -150,9 +149,9 @@ void TestMultiHeadMatMulVJP() {
150149
void TestRMSNormVJP() {
151150
static const size_t K = 2;
152151
static const size_t N = 64;
153-
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
154-
BoundedSlice(0, 8));
155-
Allocator::Init(pools.Topology());
152+
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
153+
Allocator::Init(topology);
154+
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
156155
std::mt19937 gen(42);
157156
MatStorageT<float> weights("weights", N, 1);
158157
MatStorageT<float> x("x", K, N);
@@ -216,9 +215,9 @@ static ModelConfig TestConfig() {
216215

217216
void TestEndToEnd() {
218217
std::mt19937 gen(42);
219-
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
220-
BoundedSlice(0, 1));
221-
Allocator::Init(pools.Topology());
218+
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
219+
Allocator::Init(topology);
220+
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
222221
ModelConfig config = TestConfig();
223222
WeightsWrapper<float> weights(config);
224223
WeightsWrapper<float> grad(config);

backprop/optimize_test.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
namespace gcpp {
4242

4343
TEST(OptimizeTest, GradientDescent) {
44-
NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
45-
BoundedSlice(0, 1));
46-
Allocator::Init(pools.Topology());
44+
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
45+
Allocator::Init(topology);
46+
NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
47+
MatMulEnv env(topology, pools);
4748
hwy::ThreadPool& pool = pools.Pool();
4849
std::mt19937 gen(42);
4950

@@ -66,7 +67,7 @@ TEST(OptimizeTest, GradientDescent) {
6667
config.layer_configs[0].qkv_dim,
6768
config.layer_configs[0].post_qk == PostQKType::HalfRope);
6869

69-
Gemma gemma(GemmaTokenizer(), info, pools);
70+
Gemma gemma(GemmaTokenizer(), info, env);
7071

7172
const auto generate = [&](const std::vector<int>& prompt) {
7273
std::vector<int> reply;

compression/blob_compare.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,9 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
202202
if (!CompareKeys(reader1, reader2)) return;
203203

204204
// Single allocation, avoid initializing the memory.
205-
NestedPools pools(0);
206-
Allocator::Init(pools.Topology());
205+
BoundedTopology topology;
206+
Allocator::Init(topology);
207+
NestedPools pools(topology);
207208
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
208209
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
209210
size_t pos = 0;

evals/benchmark_helper.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
5656

5757
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
5858
const AppArgs& app)
59-
: pools_(CreatePools(app)) {
60-
Allocator::Init(pools_.Topology());
59+
: topology_(CreateTopology(app)),
60+
pools_(CreatePools(topology_, app)),
61+
env_(topology_, pools_) {
6162
InferenceArgs mutable_inference = inference;
6263
AbortIfInvalidArgs(mutable_inference);
6364
LoaderArgs mutable_loader = loader;
@@ -66,7 +67,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
6667
fprintf(stderr, "Skipping model load because: %s\n", err);
6768
} else {
6869
fprintf(stderr, "Loading model...\n");
69-
model_ = AllocateGemma(mutable_loader, pools_);
70+
model_ = AllocateGemma(mutable_loader, env_);
7071
// Only allocate one for starters because GenerateBatch might not be called.
7172
kv_caches_.resize(1);
7273
kv_caches_[0] = KVCache::Create(model_->GetModelConfig(),
@@ -236,7 +237,7 @@ std::string CacheString() {
236237
}
237238

238239
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
239-
NestedPools& pools) {
240+
const BoundedTopology& topology, NestedPools& pools) {
240241
loader.Print(app.verbosity);
241242
inference.Print(app.verbosity);
242243
app.Print(app.verbosity);
@@ -255,7 +256,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
255256
"Compiled config : %s\n"
256257
"Weight Type : %s\n"
257258
"EmbedderInput Type : %s\n",
258-
dt, cpu100, pools.TopologyString(), pools.PinString(),
259+
dt, cpu100, topology.TopologyString(), pools.PinString(),
259260
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
260261
hwy::VectorBytes() * 8, CompiledConfig(),
261262
StringFromType(loader.Info().weight), TypeName<EmbedderInputT>());

evals/benchmark_helper.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <vector>
2525

2626
#include "gemma/gemma.h"
27+
#include "ops/matmul.h"
2728
#include "util/app.h"
2829
#include "util/threading.h"
2930
#include "hwy/base.h"
@@ -105,23 +106,20 @@ class GemmaEnv {
105106
KVCache& MutableKVCache() { return kv_caches_[0]; }
106107

107108
private:
108-
// Thread pool for running inference.
109-
NestedPools pools_;
110-
// Random number generator.
111-
std::mt19937 gen_;
112-
// The model to run inference on.
109+
BoundedTopology topology_;
110+
NestedPools pools_; // Thread pool.
111+
MatMulEnv env_;
112+
std::mt19937 gen_; // Random number generator.
113113
std::unique_ptr<Gemma> model_;
114-
// KV caches, same number as query batch.
115-
std::vector<KVCache> kv_caches_;
116-
// Runtime config for inference.
114+
std::vector<KVCache> kv_caches_; // Same number as query batch.
117115
RuntimeConfig runtime_config_;
118116
};
119117

120118
// Logs the inference speed in tokens/sec.
121119
void LogSpeedStats(double time_start, size_t total_tokens);
122120

123121
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
124-
NestedPools& pools);
122+
const BoundedTopology& topology, NestedPools& pools);
125123
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
126124

127125
} // namespace gcpp

examples/hello_world/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ cc_binary(
1313
# Placeholder for internal dep, do not remove.,
1414
"//:app",
1515
"//:args",
16-
"//:common",
1716
"//:gemma_lib",
1817
"//:threading",
1918
"//:tokenizer",

examples/hello_world/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ project(hello_world)
1717
set(CMAKE_CXX_STANDARD_REQUIRED ON)
1818

1919
include(FetchContent)
20-
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995)
20+
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06)
2121
FetchContent_MakeAvailable(highway)
2222
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
2323
FetchContent_MakeAvailable(sentencepiece)

0 commit comments

Comments
 (0)