Skip to content

StopWordsRemovingEstimator export to Onnx #5279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 85 additions & 4 deletions src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using Microsoft.ML.Data.IO;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;

Expand Down Expand Up @@ -343,14 +344,16 @@ private static Stream GetResourceFileStreamOrNull(StopWordsRemovingEstimator.Lan
return assembly.GetManifestResourceStream($"{assembly.GetName().Name}.Text.StopWords.{lang.ToString()}.txt");
}

private sealed class Mapper : MapperBase
private sealed class Mapper : MapperBase, ISaveAsOnnx
{
private readonly DataViewType[] _types;
private readonly StopWordsRemovingTransformer _parent;
private readonly int[] _languageColumns;
private readonly bool?[] _resourcesExist;
private readonly Dictionary<int, int> _colMapNewToOld;

public bool CanSaveOnnx(OnnxContext ctx) => true;

public Mapper(StopWordsRemovingTransformer parent, DataViewSchema inputSchema)
: base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
{
Expand Down Expand Up @@ -438,6 +441,45 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
return del;
}

public void SaveAsOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
var srcVariableName = ctx.GetVariableName(_parent.ColumnPairs[i].inputColumnName);
if (!ctx.ContainsColumn(srcVariableName))
continue;
var dstVariableName = ctx.AddIntermediateVariable(_types[i], _parent.ColumnPairs[i].outputColumnName);
SaveAsOnnxCore(ctx, i, srcVariableName, dstVariableName);
}
}

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
var opType = "Squeeze";
var squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput", true);
var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), "");
Copy link
Contributor

@harishsk harishsk Jul 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to avoid skipping the shape and type addition? #Resolved

Copy link
Contributor Author

@Lynx1820 Lynx1820 Jul 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I skipped the shape is because the shape of the tokenized word vector is not known prior to inference time. The number of words tokenized may be any number. #Resolved

Copy link
Contributor Author

@Lynx1820 Lynx1820 Jul 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node.AddAttribute("axes", new long[] { 0 });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in this PR, but in a different PR it maybe worth considering changing the default domain for CreateNode to be "ai.onnx" and not "ai.onnx.ml". The latter has very few ops and we mostly use operators from "ai.onnx" and it makes sense to retain that as the default.


opType = "StringNormalizer";
var stringNormalizerOutput = ctx.AddIntermediateVariable(_types[iinfo], "StringNormalizerOutput", true);
node = ctx.CreateNode(opType, squeezeOutput, stringNormalizerOutput, ctx.GetNodeName(opType), "");

var langToUse = _parent._columns[iinfo].Language;
var lang = default(ReadOnlyMemory<char>);
UpdateLanguage(ref langToUse, null, ref lang);

var words = StopWords[iinfo].Select(item => Convert.ToString(item.Value));
node.AddAttribute("stopwords", StopWords[iinfo].Select(item => Convert.ToString(item.Value)));

opType = "Unsqueeze";
squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput");
node = ctx.CreateNode(opType, stringNormalizerOutput, dstVariableName, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });
}

private void UpdateLanguage(ref StopWordsRemovingEstimator.Language langToUse, ValueGetter<ReadOnlyMemory<char>> getLang, ref ReadOnlyMemory<char> langTxt)
{
if (getLang != null)
Expand Down Expand Up @@ -490,7 +532,7 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
/// | Does this estimator need to look at the data to train its parameters? | No |
/// | Input column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
/// | Output column data type | Variable-sized vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
/// | Exportable to ONNX | No |
/// | Exportable to ONNX | Yes |
///
/// The resulting <xref:Microsoft.ML.Transforms.Text.StopWordsRemovingTransformer> creates a new column, named as specified in the output column name parameter,
/// and fills it with a vector of words containing all of the words in the input column **except the predefined list of stopwords for the specified language.
Expand Down Expand Up @@ -1016,11 +1058,13 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly DataViewType[] _types;
private readonly CustomStopWordsRemovingTransformer _parent;

public bool CanSaveOnnx(OnnxContext ctx) => true;

public Mapper(CustomStopWordsRemovingTransformer parent, DataViewSchema inputSchema)
: base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), parent, inputSchema)
{
Expand Down Expand Up @@ -1084,6 +1128,43 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b

return del;
}

