Skip to content

Commit 4635a86

Browse files
authored
Tokenizer's Interfaces Cleanup (#7001)
* Tokenizer's Interfaces Cleanup * Address the feedback * Optimization
1 parent 64523e8 commit 4635a86

File tree

11 files changed

+470
-226
lines changed

11 files changed

+470
-226
lines changed

src/Microsoft.ML.Tokenizers/Model/BPE.cs

+58-41
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
9595

9696
(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadFile(vocabFile, mergesFile);
9797
Vocab = vocab1 ?? new Dictionary<string, int>();
98+
Cache = new Cache<string, Word>();
9899

99100
VocabReverse = new();
100101

@@ -146,23 +147,33 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
146147
/// Tokenize a sequence string to a list of tokens.
147148
/// </summary>
148149
/// <param name="sequence">The sequence to tokenize.</param>
150+
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
149151
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
150-
public override IReadOnlyList<Token> Tokenize(string sequence)
152+
public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken = false)
151153
{
152154
if (sequence.Length == 0)
153155
{
154156
return EmptyTokensList;
155157
}
156158

157-
if (!Dropout.HasValue)
158-
{
159-
return TokenizeWithCache(sequence);
160-
}
159+
return TokenizeWithCache(sequence);
160+
}
161161

162-
Word word = MergeWord(sequence);
162+
/// <summary>
163+
/// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
164+
/// </summary>
165+
/// <param name="sequence">The sequence to split.</param>
166+
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
167+
/// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
168+
public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds) => TokenizeToIdsWithCache(sequence, accumulatedIds);
163169

164-
return WordToTokens(ref word);
165-
}
170+
/// <summary>
171+
/// Get the number of tokens that the input sequence will be encoded to.
172+
/// </summary>
173+
/// <param name="sequence">The text to tokenize.</param>
174+
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
175+
/// <returns>The number of tokens that the input sequence will be encoded to.</returns>
176+
public override int CountTokens(string sequence, bool isSpecialToken) => TokenizeToIdsWithCache(sequence, null);
166177

167178
/// <summary>
168179
/// Map the token to tokenized Id.
@@ -195,14 +206,6 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
195206
return null;
196207
}
197208

198-
/// <summary>
199-
/// Map the tokenized Id to the token.
200-
/// </summary>
201-
/// <param name="id">The Id to map to the token.</param>
202-
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
203-
/// <returns>The mapped token of the Id.</returns>
204-
public override string? IdToString(int id, bool skipSpecialTokens = false) => throw new NotImplementedException();
205-
206209
/// <summary>
207210
/// Gets the dictionary mapping tokens to Ids.
208211
/// </summary>
@@ -332,7 +335,7 @@ internal string CharToString(char c)
332335

333336
internal Word MergeWord(string w)
334337
{
335-
Word word = Word.WithCapacity((int)w.Length);
338+
Word word = Word.WithCapacity(w.Length);
336339
(int Id, int Len)? unk = null;
337340
int i = 0;
338341

@@ -344,7 +347,7 @@ internal Word MergeWord(string w)
344347
if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
345348
{
346349
length = 2;
347-
s = w.Substring(i, (int)length);
350+
s = w.Substring(i, length);
348351
}
349352
else
350353
{
@@ -403,7 +406,7 @@ internal Word MergeWord(string w)
403406
}
404407
}
405408

406-
i += (int)length;
409+
i += length;
407410
}
408411

409412
if (unk.HasValue)
@@ -415,45 +418,59 @@ internal Word MergeWord(string w)
415418
return word;
416419
}
417420

418-
// internal Word.Enumerator WordToTokens(Word word) => word.GetIterator(VocabReverse);
419-
internal List<Token> WordToTokens(ref Word word)
421+
internal List<Token> WordToTokens(ref Word word) => word.ToTokens(VocabReverse);
422+
423+
internal List<Token> TokenizeWithCache(string sequence)
420424
{
421-
List<Token> tokens = new(word.SymbolsCount);
425+
Word word;
426+
if (Cache is not null)
427+
{
428+
if (Cache.TryGet(sequence, out word))
429+
{
430+
return WordToTokens(ref word);
431+
}
422432

423-
foreach (Token token in word.GetIterator(VocabReverse))
433+
word = MergeWord(sequence);
434+
Cache.Set(sequence, word);
435+
}
436+
else
424437
{
425-
tokens.Add(token);
438+
word = MergeWord(sequence);
426439
}
427440

428-
return tokens;
441+
return WordToTokens(ref word);
429442
}
430443

