Skip to content

Added onnx export support for several multiclass classifiers #4462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -978,13 +978,35 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureCol
{
Host.CheckValue(ctx, nameof(ctx));

string predictedLabelInt64 = null;
string predictedLabelUint32 = null;
// REVIEW: What is the right way to get the name of the predicted column?
for (int i = 0; i < outputs.Length; i++)
{
if (outputs[i] != DefaultColumnNames.PredictedLabel)
continue;
predictedLabelUint32 = DefaultColumnNames.PredictedLabel;
predictedLabelInt64 = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "PredictedLabelInt64", true);
outputs[i] = predictedLabelInt64;
break;
}

Host.CheckValue(predictedLabelInt64, nameof(predictedLabelInt64));

string opType = "LinearClassifier";
var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));
node.AddAttribute("post_transform", GetOnnxPostTransform());
node.AddAttribute("multi_class", true);
node.AddAttribute("coefficients", Weights.SelectMany(w => w.DenseValues()));
node.AddAttribute("intercepts", Biases);
node.AddAttribute("classlabels_ints", Enumerable.Range(0, NumberOfClasses).Select(x => (long)x));
node.AddAttribute("classlabels_ints", Enumerable.Range(1, NumberOfClasses).Select(x => (long)x));

// Onnx outputs an Int64, but ML.NET outputs UInt32. So cast the Onnx output here
opType = "Cast";
var castNode = ctx.CreateNode(opType, predictedLabelInt64, predictedLabelUint32, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
castNode.AddAttribute("to", t);

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
Expand Down Expand Up @@ -244,7 +245,8 @@ public sealed class OneVersusAllModelParameters :
IValueMapper,
ICanSaveInSourceCode,
ICanSaveInTextFormat,
ISingleCanSavePfa
ISingleCanSavePfa,
ISingleCanSaveOnnx
{
internal const string LoaderSignature = "OVAExec";
internal const string RegistrationName = "OVAPredictor";
Expand Down Expand Up @@ -490,7 +492,11 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
}
}

private abstract class ImplBase : ISingleCanSavePfa
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _impl.CanSaveOnnx(ctx);

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) => _impl.SaveAsOnnx(ctx, outputNames, featureColumn);

private abstract class ImplBase : ISingleCanSavePfa, ISingleCanSaveOnnx
{
public OutputFormula OutputFormula;
public abstract DataViewType InputType { get; }
Expand All @@ -499,6 +505,10 @@ private abstract class ImplBase : ISingleCanSavePfa
public abstract ValueMapper<VBuffer<float>, VBuffer<float>> GetMapper();
public abstract JToken SaveAsPfa(BoundPfaContext ctx, JToken input);

public bool CanSaveOnnx(OnnxContext ctx) => Predictors.All(pred => (pred as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true);
Copy link
Member

@ganik ganik Nov 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(pred as ICanSaveOnnx)? [](start = 79, length = 23)

What happens if it cant cast? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it can't cast, it will return a null. And the question mark after the ')' ensures makes it a null conditional operator and will not execute further and the overall function will return false.


In reply to: 348211703 [](ancestors = 348211703)


public abstract bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn);

protected bool IsValid(IValueMapper mapper, ref VectorDataViewType inputType)
{
Contracts.AssertValueOrNull(mapper);
Expand All @@ -521,6 +531,65 @@ protected bool IsValid(IValueMapper mapper, ref VectorDataViewType inputType)
}
return true;
}

public string[] SaveAsOnnxPreProcess(OnnxContext ctx, string featureColumn, bool clipToZero)
{
string[] outputs = new string[Predictors.Length];

string[] localOutputNames = { DefaultColumnNames.PredictedLabel, DefaultColumnNames.Score, DefaultColumnNames.Probability };

for (int i = 0; i < Predictors.Length; i++)
{
var predictorOutputNames = new string[localOutputNames.Length];

predictorOutputNames[0] = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, $"{DefaultColumnNames.PredictedLabel}_{i}", true);
predictorOutputNames[1] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"{DefaultColumnNames.Score}_{i}", true);
predictorOutputNames[2] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"{DefaultColumnNames.Probability}_{i}", true);

string clipInput = predictorOutputNames[2];

var pred = Predictors[i] as ISingleCanSaveOnnx;
Contracts.AssertValue(pred);
pred.SaveAsOnnx(ctx, predictorOutputNames, featureColumn);

if (clipToZero)
{
var clipOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"ClipOutput_{i}", true);
outputs[i] = clipOutput;

string opType = "Clip";
var clipNode = ctx.CreateNode(opType, clipInput, outputs[i], ctx.GetNodeName(opType), "");
clipNode.AddAttribute("min", 0.0);
}
else
outputs[i] = predictorOutputNames[2];
}
return outputs;
}

