Skip to content

Added onnx export functionality for MissingValueIndicatorTransformer #4194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Sep 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ private static TensorProto.Types.DataType ConvertToTensorProtoType(Type rawType)
var dataType = TensorProto.Types.DataType.Undefined;

if (rawType == typeof(bool))
dataType = TensorProto.Types.DataType.Float;
dataType = TensorProto.Types.DataType.Bool;
else if (rawType == typeof(ReadOnlyMemory<char>))
dataType = TensorProto.Types.DataType.String;
else if (rawType == typeof(sbyte))
Expand Down Expand Up @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion;
model.ModelVersion = modelVersion;
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 7 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 9 });
model.Graph = new GraphProto();
var graph = model.Graph;
graph.Node.Add(nodes);
Expand Down
43 changes: 42 additions & 1 deletion src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;

Expand Down Expand Up @@ -140,7 +141,7 @@ private protected override void SaveModel(ModelSaveContext ctx)

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly MissingValueIndicatorTransformer _parent;
private readonly ColInfo[] _infos;
Expand Down Expand Up @@ -426,6 +427,46 @@ private void FillValues(int srcLength, ref VBuffer<bool> dst, List<int> indices,
dst = editor.Commit();
}
}

public bool CanSaveOnnx(OnnxContext ctx) => true;

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));

for (int iinfo = 0; iinfo < _infos.Length; ++iinfo)
{
ColInfo info = _infos[iinfo];
string inputColumnName = info.InputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
{
ctx.RemoveColumn(info.Name, false);
continue;
}

if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName),
ctx.AddIntermediateVariable(_infos[iinfo].OutputType, info.Name)))
{
ctx.RemoveColumn(info.Name, true);
}
}
}

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
var inputType = _infos[iinfo].InputType;
Type rawType = (inputType is VectorDataViewType vectorType) ? vectorType.ItemType.RawType : inputType.RawType;

if (rawType != typeof(float))
return false;

string opType;
opType = "IsNaN";
var isNaNOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsNaNOutput", true);
var nanNode = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");

return true;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
"elemType": "BOOL",
"shape": {
"dim": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
"elemType": "BOOL",
"shape": {
"dim": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@
"name": "Label",
"type": {
"tensorType": {
"elemType": "FLOAT",
"elemType": "BOOL",
"shape": {
"dim": [
{
Expand Down Expand Up @@ -470,7 +470,7 @@
"name": "Label0",
"type": {
"tensorType": {
"elemType": "FLOAT",
"elemType": "BOOL",
"shape": {
"dim": [
{
Expand Down Expand Up @@ -542,7 +542,7 @@
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
"elemType": "BOOL",
"shape": {
"dim": [
{
Expand Down
163 changes: 163 additions & 0 deletions test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
{
"irVersion": "3",
"producerName": "ML.NET",
"producerVersion": "##VERSION##",
"domain": "machinelearning.dotnet",
"graph": {
"node": [
{
"input": [
"Features"
],
"output": [
"MissingIndicator"
],
"name": "IsNaN",
"opType": "IsNaN"
},
{
"input": [
"MissingIndicator"
],
"output": [
"MissingIndicator0"
],
"name": "Cast",
"opType": "Cast",
"attribute": [
{
"name": "to",
"i": "6",
"type": "INT"
}
]
},
{
"input": [
"Features"
],
"output": [
"Features0"
],
"name": "Identity",
"opType": "Identity"
},
{
"input": [
"MissingIndicator0"
],
"output": [
"MissingIndicator1"
],
"name": "Identity0",
"opType": "Identity"
}
],
"name": "model",
"input": [
{
"name": "Features",
"type": {
"tensorType": {
"elemType": "FLOAT",
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "3"
}
]
}
}
}
}
],
"output": [
{
"name": "Features0",
"type": {
"tensorType": {
"elemType": "FLOAT",
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "3"
}
]
}
}
}
},
{
"name": "MissingIndicator1",
"type": {
"tensorType": {
"elemType": "INT32",
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "3"
}
]
}
}
}
}
],
"valueInfo": [
{
"name": "MissingIndicator",
"type": {
"tensorType": {
"elemType": "BOOL",
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "3"
}
]
}
}
}
},
{
"name": "MissingIndicator0",
"type": {
"tensorType": {
"elemType": "INT32",
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "3"
}
]
}
}
}
}
]
},
"opsetImport": [
{
"domain": "ai.onnx.ml",
"version": "1"
},
{
"version": "9"
}
]
}
66 changes: 66 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ public OnnxConversionTest(ITestOutputHelper output) : base(output)
{
}

private bool IsOnnxRuntimeSupported()
{
return Environment.Is64BitProcess && (!RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || AttributeHelpers.CheckLibcVersionGreaterThanMinimum(new System.Version(2, 23)));
}

/// <summary>
/// In this test, we convert a trained <see cref="TransformerChain"/> into ONNX <see cref="ModelProto"/> file and then
/// call <see cref="OnnxScoringEstimator"/> to evaluate that file. The outputs of <see cref="OnnxScoringEstimator"/> are checked against the original
Expand Down Expand Up @@ -740,6 +745,67 @@ public void OnnxTypeConversionTest()
}
}

private class TransformedDataPoint : DataPoint, IEquatable<TransformedDataPoint>
{
[VectorType(3)]
public int[] MissingIndicator { get; set; }

public bool Equals(TransformedDataPoint other)
{
return Enumerable.SequenceEqual(MissingIndicator, other.MissingIndicator);
}
}

[Fact]
void IndicateMissingValuesOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);

var samples = new List<DataPoint>()
{
new DataPoint() { Features = new float[3] {1, 1, 0}, },
new DataPoint() { Features = new float[3] {0, float.NaN, 1}, },
new DataPoint() { Features = new float[3] {-1, float.NaN, float.PositiveInfinity}, },
};
var dataView = mlContext.Data.LoadFromEnumerable(samples);

// IsNaN outputs a binary tensor. Support for this has been added in the latest version
// of Onnxruntime, but that hasn't been released yet.
// So we need to convert its type to Int32 until then.
// ConvertType part of the pipeline can be removed once we pick up a new release of the Onnx runtime

var pipeline = mlContext.Transforms.IndicateMissingValues(new[] { new InputOutputColumnPair("MissingIndicator", "Features"), })
.Append(mlContext.Transforms.Conversion.ConvertType("MissingIndicator", outputKind: DataKind.Int32));

var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var mlnetData = mlContext.Data.CreateEnumerable<TransformedDataPoint>(transformedData, false);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms");
var onnxFileName = "IndicateMissingValues.onnx";
var onnxTextName = "IndicateMissingValues.txt";
var onnxModelPath = GetOutputPath(onnxFileName);
var onnxTextPath = GetOutputPath(subDir, onnxTextName);

SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath);

// Compare results produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedVectorColumns<int>(model.LastTransformer.ColumnPairs[0].outputColumnName, outputNames[1], transformedData, onnxResult);
}

CheckEquality(subDir, onnxTextName, parseOption: NumberParseOption.UseSingle);
Done();
}

private void CreateDummyExamplesToMakeComplierHappy()
{
var dummyExample = new BreastCancerFeatureVector() { Features = null };
Expand Down