Skip to content

Commit 28c46fb

Browse files
authored
minor tweaks to the transformer example (#3048)
* minor tweaks to the transformer example * actually take advantadge of using namespace std;
1 parent 8fdd2a6 commit 28c46fb

File tree

3 files changed

+47
-44
lines changed

3 files changed

+47
-44
lines changed

examples/slm_basic_train_ex.cpp

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949

5050
// ----------------------------------------------------------------------------------------
5151

52+
using namespace std;
53+
using namespace dlib;
54+
5255
// We treat each character as a token ID in [0..255].
5356
const int MAX_TOKEN_ID = 255;
5457
const int PAD_TOKEN = 256; // an extra "pad" token if needed
@@ -66,13 +69,13 @@ std::vector<int> char_based_tokenize(const std::string& text)
6669
}
6770

6871
// Function to shuffle samples and labels in sync
69-
void shuffle_samples_and_labels(std::vector<dlib::matrix<int, 0, 1>>& samples, std::vector<unsigned long>& labels) {
72+
void shuffle_samples_and_labels(std::vector<matrix<int, 0, 1>>& samples, std::vector<unsigned long>& labels) {
7073
std::vector<size_t> indices(samples.size());
7174
std::iota(indices.begin(), indices.end(), 0); // Fill with 0, 1, 2, ..., N-1
7275
std::shuffle(indices.begin(), indices.end(), std::default_random_engine{});
7376

7477
// Create temporary vectors to hold shuffled data
75-
std::vector<dlib::matrix<int, 0, 1>> shuffled_samples(samples.size());
78+
std::vector<matrix<int, 0, 1>> shuffled_samples(samples.size());
7679
std::vector<unsigned long> shuffled_labels(labels.size());
7780

7881
// Apply the shuffle
@@ -93,15 +96,15 @@ int main(int argc, char** argv)
9396
{
9497
try
9598
{
96-
dlib::command_line_parser parser;
99+
command_line_parser parser;
97100
parser.add_option("train", "Train a small transformer on the built-in Shakespeare text");
98101
parser.add_option("generate", "Generate text from a previously trained model (needs shakespeare_prompt)");
99102
parser.add_option("learning-rate", "Set the learning rate for training (default: 1e-4)", 1);
100103
parser.add_option("batch-size", "Set the mini-batch size for training (default: 64)", 1);
101104
parser.add_option("generation-length", "Set the length of generated text (default: 400)", 1);
102-
parser.add_option("alpha", "Set the initial learning rate for Adam optimizer (default: 0.004)", 1);
103-
parser.add_option("beta1", "Set the decay rate for the first moment estimate (default: 0.9)", 1);
104-
parser.add_option("beta2", "Set the decay rate for the second moment estimate (default: 0.999)", 1);
105+
parser.add_option("alpha", "Set the weight decay for Adam optimizer (default: 0.004)", 1);
106+
parser.add_option("beta1", "Set the first moment coefficient (default: 0.9)", 1);
107+
parser.add_option("beta2", "Set the second moment coefficient (default: 0.999)", 1);
105108
parser.add_option("max-samples", "Set the maximum number of training samples (default: 50000)", 1);
106109
parser.add_option("shuffle", "Shuffle training sequences and labels before training (default: false)");
107110
parser.parse(argc, argv);
@@ -122,7 +125,7 @@ int main(int argc, char** argv)
122125
const size_t max_samples = get_option(parser, "max-samples",50000); // Default maximum number of training samples
123126

124127
// We define a minimal config for demonstration
125-
const long vocab_size = 257; // 0..255 for chars + 1 pad token
128+
const long vocab_size = MAX_TOKEN_ID + 1 + 1; // 256 for chars + 1 pad token
126129
const long num_layers = 3;
127130
const long num_heads = 4;
128131
const long embedding_dim = 64;
@@ -136,8 +139,8 @@ int main(int argc, char** argv)
136139
embedding_dim,
137140
max_seq_len,
138141
use_squeezing,
139-
dlib::gelu,
140-
dlib::dropout_10
142+
gelu,
143+
dropout_10
141144
>;
142145

143146
// For GPU usage (if any), set gpus = {0} for a single GPU, etc.
@@ -151,7 +154,7 @@ int main(int argc, char** argv)
151154
// ----------------------------------------------------------------------------------------
152155
if (parser.option("train"))
153156
{
154-
std::cout << "=== TRAIN MODE ===\n";
157+
cout << "=== TRAIN MODE ===\n";
155158

156159
// 1) Prepare training data (simple approach)
157160
// We will store characters from shakespeare_text into a vector
@@ -160,7 +163,7 @@ int main(int argc, char** argv)
160163
auto full_tokens = char_based_tokenize(shakespeare_text);
161164
if (full_tokens.empty())
162165
{
163-
std::cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n";
166+
cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n";
164167
return 0;
165168
}
166169

@@ -170,18 +173,18 @@ int main(int argc, char** argv)
170173
: 0;
171174

172175
// Display the size of the training text and the number of sequences
173-
std::cout << "Training text size: " << full_tokens.size() << " characters\n";
174-
std::cout << "Maximum number of sequences: " << max_sequences << "\n";
176+
cout << "Training text size: " << full_tokens.size() << " characters\n";
177+
cout << "Maximum number of sequences: " << max_sequences << "\n";
175178

176179
// Check if the text is too short
177180
if (max_sequences == 0)
178181
{
179-
std::cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least "
182+
cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least "
180183
<< (max_seq_len + 1) << " characters.\n";
181184
return 0;
182185
}
183186

184-
std::vector<dlib::matrix<int, 0, 1>> samples;
187+
std::vector<matrix<int, 0, 1>> samples;
185188
std::vector<unsigned long> labels;
186189

187190
// Let's create a training set of about (N) samples from the text
@@ -190,7 +193,7 @@ int main(int argc, char** argv)
190193
const size_t N = (max_sequences < max_samples) ? max_sequences : max_samples;
191194
for (size_t start = 0; start < N; ++start)
192195
{
193-
dlib::matrix<int, 0, 1> seq(max_seq_len, 1);
196+
matrix<int, 0, 1> seq(max_seq_len, 1);
194197
for (long t = 0; t < max_seq_len; ++t)
195198
seq(t, 0) = full_tokens[start + t];
196199
samples.push_back(seq);
@@ -200,18 +203,18 @@ int main(int argc, char** argv)
200203
// Shuffle samples and labels if the --shuffle option is enabled
201204
if (parser.option("shuffle"))
202205
{
203-
std::cout << "Shuffling training sequences and labels...\n";
206+
cout << "Shuffling training sequences and labels...\n";
204207
shuffle_samples_and_labels(samples, labels);
205208
}
206209

207210
// 3) Construct the network in training mode
208211
using net_type = my_transformer_cfg::network_type<true>;
209212
net_type net;
210-
if (dlib::file_exists(model_file))
211-
dlib::deserialize(model_file) >> net;
213+
if (file_exists(model_file))
214+
deserialize(model_file) >> net;
212215

213216
// 4) Create dnn_trainer
214-
dlib::dnn_trainer<net_type, dlib::adam> trainer(net, dlib::adam(alpha, beta1, beta2), gpus);
217+
dnn_trainer<net_type, adam> trainer(net, adam(alpha, beta1, beta2), gpus);
215218
trainer.set_learning_rate(learning_rate);
216219
trainer.set_min_learning_rate(1e-6);
217220
trainer.set_mini_batch_size(batch_size);
@@ -229,41 +232,41 @@ int main(int argc, char** argv)
229232
if (predicted[i] == labels[i])
230233
correct++;
231234
double accuracy = (double)correct / labels.size();
232-
std::cout << "Training accuracy (on this sample set): " << accuracy << "\n";
235+
cout << "Training accuracy (on this sample set): " << accuracy << "\n";
233236

234237
// 7) Save the model
235238
net.clean();
236-
dlib::serialize(model_file) << net;
237-
std::cout << "Model saved to " << model_file << "\n";
239+
serialize(model_file) << net;
240+
cout << "Model saved to " << model_file << "\n";
238241
}
239242

240243
// ----------------------------------------------------------------------------------------
241244
// Generate mode
242245
// ----------------------------------------------------------------------------------------
243246
if (parser.option("generate"))
244247
{
245-
std::cout << "=== GENERATE MODE ===\n";
248+
cout << "=== GENERATE MODE ===\n";
246249
// 1) Load the trained model
247250
using net_infer = my_transformer_cfg::network_type<false>;
248251
net_infer net;
249-
if (dlib::file_exists(model_file))
252+
if (file_exists(model_file))
250253
{
251-
dlib::deserialize(model_file) >> net;
252-
std::cout << "Loaded model from " << model_file << "\n";
254+
deserialize(model_file) >> net;
255+
cout << "Loaded model from " << model_file << "\n";
253256
}
254257
else
255258
{
256-
std::cerr << "Error: model file not found. Please run --train first.\n";
259+
cerr << "Error: model file not found. Please run --train first.\n";
257260
return 0;
258261
}
259-
std::cout << my_transformer_cfg::model_info::describe() << std::endl;
260-
std::cout << "Model parameters: " << count_parameters(net) << std::endl << std::endl;
262+
cout << my_transformer_cfg::model_info::describe() << endl;
263+
cout << "Model parameters: " << count_parameters(net) << endl << endl;
261264

262265
// 2) Get the prompt from the included slm_data.h
263266
std::string prompt_text = shakespeare_prompt;
264267
if (prompt_text.empty())
265268
{
266-
std::cerr << "No prompt found in slm_data.h.\n";
269+
cerr << "No prompt found in slm_data.h.\n";
267270
return 0;
268271
}
269272
// If prompt is longer than max_seq_len, we keep only the first window
@@ -274,7 +277,7 @@ int main(int argc, char** argv)
274277
const auto prompt_tokens = char_based_tokenize(prompt_text);
275278

276279
// Put into a dlib matrix
277-
dlib::matrix<int, 0, 1> input_seq(max_seq_len, 1);
280+
matrix<int, 0, 1> input_seq(max_seq_len, 1);
278281
// Fill with pad if prompt is shorter than max_seq_len
279282
for (long i = 0; i < max_seq_len; ++i)
280283
{
@@ -284,7 +287,7 @@ int main(int argc, char** argv)
284287
input_seq(i, 0) = PAD_TOKEN;
285288
}
286289

287-
std::cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;
290+
cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;
288291

289292
// 3) Generate new text
290293
// We'll predict one character at a time, then shift the window
@@ -293,22 +296,22 @@ int main(int argc, char** argv)
293296
const int next_char = net(input_seq); // single inference
294297

295298
// Print the generated character
296-
std::cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << std::flush;
299+
cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << flush;
297300

298301
// Shift left by 1
299302
for (long i = 0; i < max_seq_len - 1; ++i)
300303
input_seq(i, 0) = input_seq(i + 1, 0);
301304
input_seq(max_seq_len - 1, 0) = std::min(next_char, MAX_TOKEN_ID);
302305
}
303306

304-
std::cout << "\n\n(end of generation)\n";
307+
cout << "\n\n(end of generation)\n";
305308
}
306309

307310
return 0;
308311
}
309-
catch (std::exception& e)
312+
catch (exception& e)
310313
{
311-
std::cerr << "Exception thrown: " << e.what() << std::endl;
314+
cerr << "Exception thrown: " << e.what() << endl;
312315
return 1;
313316
}
314317
}

examples/slm_data.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <algorithm>
77

88
// Utility function to concatenate text parts
9-
std::string concatenateTexts(const std::vector<std::string>& texts) {
9+
inline std::string concatenateTexts(const std::vector<std::string>& texts) {
1010
std::string result;
1111
for (const auto& text : texts) {
1212
result += text;
@@ -590,4 +590,4 @@ And you shall understand from me her mind.
590590
591591
)";
592592

593-
#endif // SlmData_H
593+
#endif // SlmData_H

examples/slm_defs.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,11 @@ namespace transformer
214214
template<bool is_training>
215215
using network_type = std::conditional_t<is_training,
216216
classification_head<USE_SQUEEZING, activation_func, VOCAB_SIZE, EMBEDDING_DIM,
217-
repeat<NUM_LAYERS, t_transformer_block,
218-
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>,
217+
repeat<NUM_LAYERS, t_transformer_block,
218+
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>,
219219
classification_head<USE_SQUEEZING, activation_func, VOCAB_SIZE, EMBEDDING_DIM,
220-
repeat<NUM_LAYERS, i_transformer_block,
221-
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>
220+
repeat<NUM_LAYERS, i_transformer_block,
221+
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>
222222
>;
223223

224224
/**
@@ -283,4 +283,4 @@ namespace transformer
283283
*/
284284
}
285285

286-
#endif // SlmNet_H
286+
#endif // SlmNet_H

0 commit comments

Comments
 (0)