Skip to content

Onnx Export for ValueMapping estimator #5577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 236 additions & 1 deletion src/Microsoft.ML.Data/Transforms/ValueMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;

Expand Down Expand Up @@ -818,6 +819,8 @@ private static ValueMap CreateValueMapInvoke<TKey, TValue>(DataViewSchema.Column
public abstract Delegate GetGetter(DataViewRow input, int index);

public abstract IDataView GetDataView(IHostEnvironment env);
public abstract TKey[] GetKeys<TKey>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be public? Seems they're only used by Onnx conversion, so might prefer to make them private.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The valuemap class itself is only used by the ValueMapping estimator, but since the map used is not populated until runtime, I keep the method abstract. I keep the methods public so they can used by the mapper class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, ok. Well, as long as the GetKeys new method isn't accessible to end users, it's all right.

public abstract TValue[] GetValues<TValue>();
}

/// <summary>
Expand Down Expand Up @@ -962,6 +965,16 @@ private static TValue GetVector<T>(TValue value)
}

private static TValue GetValue<T>(TValue value) => value;

public override T[] GetKeys<T>()
{
return _mapping.Keys.Cast<T>().ToArray();
}
public override T[] GetValues<T>()
{
return _mapping.Values.Cast<T>().ToArray();
}

}

/// <summary>
Expand Down Expand Up @@ -1012,12 +1025,13 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
return new Mapper(this, schema, _valueMap, ColumnPairs);
}

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly DataViewSchema _inputSchema;
private readonly ValueMap _valueMap;
private readonly (string outputColumnName, string inputColumnName)[] _columns;
private readonly ValueMappingTransformer _parent;
public bool CanSaveOnnx(OnnxContext ctx) => true;

internal Mapper(ValueMappingTransformer transform,
DataViewSchema inputSchema,
Expand All @@ -1040,6 +1054,227 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
return _valueMap.GetGetter(input, ColMapNewToOld[iinfo]);
}

public void SaveAsOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
Host.CheckValue(ctx, nameof(ctx));

for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo)
{
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;

if (!_inputSchema.TryGetColumnIndex(inputColumnName, out int colSrc))
throw Host.ExceptSchemaMismatch(nameof(_inputSchema), "input", inputColumnName);
var type = _inputSchema[colSrc].Type;
DataViewType colType;
if (type is VectorDataViewType vectorType)
colType = new VectorDataViewType((PrimitiveDataViewType)_parent.ValueColumnType, vectorType.Dimensions);
else
colType = _parent.ValueColumnType;
string dstVariableName = ctx.AddIntermediateVariable(colType, outputColumnName);
if (!ctx.ContainsColumn(inputColumnName))
continue;

if (!SaveAsOnnxCore(ctx, ctx.GetVariableName(inputColumnName), dstVariableName))
ctx.RemoveColumn(inputColumnName, true);
}
}

private void CastInputTo<T>(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput, PrimitiveDataViewType itemType)
{
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(itemType, (int)srcShape[1]), "castOutput");
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", itemType.RawType);
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
if (itemType == TextDataViewType.Instance)
node.AddAttribute("keys_strings", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToString(item)));
else if (itemType == NumberDataViewType.Single)
node.AddAttribute("keys_floats", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToSingle(item)));
else if (itemType == NumberDataViewType.Int64)
node.AddAttribute("keys_int64s", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToInt64(item)));

}

private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
OnnxNode node;
string opType = "LabelEncoder";
var labelEncoderInput = srcVariableName;
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
var typeValue = _valueMap.ValueColumn.Type;
var typeKey = _valueMap.KeyColumn.Type;
var kind = _valueMap.ValueColumn.Type.GetRawKind();

var labelEncoderOutput = (typeValue == NumberDataViewType.Single || typeValue == TextDataViewType.Instance || typeValue == NumberDataViewType.Int64) ? dstVariableName :
(typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) ? ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "LabelEncoderOutput") :
ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int) srcShape[1]), "LabelEncoderOutput");

