|
18 | 18 | #ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ |
19 | 19 | #define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ |
20 | 20 |
|
21 | | -#include <cstddef> |
| 21 | +#include <stddef.h> |
22 | 22 |
|
23 | 23 | namespace gcpp { |
24 | 24 |
|
25 | 25 | static constexpr size_t kSeqLen = 7168; |
26 | 26 |
|
27 | 27 | struct ConfigGemma7B { |
28 | | - // NOLINTBEGIN(google3-readability-class-member-naming) |
29 | | - static constexpr int seq_len = kSeqLen; |
30 | | - static constexpr int vocab_size = 256128; |
31 | | - static constexpr int n_layers = 28; |
32 | | - static constexpr int dim_model = 3072; |
33 | | - static constexpr int dim_ffw_hidden = 16 * 3072 / 2; // = 24576 |
34 | | - static constexpr int n_heads = 16; |
35 | | - static constexpr int n_kv_heads = 16; // standard MHA, no GQA or MQA |
36 | | - static constexpr int dim_qkv = 256; // query size == key size == value size |
37 | | - static constexpr int top_k = 1; |
38 | | - // NOLINTEND(google3-readability-class-member-naming) |
| 28 | + static constexpr int kSeqLen = gcpp::kSeqLen; |
| 29 | + static constexpr int kVocabSize = 256128; |
| 30 | + static constexpr int kLayers = 28; |
| 31 | + static constexpr int kModelDim = 3072; |
| 32 | + static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 |
| 33 | + static constexpr int kHeads = 16; |
| 34 | + static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA |
| 35 | + static constexpr int kQKVDim = 256; // query size == key size == value size |
| 36 | + static constexpr int kTopK = 1; |
39 | 37 | }; |
40 | 38 |
|
41 | 39 | struct ConfigGemma2B { |
42 | | - // NOLINTBEGIN(google3-readability-class-member-naming) |
43 | | - static constexpr int seq_len = kSeqLen; |
44 | | - static constexpr int vocab_size = 256128; |
45 | | - static constexpr int n_layers = 18; |
46 | | - static constexpr int dim_model = 2048; |
47 | | - static constexpr int dim_ffw_hidden = 16 * 2048 / 2; // = 16384 |
48 | | - static constexpr int n_heads = 8; |
49 | | - static constexpr int n_kv_heads = 8; // TODO(austinvhuang): add MQA support |
50 | | - static constexpr int dim_qkv = 256; // query size == key size == value size |
51 | | - static constexpr int top_k = 1; |
52 | | - // NOLINTEND(google3-readability-class-member-naming) |
| 40 | + static constexpr int kSeqLen = gcpp::kSeqLen; |
| 41 | + static constexpr int kVocabSize = 256128; |
| 42 | + static constexpr int kLayers = 18; |
| 43 | + static constexpr int kModelDim = 2048; |
| 44 | + static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 |
| 45 | + static constexpr int kHeads = 8; |
| 46 | + static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support |
| 47 | + static constexpr int kQKVDim = 256; // query size == key size == value size |
| 48 | + static constexpr int kTopK = 1; |
53 | 49 | }; |
54 | 50 |
|
55 | 51 | } // namespace gcpp |
|
0 commit comments