public void SaveAsOnnxPostProcess(OnnxContext ctx, string inputName, string[] outputNames)
{
Contracts.Assert(outputNames.Length >= 2);

string opType;
opType = "ArgMax";
var argMaxOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ArgMaxOutput", true);
var argMaxNode = ctx.CreateNode(opType, inputName, argMaxOutput, ctx.GetNodeName(opType), "");
argMaxNode.AddAttribute("keepdims", 0);

opType = "Add";
var one = ctx.AddInitializer(1);
var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "AddOutput", true);
var addNode = ctx.CreateNode(opType, new[] { argMaxOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var castToUint32Node = ctx.CreateNode(opType, addOutput, outputNames[0], ctx.GetNodeName(opType), "");
var t2 = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
castToUint32Node.AddAttribute("to", t2);

opType = "Max";
ctx.CreateNode(opType, inputName, outputNames[1], ctx.GetNodeName(opType), "");
}
}

private sealed class ImplRaw : ImplBase
Expand Down Expand Up @@ -586,6 +655,21 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
JObject jobj = null;
return jobj.AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)).AddReturn("new", rootObjects);
}

public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true);

string opType = "Concat";
var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true);
var concatNode = ctx.CreateNode(opType, probabilityOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), "");
concatNode.AddAttribute("axis", 0);

base.SaveAsOnnxPostProcess(ctx, concatOutput, outputNames);

return true;

}
}

private sealed class ImplDist : ImplBase
Expand Down Expand Up @@ -699,6 +783,51 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
var factorVar = ctx.DeclareVar(null, PfaUtils.Call("/", 1.0, PfaUtils.Call("a.sum", resultVar)));
return PfaUtils.Call("la.scale", resultVar, factorVar);
}

public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
Contracts.Assert(outputNames.Length >= 2);

string opType;
var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true);

opType = "Sum";
var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScores", true);
var sumNode = ctx.CreateNode(opType, probabilityOutputs, new[] { sumOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZero", true);
var castNode = ctx.CreateNode(opType, sumOutput, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();
castNode.AddAttribute("to", t);

var castIsZeroSumToFloat = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZeroAsFloat", true);
var castIsZeroSumToFloatNode = ctx.CreateNode(opType, castOutput, castIsZeroSumToFloat, ctx.GetNodeName(opType), "");
var t1 = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
castIsZeroSumToFloatNode.AddAttribute("to", t1);

opType = "Sum";
var sumOutputNonZero = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScoresNonZero", true);
var sumOutputNonZeroNode = ctx.CreateNode(opType, new[] { sumOutput, castIsZeroSumToFloat },
new[] { sumOutputNonZero }, ctx.GetNodeName(opType), "");

string[] divOutputs = new string[Predictors.Length];
for (int i = 0; i < Predictors.Length; i++)
{
opType = "Div";
divOutputs[i] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"DivOutput_{i}", true);
ctx.CreateNode(opType, new[] { probabilityOutputs[i], sumOutputNonZero }, new[] { divOutputs[i] }, ctx.GetNodeName(opType), "");
}

opType = "Concat";
var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true);
var concatNode = ctx.CreateNode(opType, divOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), "");
concatNode.AddAttribute("axis", 0);

base.SaveAsOnnxPostProcess(ctx, concatOutput, outputNames);

return true;
}
}

private sealed class ImplSoftmax : ImplBase
Expand Down Expand Up @@ -768,6 +897,36 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
{
throw new NotImplementedException("Softmax's PFA exporter is not implemented yet.");
}

public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
Contracts.Assert(outputNames.Length >= 2);

var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false);

string opType;
opType = "Concat";
var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true);
var concatNode = ctx.CreateNode(opType, probabilityOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), "");
concatNode.AddAttribute("axis", 0);

opType = "Exp";
var expOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ExpOutput", true);
var expNode = ctx.CreateNode(opType, concatOutput, expOutput, ctx.GetNodeName(opType), "");

opType = "ReduceSum";
var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOutput", true);
var sumNode = ctx.CreateNode(opType, expOutput, sumOutput, ctx.GetNodeName(opType), "");
sumNode.AddAttribute("keepdims", 0);

opType = "Div";
var divOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "DivOutput", true);
var divNode = ctx.CreateNode(opType, new[] { expOutput, sumOutput }, new[] { divOutput }, ctx.GetNodeName(opType), "");

base.SaveAsOnnxPostProcess(ctx, divOutput, outputNames);

return true;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
"Features0"
],
"output": [
"PredictedLabel",
"PredictedLabelInt64",
"Score"
],
"name": "LinearClassifier",
Expand Down Expand Up @@ -239,7 +239,6 @@
{
"name": "classlabels_ints",
"ints": [
"0",
"1",
"2",
"3",
Expand All @@ -248,13 +247,31 @@
"6",
"7",
"8",
"9"
"9",
"10"
],
"type": "INTS"
}
],
"domain": "ai.onnx.ml"
},
{
"input": [
"PredictedLabelInt64"
],
"output": [
"PredictedLabel"
],
"name": "Cast0",
"opType": "Cast",
"attribute": [
{
"name": "to",
"i": "12",
"type": "INT"
}
]
},
{
"input": [
"Label0"
Expand Down
Loading