// The LabelEncoder operator doesn't support mappings between the same type and only supports mappings between int64s, floats, and strings.
// As a result, we need to cast most inputs and outputs. In order to avoid as many unsupported mappings, we cast keys that are of NumberDataTypeView
// to strings and values of NumberDataViewType to int64s.
// String -> String mappings can't be supported.
if (typeKey == NumberDataViewType.Int64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm left wondering if it's possible re refactor this big if-else statement into something that's reusable, but as discussed offline it might not be possible as each case handles things in a particular way and you might end up re-writing the if-else blocks anyway, and not saving much lines of code.

{
// To avoid a int64 -> int64 mapping, we cast keys to strings
if (typeValue is NumberDataViewType)
{
CastInputTo<Int64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else
{
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
node.AddAttribute("keys_int64s", _valueMap.GetKeys<Int64>());
}
}
else if (typeKey == NumberDataViewType.Int32)
{
// To avoid a string -> string mapping, we cast keys to int64s
if (typeValue is TextDataViewType)
CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.Int16)
{
if (typeValue is TextDataViewType)
CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.UInt64)
{
if (typeValue is TextDataViewType)
CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.UInt32)
{
if (typeValue is TextDataViewType)
CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.UInt16)
{
if (typeValue is TextDataViewType)
CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.Single)
{
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance)
{
CastInputTo<float>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else
{
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
node.AddAttribute("keys_floats", _valueMap.GetKeys<float>());
}
}
else if (typeKey == NumberDataViewType.Double)
{
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance)
CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
else
CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single);
}
else if (typeKey == TextDataViewType.Instance)
{
if (typeValue == TextDataViewType.Instance)
return false;
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
node.AddAttribute("keys_strings", _valueMap.GetKeys<ReadOnlyMemory<char>>());
}
else if (typeKey == BooleanDataViewType.Instance)
{
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance)
{
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput");
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType();
castNode.AddAttribute("to", t);
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
var values = Array.ConvertAll(_valueMap.GetKeys<bool>(), item => Convert.ToString(Convert.ToByte(item)));
node.AddAttribute("keys_strings", values);
}
else
CastInputTo<bool>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single);
}
else
return false;

if (typeValue == NumberDataViewType.Int64)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<long>());
}
else if (typeValue == NumberDataViewType.Int32)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<int>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.Int16)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<short>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.UInt64 || kind == InternalDataKind.U8)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<ulong>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.UInt32 || kind == InternalDataKind.U4)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<uint>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.UInt16)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<ushort>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.Single)
{
node.AddAttribute("values_floats", _valueMap.GetValues<float>());
}
else if (typeValue == NumberDataViewType.Double)
{
node.AddAttribute("values_floats", _valueMap.GetValues<double>().Select(item => Convert.ToSingle(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == TextDataViewType.Instance)
{
node.AddAttribute("values_strings", _valueMap.GetValues<ReadOnlyMemory<char>>());
}
else if (typeValue == BooleanDataViewType.Instance)
{
node.AddAttribute("values_floats", _valueMap.GetValues<bool>().Select(item => Convert.ToSingle(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else
return false;

//Unknown keys should map to 0
node.AddAttribute("default_int64", 0);
node.AddAttribute("default_string", "");
node.AddAttribute("default_float", 0f);
return true;
}

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var result = new DataViewSchema.DetachedColumn[_columns.Length];
Expand Down
2 changes: 0 additions & 2 deletions test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -674,9 +674,7 @@ private static double Round(double value, int digitsOfPrecision)
public void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6, bool isRightColumnOnnxScalar = false)
{
var leftColumn = left.Schema[leftColumnName];
var rightColumn = right.Schema[rightColumnName];
var leftType = leftColumn.Type.GetItemType();
var rightType = rightColumn.Type.GetItemType();

if (leftType == NumberDataViewType.SByte)
CompareSelectedColumns<sbyte>(leftColumnName, rightColumnName, left, right, isRightColumnOnnxScalar: isRightColumnOnnxScalar);
Expand Down
Loading