Skip to content

Fixes NER to correctly expand/shrink the labels #6928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,5 @@ public abstract class Model
/// <param name="ch"></param>
/// <returns></returns>
public abstract bool IsValidChar(char ch);

}

}
14 changes: 9 additions & 5 deletions src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ private protected override torch.Tensor PrepareBatchTensor(ref List<Tensor> inpu
return DataUtils.CollateTokens(inputTensors, Tokenizer.RobertaModel().PadIndex, device: Device);
}

private protected override torch.Tensor PrepareRowTensor()
private protected override torch.Tensor PrepareRowTensor(ref TLabelCol target)
{
ReadOnlyMemory<char> sentence1 = default;
Sentence1Getter(ref sentence1);
Expand Down Expand Up @@ -494,7 +494,8 @@ private protected abstract class NasBertMapper : TorchSharpBaseMapper

private static readonly FuncInstanceMethodInfo1<NasBertMapper, DataViewSchema.DetachedColumn, Delegate> _makeLabelAnnotationGetter
= FuncInstanceMethodInfo1<NasBertMapper, DataViewSchema.DetachedColumn, Delegate>.Create(target => target.GetLabelAnnotations<int>);

internal static readonly int[] InitTokenArray = new[] { 0 /* InitToken */ };
internal static readonly int[] SeperatorTokenArray = new[] { 2 /* SeperatorToken */ };

public NasBertMapper(TorchSharpBaseTransformer<TLabelCol, TTargetsCol> parent, DataViewSchema inputSchema) :
base(parent, inputSchema)
Expand Down Expand Up @@ -583,13 +584,16 @@ private IList<int> PrepInputTokens(ref ReadOnlyMemory<char> sentence1, ref ReadO
getSentence1(ref sentence1);
if (getSentence2 == default)
{
return new[] { 0 /* InitToken */ }.Concat(tokenizer.EncodeToConverted(sentence1.ToString())).ToList();
List<int> newList = new List<int>(tokenizer.EncodeToConverted(sentence1.ToString()));
// 0 Is the init token and must be at the beginning.
newList.Insert(0, 0);
return newList;
}
else
{
getSentence2(ref sentence2);
return new[] { 0 /* InitToken */ }.Concat(tokenizer.EncodeToConverted(sentence1.ToString()))
.Concat(new[] { 2 /* SeperatorToken */ }).Concat(tokenizer.EncodeToConverted(sentence2.ToString())).ToList();
return InitTokenArray.Concat(tokenizer.EncodeToConverted(sentence1.ToString()))
.Concat(SeperatorTokenArray).Concat(tokenizer.EncodeToConverted(sentence2.ToString())).ToList();
}
}

Expand Down
85 changes: 77 additions & 8 deletions src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using Microsoft.ML.TorchSharp.NasBert.Models;
using TorchSharp;
using static Microsoft.ML.TorchSharp.NasBert.NasBertTrainer;
using static TorchSharp.torch;

[assembly: LoadableClass(typeof(NerTransformer), null, typeof(SignatureLoadModel),
NerTransformer.UserName, NerTransformer.LoaderSignature)]
Expand Down Expand Up @@ -61,6 +62,8 @@ namespace Microsoft.ML.TorchSharp.NasBert
///
public class NerTrainer : NasBertTrainer<VBuffer<uint>, TargetType>
{
private const char StartChar = (char)(' ' + 256);

public class NerOptions : NasBertOptions
{
public NerOptions()
Expand All @@ -69,6 +72,7 @@ public NerOptions()
EncoderOutputDim = 384;
EmbeddingDim = 128;
Arches = new int[] { 15, 16, 14, 0, 0, 0, 15, 16, 14, 0, 0, 0, 17, 14, 15, 0, 0, 0, 17, 14, 15, 0, 0, 0 };
TaskType = BertTaskType.NamedEntityRecognition;
}
}
internal NerTrainer(IHostEnvironment env, NerOptions options) : base(env, options)
Expand All @@ -93,7 +97,6 @@ internal NerTrainer(IHostEnvironment env,
BatchSize = batchSize,
MaxEpoch = maxEpochs,
ValidationSet = validationSet,
TaskType = BertTaskType.NamedEntityRecognition
})
{
}
Expand All @@ -108,9 +111,12 @@ private protected override TorchSharpBaseTransformer<VBuffer<uint>, TargetType>
return new NerTransformer(host, options as NasBertOptions, model as NasBertModel, labelColumn);
}

internal static bool TokenStartsWithSpace(string token) => token is null || (token.Length != 0 && token[0] == StartChar);