431-
internal List<Token> TokenizeWithCache(string sequence)
444+
internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
432445
{
433-
if (Cache is not null)
446+
if (accumulatedIds is not null)
434447
{
435-
Word? hit = Cache.Get(sequence);
436-
if (hit.HasValue)
437-
{
438-
Word w = hit.Value;
439-
return WordToTokens(ref w);
440-
}
448+
word.PopulateIds(accumulatedIds);
441449
}
442450

443-
Word word = MergeWord(sequence);
444-
List<Token> tokens = WordToTokens(ref word);
451+
return word.SymbolsCount;
452+
}
453+
454+
internal int TokenizeToIdsWithCache(string sequence, IList<int>? accumulatedIds)
455+
{
456+
Word word;
445457

446458
if (Cache is not null)
447459
{
460+
if (Cache.TryGet(sequence, out Word hit))
461+
{
462+
return WordToIds(ref hit, accumulatedIds);
463+
}
464+
465+
word = MergeWord(sequence);
448466
Cache.Set(sequence, word);
449467
}
468+
else
469+
{
470+
word = MergeWord(sequence);
471+
}
450472

451-
return tokens;
452-
}
453-
454-
public override bool IsValidChar(char ch)
455-
{
456-
throw new NotImplementedException();
473+
return WordToIds(ref word, accumulatedIds);
457474
}
458475

459476
internal static readonly List<Token> EmptyTokensList = new();

src/Microsoft.ML.Tokenizers/Model/Cache.cs

+7-12
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
namespace Microsoft.ML.Tokenizers
1111
{
12-
internal sealed class Cache<TKey, TValue> where TKey : notnull
12+
internal sealed class Cache<TKey, TValue> where TKey : notnull where TValue : notnull
1313
{
1414
internal Cache() : this(Bpe.DefaultCacheCapacity) { }
1515

1616
internal Cache(int capacity)
1717
{
1818
Capacity = capacity;
19-
Map = new Dictionary<TKey, TValue>((int)Capacity);
19+
Map = new Dictionary<TKey, TValue>(Capacity);
2020
}
2121

2222
private readonly ReaderWriterLockSlim _cacheLock = new ReaderWriterLockSlim();
@@ -25,7 +25,7 @@ internal Cache(int capacity)
2525

2626
internal int Capacity { get; set; }
2727

28-
internal void Fresh() => Map = new Dictionary<TKey, TValue>((int)Capacity);
28+
internal void Fresh() => Map = new Dictionary<TKey, TValue>(Capacity);
2929

3030
internal void Clear()
3131
{
@@ -56,27 +56,22 @@ internal List<TValue> GetValues(IEnumerable<TKey> keys)
5656
return values;
5757
}
5858

59-
internal TValue? Get(TKey key)
59+
internal bool TryGet(TKey key, out TValue value)
6060
{
6161
_cacheLock.EnterReadLock();
6262
try
6363
{
64-
if (Map.TryGetValue(key, out TValue? value))
65-
{
66-
return value;
67-
}
64+
return Map.TryGetValue(key, out value!);
6865
}
6966
finally { _cacheLock.ExitReadLock(); }
70-
71-
return default;
7267
}
7368

74-
internal void SetValues(IEnumerable<(TKey, TValue)> enteries)
69+
internal void SetValues(IEnumerable<(TKey, TValue)> entries)
7570
{
7671
_cacheLock.EnterWriteLock();
7772
try
7873
{
79-
foreach ((TKey, TValue) entry in enteries)
74+
foreach ((TKey, TValue) entry in entries)
8075
{
8176
if (Capacity <= Map.Count)
8277
{

0 commit comments

Comments
 (0)