|
6 | 6 | #include <regex> |
7 | 7 | #include <iostream> |
8 | 8 | #include <iterator> |
| 9 | +#include <queue> |
9 | 10 | #include <string> |
10 | 11 | #include <math.h> |
11 | 12 |
|
@@ -294,58 +295,146 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri |
294 | 295 | return tokens; |
295 | 296 | } |
296 | 297 |
|
297 | | -// TODO: Calculate this constant from the vocabulary |
298 | | -#define MAX_TOKEN_LEN 18 |
299 | | -// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece |
300 | | -std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) { |
301 | | - std::vector<gpt_vocab::id> res; |
302 | | - std::vector<int> score; |
303 | | - std::vector<gpt_vocab::id> prev; |
304 | | - int len = text.length(); |
305 | | - |
306 | | - score.resize(len + 1); |
307 | | - prev.resize(len + 1); |
308 | | - |
309 | | - // Forward pass |
310 | | - for (int i = 0; i < len; i++) { |
311 | | - int max_len = std::min(len - i, MAX_TOKEN_LEN); |
312 | | - for (int sub_len = 1; sub_len <= max_len; sub_len++) { |
313 | | - auto sub = text.substr(i, sub_len); |
314 | | - auto token = vocab.token_to_id.find(sub); |
315 | | - if (token != vocab.token_to_id.end()) { |
316 | | - int token_score = sub.length() * sub.length(); |
317 | | - int local_score = score[i] + token_score; |
318 | | - int next = i + sub_len; |
319 | | - if (score[next] < local_score) { |
320 | | - score[next] = local_score; |
321 | | - prev[next] = (*token).second; |
| 298 | +static size_t utf8_len(char src) { |
| 299 | + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; |
| 300 | + uint8_t highbits = static_cast<uint8_t>(src) >> 4; |
| 301 | + return lookup[highbits]; |
| 302 | +} |
| 303 | + |
| 304 | +struct llama_sp_symbol { |
| 305 | + using index = int; |
| 306 | + index prev; |
| 307 | + index next; |
| 308 | + std::string_view text; |
| 309 | +}; |
| 310 | + |
| 311 | +struct llama_sp_bigram { |
| 312 | + struct comparator { |
| 313 | + bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) { |
| 314 | + return (l.score < r.score) || (l.score == r.score && l.left > r.left); |
| 315 | + } |
| 316 | + }; |
| 317 | + using queue_storage = std::vector<llama_sp_bigram>; |
| 318 | + using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>; |
| 319 | + llama_sp_symbol::index left; |
| 320 | + llama_sp_symbol::index right; |
| 321 | + float score; |
| 322 | + size_t size; |
| 323 | +}; |
| 324 | + |
| 325 | +struct llama_tokenizer { |
| 326 | + llama_tokenizer(const gpt_vocab & vocab): vocab_(vocab) {} |
| 327 | + |
| 328 | + void tokenize(std::string_view text, std::vector<gpt_vocab::id> & output) { |
| 329 | + // split string into utf8 chars |
| 330 | + int index = 0; |
| 331 | + while (!text.empty()) { |
| 332 | + llama_sp_symbol sym; |
| 333 | + size_t char_len = std::min(text.size(), utf8_len(text.data()[0])); |
| 334 | + sym.text = std::string_view(text.data(), char_len); |
| 335 | + sym.prev = index - 1; |
| 336 | + text.remove_prefix(char_len); |
| 337 | + sym.next = text.empty() ? -1 : index + 1; |
| 338 | + index++; |
| 339 | + symbols_.emplace_back(std::move(sym)); |
| 340 | + } |
| 341 | + |
| 342 | + // seed the work queue with all possible 2-character tokens. |
| 343 | + for (size_t i = 1; i < symbols_.size(); ++i) { |
| 344 | + try_add_bigram(i - 1, i); |
| 345 | + } |
| 346 | + |
| 347 | + // keep substituting the highest frequency pairs for as long as we can. |
| 348 | + while (!work_queue_.empty()) { |
| 349 | + auto bigram = work_queue_.top(); |
| 350 | + work_queue_.pop(); |
| 351 | + |
| 352 | + auto & left_sym = symbols_[bigram.left]; |
| 353 | + auto & right_sym = symbols_[bigram.right]; |
| 354 | + |
| 355 | + // if one of the symbols already got merged, skip it. |
| 356 | + if (left_sym.text.empty() || right_sym.text.empty() || |
| 357 | + left_sym.text.size() + right_sym.text.size() != bigram.size) { |
| 358 | + continue; |
| 359 | + } |
| 360 | + |
| 361 | + // merge the right sym into the left one |
| 362 | + left_sym.text = std::string_view(left_sym.text.data(), left_sym.text.size() + right_sym.text.size()); |
| 363 | + right_sym.text = std::string_view(""); |
| 364 | + |
| 365 | + // remove the right sym from the chain |
| 366 | + left_sym.next = right_sym.next; |
| 367 | + if (right_sym.next >= 0) { |
| 368 | + symbols_[right_sym.next].prev = bigram.left; |
| 369 | + } |
| 370 | + |
| 371 | + // find more substitutions |
| 372 | + try_add_bigram(left_sym.prev, bigram.left); |
| 373 | + try_add_bigram(bigram.left, left_sym.next); |
| 374 | + } |
| 375 | + |
| 376 | + for (int i = 0; i != -1; i = symbols_[i].next) { |
| 377 | + auto& symbol = symbols_[i]; |
| 378 | + auto token = vocab_.token_to_id.find(std::string(symbol.text)); |
| 379 | + |
| 380 | + if (token == vocab_.token_to_id.end()) { |
| 381 | + // output any symbols that did not form tokens as bytes. |
| 382 | + for (int j = 0; j < symbol.text.size(); ++j) { |
| 383 | + gpt_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3; |
| 384 | + output.push_back(token_id); |
322 | 385 | } |
| 386 | + } else { |
| 387 | + output.push_back((*token).second); |
323 | 388 | } |
324 | 389 | } |
325 | 390 | } |
326 | 391 |
|
327 | | - // Backward pass |
328 | | - int i = len; |
329 | | - while (i > 0) { |
330 | | - gpt_vocab::id token_id = prev[i]; |
331 | | - if (token_id == 0) { |
332 | | - // TODO: Return error or something more meaningful |
333 | | - printf("failed to tokenize string!\n"); |
334 | | - break; |
| 392 | +private: |
| 393 | + void try_add_bigram(int left, int right) { |
| 394 | + if (left == -1 || right == -1) { |
| 395 | + return; |
| 396 | + } |
| 397 | + |
| 398 | + std::string_view text(symbols_[left].text.data(), symbols_[left].text.size() + symbols_[right].text.size()); |
| 399 | + auto token = vocab_.token_to_id.find(std::string(text)); |
| 400 | + |
| 401 | + if (token == vocab_.token_to_id.end()) { |
| 402 | + return; |
335 | 403 | } |
336 | | - res.push_back(token_id); |
337 | | - auto token = (*vocab.id_to_token.find(token_id)).second; |
338 | | - i -= token.length(); |
| 404 | + |
| 405 | + auto score = vocab_.score.find((*token).second); |
| 406 | + |
| 407 | + if (score == vocab_.score.end()) { |
| 408 | + return; |
| 409 | + } |
| 410 | + |
| 411 | + llama_sp_bigram bigram; |
| 412 | + bigram.left = left; |
| 413 | + bigram.right = right; |
| 414 | + bigram.score = (*score).second; |
| 415 | + bigram.size = text.size(); |
| 416 | + work_queue_.push(bigram); |
339 | 417 | } |
340 | 418 |
|
341 | | - if (bos) { |
342 | | - res.push_back(1); // TODO: replace with vocab.bos |
| 419 | + const gpt_vocab & vocab_; |
| 420 | + std::vector<llama_sp_symbol> symbols_; |
| 421 | + llama_sp_bigram::queue work_queue_; |
| 422 | +}; |
| 423 | + |
| 424 | +std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos) { |
| 425 | + llama_tokenizer tokenizer(vocab); |
| 426 | + std::vector<gpt_vocab::id> output; |
| 427 | + |
| 428 | + if (text.size() == 0) { |
| 429 | + return output; |
343 | 430 | } |
344 | 431 |
|
345 | | - // Pieces are in reverse order so correct that |
346 | | - std::reverse(res.begin(), res.end()); |
| 432 | + if (bos) { |
| 433 | + output.push_back(1); |
| 434 | + } |
347 | 435 |
|
348 | | - return res; |
| 436 | + tokenizer.tokenize(text, output); |
| 437 | + return output; |
349 | 438 | } |
350 | 439 |
|
351 | 440 | bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { |
|
0 commit comments