public void SaveAsOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
var srcVariableName = ctx.GetVariableName(_parent.ColumnPairs[i].inputColumnName);
if (!ctx.ContainsColumn(srcVariableName))
continue;
var dstVariableName = ctx.AddIntermediateVariable(_types[i], _parent.ColumnPairs[i].outputColumnName);

SaveAsOnnxCore(ctx, i, srcVariableName, dstVariableName);
}
}

// Note: Since StringNormalizer only accepts inputs of shape [C] or [1,C], we temporarily squeeze the
// batch dimension which may exceed 1
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
var opType = "Squeeze";
var squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput", true);
var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });

opType = "StringNormalizer";
var stringNormalizerOutput = ctx.AddIntermediateVariable(_types[iinfo], "StringNormalizerOutput", true);
node = ctx.CreateNode(opType, squeezeOutput, stringNormalizerOutput, ctx.GetNodeName(opType), "");
var words = _parent._stopWordsMap.ToList();
node.AddAttribute("stopwords", words.Select(item => Convert.ToString(item.Value)));

opType = "Unsqueeze";
squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput");
node = ctx.CreateNode(opType, stringNormalizerOutput, dstVariableName, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });
}
}
}

Expand All @@ -1098,7 +1179,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
/// | Does this estimator need to look at the data to train its parameters? | No |
/// | Input column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
/// | Output column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
/// | Exportable to ONNX | No |
/// | Exportable to ONNX | Yes |
///
/// The resulting <xref:Microsoft.ML.Transforms.Text.CustomStopWordsRemovingTransformer> creates a new column, named as specified by the output column name parameter, and
/// fills it with a vector of words containing all of the words in the input column except those given by the stopwords parameter.
Expand Down
52 changes: 50 additions & 2 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -974,8 +974,8 @@ public void OneHotHashEncodingOnnxConversionTest()
var mlContext = new MLContext();
string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);

var dataView = ML.Data.LoadFromTextFile<BreastCancerCatFeatureExample>(dataPath);
var pipeline = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerCatFeatureExample>(dataPath);
var pipeline = mlContext.Transforms.Categorical.OneHotHashEncoding(new[]{
new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false),
});
var onnxFileName = "OneHotHashEncoding.onnx";
Expand Down Expand Up @@ -1343,6 +1343,54 @@ public void NgramOnnxConversionTest(
Done();
}

[Fact]
public void CustomStopWordsRemovingEstimatorOnnxTest()
{
var mlContext = new MLContext();

var pipeline = mlContext.Transforms.Text.TokenizeIntoWords("Words", "Text")
.Append(mlContext.Transforms.Text.RemoveStopWords(
"WordsWithoutStopWords", "Words", stopwords:
new[] { "cat", "sat", "on" }));

var samples = new List<TextData>()
{
new TextData(){ Text = "cat sat on mat" },
new TextData(){ Text = "mat not fit cat" },
new TextData(){ Text = "a cat think mat bad" },
};
var dataView = mlContext.Data.LoadFromEnumerable(samples);
var onnxFileName = $"CustomStopWordsRemovingEstimator.onnx";

TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords")});

Done();
}

[Fact]
public void StopWordsRemovingEstimatorOnnxTest()
{
var mlContext = new MLContext();

var pipeline = mlContext.Transforms.Text.TokenizeIntoWords("Words", "Text")
.Append(mlContext.Transforms.Text.RemoveDefaultStopWords(
"WordsWithoutStopWords", "Words", language:
StopWordsRemovingEstimator.Language.English));

var samples = new List<TextData>()
{
new TextData(){ Text = "a go cat sat on mat" },
new TextData(){ Text = "a mat not fit go cat" },
new TextData(){ Text = "cat think mat bad a" },
};
var dataView = mlContext.Data.LoadFromEnumerable(samples);
var onnxFileName = $"StopWordsRemovingEstimator.onnx";

TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords") });

Done();
}

[Theory]
[InlineData(DataKind.Boolean)]
[InlineData(DataKind.SByte)]
Expand Down