@@ -95,6 +95,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
95
95
96
96
( Dictionary < string , int > ? vocab1 , Vec < ( string , string ) > merges ) = ReadFile ( vocabFile , mergesFile ) ;
97
97
Vocab = vocab1 ?? new Dictionary < string , int > ( ) ;
98
+ Cache = new Cache < string , Word > ( ) ;
98
99
99
100
VocabReverse = new ( ) ;
100
101
@@ -146,23 +147,33 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
146
147
/// Tokenize a sequence string to a list of tokens.
147
148
/// </summary>
148
149
/// <param name="sequence">The sequence to tokenize.</param>
150
+ /// <param name="isSpecialToken">Indicate if the token is a special token.</param>
149
151
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
150
- public override IReadOnlyList < Token > Tokenize ( string sequence )
152
+ public override IReadOnlyList < Token > Tokenize ( string sequence , bool isSpecialToken = false )
151
153
{
152
154
if ( sequence . Length == 0 )
153
155
{
154
156
return EmptyTokensList ;
155
157
}
156
158
157
- if ( ! Dropout . HasValue )
158
- {
159
- return TokenizeWithCache ( sequence ) ;
160
- }
159
+ return TokenizeWithCache ( sequence ) ;
160
+ }
161
161
162
- Word word = MergeWord ( sequence ) ;
162
+ /// <summary>
163
+ /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
164
+ /// </summary>
165
+ /// <param name="sequence">The sequence to split.</param>
166
+ /// <param name="isSpecialToken">Indicate if the token is a special token.</param>
167
+ /// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
168
+ public override void TokenizeToIds ( string sequence , bool isSpecialToken , IList < int > accumulatedIds ) => TokenizeToIdsWithCache ( sequence , accumulatedIds ) ;
163
169
164
- return WordToTokens ( ref word ) ;
165
- }
170
+ /// <summary>
171
+ /// Get the number of tokens that the input sequence will be encoded to.
172
+ /// </summary>
173
+ /// <param name="sequence">The text to tokenize.</param>
174
+ /// <param name="isSpecialToken">Indicate if the token is special token.</param>
175
+ /// <returns>The number of tokens that the input sequence will be encoded to.</returns>
176
+ public override int CountTokens ( string sequence , bool isSpecialToken ) => TokenizeToIdsWithCache ( sequence , null ) ;
166
177
167
178
/// <summary>
168
179
/// Map the token to tokenized Id.
@@ -195,14 +206,6 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
195
206
return null ;
196
207
}
197
208
198
- /// <summary>
199
- /// Map the tokenized Id to the token.
200
- /// </summary>
201
- /// <param name="id">The Id to map to the token.</param>
202
- /// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
203
- /// <returns>The mapped token of the Id.</returns>
204
- public override string ? IdToString ( int id , bool skipSpecialTokens = false ) => throw new NotImplementedException ( ) ;
205
-
206
209
/// <summary>
207
210
/// Gets the dictionary mapping tokens to Ids.
208
211
/// </summary>
@@ -332,7 +335,7 @@ internal string CharToString(char c)
332
335
333
336
internal Word MergeWord ( string w )
334
337
{
335
- Word word = Word . WithCapacity ( ( int ) w . Length ) ;
338
+ Word word = Word . WithCapacity ( w . Length ) ;
336
339
( int Id , int Len ) ? unk = null ;
337
340
int i = 0 ;
338
341
@@ -344,7 +347,7 @@ internal Word MergeWord(string w)
344
347
if ( Char . IsHighSurrogate ( w [ i ] ) && i < w . Length - 1 && Char . IsLowSurrogate ( w [ i + 1 ] ) )
345
348
{
346
349
length = 2 ;
347
- s = w . Substring ( i , ( int ) length ) ;
350
+ s = w . Substring ( i , length ) ;
348
351
}
349
352
else
350
353
{
@@ -403,7 +406,7 @@ internal Word MergeWord(string w)
403
406
}
404
407
}
405
408
406
- i += ( int ) length ;
409
+ i += length ;
407
410
}
408
411
409
412
if ( unk . HasValue )
@@ -415,45 +418,59 @@ internal Word MergeWord(string w)
415
418
return word ;
416
419
}
417
420
418
- // internal Word.Enumerator WordToTokens(Word word) => word.GetIterator(VocabReverse);
419
- internal List < Token > WordToTokens ( ref Word word )
421
+ internal List < Token > WordToTokens ( ref Word word ) => word . ToTokens ( VocabReverse ) ;
422
+
423
+ internal List < Token > TokenizeWithCache ( string sequence )
420
424
{
421
- List < Token > tokens = new ( word . SymbolsCount ) ;
425
+ Word word ;
426
+ if ( Cache is not null )
427
+ {
428
+ if ( Cache . TryGet ( sequence , out word ) )
429
+ {
430
+ return WordToTokens ( ref word ) ;
431
+ }
422
432
423
- foreach ( Token token in word . GetIterator ( VocabReverse ) )
433
+ word = MergeWord ( sequence ) ;
434
+ Cache . Set ( sequence , word ) ;
435
+ }
436
+ else
424
437
{
425
- tokens . Add ( token ) ;
438
+ word = MergeWord ( sequence ) ;
426
439
}
427
440
428
- return tokens ;
441
+ return WordToTokens ( ref word ) ;
429
442
}
430
443
431
- internal List < Token > TokenizeWithCache ( string sequence )
444
+ internal int WordToIds ( ref Word word , IList < int > ? accumulatedIds )
432
445
{
433
- if ( Cache is not null )
446
+ if ( accumulatedIds is not null )
434
447
{
435
- Word ? hit = Cache . Get ( sequence ) ;
436
- if ( hit . HasValue )
437
- {
438
- Word w = hit . Value ;
439
- return WordToTokens ( ref w ) ;
440
- }
448
+ word . PopulateIds ( accumulatedIds ) ;
441
449
}
442
450
443
- Word word = MergeWord ( sequence ) ;
444
- List < Token > tokens = WordToTokens ( ref word ) ;
451
+ return word . SymbolsCount ;
452
+ }
453
+
454
+ internal int TokenizeToIdsWithCache ( string sequence , IList < int > ? accumulatedIds )
455
+ {
456
+ Word word ;
445
457
446
458
if ( Cache is not null )
447
459
{
460
+ if ( Cache . TryGet ( sequence , out Word hit ) )
461
+ {
462
+ return WordToIds ( ref hit , accumulatedIds ) ;
463
+ }
464
+
465
+ word = MergeWord ( sequence ) ;
448
466
Cache . Set ( sequence , word ) ;
449
467
}
468
+ else
469
+ {
470
+ word = MergeWord ( sequence ) ;
471
+ }
450
472
451
- return tokens ;
452
- }
453
-
454
- public override bool IsValidChar ( char ch )
455
- {
456
- throw new NotImplementedException ( ) ;
473
+ return WordToIds ( ref word , accumulatedIds ) ;
457
474
}
458
475
459
476
internal static readonly List < Token > EmptyTokensList = new ( ) ;
0 commit comments