private protected class Trainer : NasBertTrainerBase
{
private const string ModelUrlString = "models/pretrained_NasBert_14M_encoder.tsm";
internal static readonly int[] ZeroArray = new int[] { 0 /* InitToken */};

public Trainer(TorchSharpBaseTrainer<VBuffer<uint>, TargetType> parent, IChannel ch, IDataView input) : base(parent, ch, input, ModelUrlString)
{
Expand Down Expand Up @@ -155,6 +161,40 @@ private protected override torch.Tensor CreateTargetsTensor(ref List<TargetType>
return torch.tensor(targetArray, device: Device);
}

private protected override torch.Tensor PrepareRowTensor(ref VBuffer<uint> target)
{
ReadOnlyMemory<char> sentenceRom = default;
Sentence1Getter(ref sentenceRom);
var sentence = sentenceRom.ToString();
Tensor t;
var encoding = Tokenizer.Encode(sentence);

if (target.Length != encoding.Tokens.Count)
{
var targetIndex = 0;
var targetEditor = VBufferEditor.Create(ref target, encoding.Tokens.Count);
var newValues = targetEditor.Values;
for (var i = 0; i < encoding.Tokens.Count; i++)
{
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
{
newValues[i] = target.GetItemOrDefault(++targetIndex);
}
else
{
newValues[i] = target.GetItemOrDefault(targetIndex);
}
}
target = targetEditor.Commit();
}
t = torch.tensor((ZeroArray).Concat(Tokenizer.RobertaModel().IdsToOccurrenceRanks(encoding.Ids)).ToList(), device: Device);

if (t.NumberOfElements > 512)
t = t.slice(0, 0, 512, 1);

return t;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private protected override int GetNumCorrect(torch.Tensor predictions, torch.Tensor targets)
{
Expand Down Expand Up @@ -334,6 +374,41 @@ private protected override Delegate CreateGetter(DataViewRow input, int iinfo, T

}

private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer tokenizer, TensorCacher outputCacher)
{
var pre = tokenizer.PreTokenizer.PreTokenize(sentence);
TokenizerResult encoding = tokenizer.Encode(sentence);

var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
var prediction = argmax.ToArray<long>();

var targetIndex = 0;
// Figure out actual count of output tokens
for (var i = 0; i < encoding.Tokens.Count; i++)
{
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
{
targetIndex++;
}
}

var editor = VBufferEditor.Create(ref dst, targetIndex + 1);
var newValues = editor.Values;
targetIndex = 0;

newValues[targetIndex++] = (uint)prediction[0];

for (var i = 1; i < encoding.Tokens.Count; i++)
{
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
{
newValues[targetIndex++] = (uint)prediction[i];
}
}

dst = editor.Commit();
}

private Delegate MakePredictedLabelGetter(DataViewRow input, IChannel ch, TensorCacher outputCacher)
{
ValueGetter<ReadOnlyMemory<char>> getSentence1 = default;
Expand All @@ -353,13 +428,7 @@ private Delegate MakePredictedLabelGetter(DataViewRow input, IChannel ch, Tensor
var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
var prediction = argmax.ToArray<long>();

var editor = VBufferEditor.Create(ref dst, prediction.Length - 1);
for (int i = 1; i < prediction.Length; i++)
{
editor.Values[i - 1] = (uint)prediction[i];
}

dst = editor.Commit();
CondenseOutput(ref dst, sentence1.ToString(), tokenizer, outputCacher);
};

return classification;
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.TorchSharp/TorchSharpBaseTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ private bool ValidateStep(DataViewRowCursor cursor,
cursorValid = cursor.MoveNext();
if (cursorValid)
{
inputTensors.Add(PrepareRowTensor());
TLabelCol target = default;
labelGetter(ref target);
inputTensors.Add(PrepareRowTensor(ref target));
targets.Add(AddToTargets(target));
}
else
Expand Down Expand Up @@ -312,9 +312,9 @@ private bool TrainStep(IHost host,
cursorValid = cursor.MoveNext();
if (cursorValid)
{
inputTensors.Add(PrepareRowTensor());
TLabelCol target = default;
labelGetter(ref target);
inputTensors.Add(PrepareRowTensor(ref target));
targets.Add(AddToTargets(target));
}
else
Expand Down Expand Up @@ -343,7 +343,7 @@ private bool TrainStep(IHost host,

private protected abstract void RunModelAndBackPropagate(ref List<Tensor> inputTensorm, ref Tensor targetsTensor);

private protected abstract torch.Tensor PrepareRowTensor();
private protected abstract torch.Tensor PrepareRowTensor(ref TLabelCol target);
private protected abstract torch.Tensor PrepareBatchTensor(ref List<Tensor> inputTensors, Device device);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down
Loading