Skip to content

Fixes NER to correctly expand/shrink the labels #6928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixes from PR comments
  • Loading branch information
michaelgsharp committed Jan 4, 2024
commit 9fd094c5cbb1eda27b5e48c0f46b74613e97444c
8 changes: 0 additions & 8 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ public sealed class EnglishRoberta : Model
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
private readonly string[] _charToString;
private readonly Cache<string, IReadOnlyList<Token>> _cache;
private const char StartChar = (char)(' ' + 256);

/// <summary>
/// Construct tokenizer object to use with the English Robert model.
Expand Down Expand Up @@ -593,13 +592,6 @@ public override bool IsValidChar(char ch)
{
return _byteToUnicode.ContainsKey(ch);
}

public override bool IsFirstTokenInWord(string token)
{
if (token == null)
return true;
return token.Length != 0 && token[0] == _startChar;
}
}

/// <summary>
Expand Down
8 changes: 0 additions & 8 deletions src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,5 @@ public abstract class Model
/// <param name="ch"></param>
/// <returns></returns>
public abstract bool IsValidChar(char ch);

/// <summary>
/// Returns if the first character of the token is part of the actual word or not.
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
public abstract bool IsFirstTokenInWord(string token);
}

}
4 changes: 3 additions & 1 deletion src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,9 @@ private IList<int> PrepInputTokens(ref ReadOnlyMemory<char> sentence1, ref ReadO
getSentence1(ref sentence1);
if (getSentence2 == default)
{
return InitTokenArray.Concat(tokenizer.EncodeToConverted(sentence1.ToString())).ToList();
List<int> newList = new List<int>(tokenizer.EncodeToConverted(sentence1.ToString()));
newList.Insert(0, 0);
return newList;
}
else
{
Expand Down
10 changes: 7 additions & 3 deletions src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ namespace Microsoft.ML.TorchSharp.NasBert
///
public class NerTrainer : NasBertTrainer<VBuffer<uint>, TargetType>
{
private const char StartChar = (char)(' ' + 256);

public class NerOptions : NasBertOptions
{
public NerOptions()
Expand Down Expand Up @@ -109,6 +111,8 @@ private protected override TorchSharpBaseTransformer<VBuffer<uint>, TargetType>
return new NerTransformer(host, options as NasBertOptions, model as NasBertModel, labelColumn);
}

internal static bool TokenStartsWithSpace(string token) => token is null || (token.Length != 0 && token[0] == StartChar);

private protected class Trainer : NasBertTrainerBase
{
private const string ModelUrlString = "models/pretrained_NasBert_14M_encoder.tsm";
Expand Down Expand Up @@ -172,7 +176,7 @@ private protected override torch.Tensor PrepareRowTensor(ref VBuffer<uint> targe
var newValues = targetEditor.Values;
for (var i = 0; i < encoding.Tokens.Count; i++)
{
if (Tokenizer.Model.IsFirstTokenInWord(encoding.Tokens[i]))
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
{
newValues[i] = target.GetItemOrDefault(++targetIndex);
}
Expand Down Expand Up @@ -382,7 +386,7 @@ private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer
// Figure out actual count of output tokens
for (var i = 0; i < encoding.Tokens.Count; i++)
{
if (tokenizer.Model.IsFirstTokenInWord(encoding.Tokens[i]))
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
{
targetIndex++;
}
Expand All @@ -396,7 +400,7 @@ private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer

for (var i = 1; i < encoding.Tokens.Count; i++)
{
if (tokenizer.Model.IsFirstTokenInWord(encoding.Tokens[i]))
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
{
newValues[targetIndex++] = (uint)prediction[i];
}
Expand Down