-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Onnx Export for ValueMapping estimator #5577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a5a7a2a
0eb8e7a
3d3f34c
7d1b86d
ba2b727
2416863
0be5f42
af7f609
e4ea4fe
13fd3c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
using Microsoft.ML.Data; | ||
using Microsoft.ML.Data.IO; | ||
using Microsoft.ML.Internal.Utilities; | ||
using Microsoft.ML.Model.OnnxConverter; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Transforms; | ||
|
||
|
@@ -818,6 +819,8 @@ private static ValueMap CreateValueMapInvoke<TKey, TValue>(DataViewSchema.Column | |
public abstract Delegate GetGetter(DataViewRow input, int index); | ||
|
||
public abstract IDataView GetDataView(IHostEnvironment env); | ||
public abstract TKey[] GetKeys<TKey>(); | ||
public abstract TValue[] GetValues<TValue>(); | ||
} | ||
|
||
/// <summary> | ||
|
@@ -962,6 +965,16 @@ private static TValue GetVector<T>(TValue value) | |
} | ||
|
||
private static TValue GetValue<T>(TValue value) => value; | ||
|
||
public override T[] GetKeys<T>() | ||
{ | ||
return _mapping.Keys.Cast<T>().ToArray(); | ||
} | ||
public override T[] GetValues<T>() | ||
{ | ||
return _mapping.Values.Cast<T>().ToArray(); | ||
} | ||
|
||
} | ||
|
||
/// <summary> | ||
|
@@ -1012,12 +1025,13 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) | |
return new Mapper(this, schema, _valueMap, ColumnPairs); | ||
} | ||
|
||
private sealed class Mapper : OneToOneMapperBase | ||
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx | ||
{ | ||
private readonly DataViewSchema _inputSchema; | ||
private readonly ValueMap _valueMap; | ||
private readonly (string outputColumnName, string inputColumnName)[] _columns; | ||
private readonly ValueMappingTransformer _parent; | ||
public bool CanSaveOnnx(OnnxContext ctx) => true; | ||
|
||
internal Mapper(ValueMappingTransformer transform, | ||
DataViewSchema inputSchema, | ||
|
@@ -1040,6 +1054,227 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b | |
return _valueMap.GetGetter(input, ColMapNewToOld[iinfo]); | ||
} | ||
|
||
public void SaveAsOnnx(OnnxContext ctx) | ||
{ | ||
const int minimumOpSetVersion = 9; | ||
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); | ||
Host.CheckValue(ctx, nameof(ctx)); | ||
|
||
for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo) | ||
{ | ||
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName; | ||
string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName; | ||
|
||
if (!_inputSchema.TryGetColumnIndex(inputColumnName, out int colSrc)) | ||
throw Host.ExceptSchemaMismatch(nameof(_inputSchema), "input", inputColumnName); | ||
var type = _inputSchema[colSrc].Type; | ||
DataViewType colType; | ||
if (type is VectorDataViewType vectorType) | ||
colType = new VectorDataViewType((PrimitiveDataViewType)_parent.ValueColumnType, vectorType.Dimensions); | ||
else | ||
colType = _parent.ValueColumnType; | ||
string dstVariableName = ctx.AddIntermediateVariable(colType, outputColumnName); | ||
if (!ctx.ContainsColumn(inputColumnName)) | ||
continue; | ||
|
||
if (!SaveAsOnnxCore(ctx, ctx.GetVariableName(inputColumnName), dstVariableName)) | ||
ctx.RemoveColumn(inputColumnName, true); | ||
} | ||
} | ||
|
||
private void CastInputTo<T>(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput, PrimitiveDataViewType itemType) | ||
{ | ||
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); | ||
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(itemType, (int)srcShape[1]), "castOutput"); | ||
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); | ||
castNode.AddAttribute("to", itemType.RawType); | ||
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
if (itemType == TextDataViewType.Instance) | ||
node.AddAttribute("keys_strings", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToString(item))); | ||
else if (itemType == NumberDataViewType.Single) | ||
node.AddAttribute("keys_floats", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToSingle(item))); | ||
else if (itemType == NumberDataViewType.Int64) | ||
node.AddAttribute("keys_int64s", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToInt64(item))); | ||
|
||
} | ||
|
||
private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) | ||
{ | ||
const int minimumOpSetVersion = 9; | ||
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); | ||
OnnxNode node; | ||
string opType = "LabelEncoder"; | ||
var labelEncoderInput = srcVariableName; | ||
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); | ||
var typeValue = _valueMap.ValueColumn.Type; | ||
var typeKey = _valueMap.KeyColumn.Type; | ||
var kind = _valueMap.ValueColumn.Type.GetRawKind(); | ||
|
||
var labelEncoderOutput = (typeValue == NumberDataViewType.Single || typeValue == TextDataViewType.Instance || typeValue == NumberDataViewType.Int64) ? dstVariableName : | ||
(typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) ? ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "LabelEncoderOutput") : | ||
ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int) srcShape[1]), "LabelEncoderOutput"); | ||
|
||
// The LabelEncoder operator doesn't support mappings between the same type and only supports mappings between int64s, floats, and strings. | ||
// As a result, we need to cast most inputs and outputs. In order to avoid as many unsupported mappings, we cast keys that are of NumberDataTypeView | ||
// to strings and values of NumberDataViewType to int64s. | ||
// String -> String mappings can't be supported. | ||
if (typeKey == NumberDataViewType.Int64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm left wondering if it's possible re refactor this big if-else statement into something that's reusable, but as discussed offline it might not be possible as each case handles things in a particular way and you might end up re-writing the if-else blocks anyway, and not saving much lines of code. |
||
{ | ||
// To avoid a int64 -> int64 mapping, we cast keys to strings | ||
if (typeValue is NumberDataViewType) | ||
{ | ||
CastInputTo<Int64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
} | ||
else | ||
{ | ||
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
node.AddAttribute("keys_int64s", _valueMap.GetKeys<Int64>()); | ||
} | ||
} | ||
else if (typeKey == NumberDataViewType.Int32) | ||
{ | ||
// To avoid a string -> string mapping, we cast keys to int64s | ||
if (typeValue is TextDataViewType) | ||
CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
else | ||
CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
} | ||
else if (typeKey == NumberDataViewType.Int16) | ||
{ | ||
if (typeValue is TextDataViewType) | ||
CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
else | ||
CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
} | ||
else if (typeKey == NumberDataViewType.UInt64) | ||
{ | ||
if (typeValue is TextDataViewType) | ||
CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
else | ||
CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
} | ||
else if (typeKey == NumberDataViewType.UInt32) | ||
{ | ||
if (typeValue is TextDataViewType) | ||
CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
else | ||
CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
} | ||
else if (typeKey == NumberDataViewType.UInt16) | ||
{ | ||
if (typeValue is TextDataViewType) | ||
CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
else | ||
CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
} | ||
else if (typeKey == NumberDataViewType.Single) | ||
{ | ||
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) | ||
{ | ||
CastInputTo<float>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
} | ||
else | ||
{ | ||
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
node.AddAttribute("keys_floats", _valueMap.GetKeys<float>()); | ||
} | ||
} | ||
else if (typeKey == NumberDataViewType.Double) | ||
{ | ||
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) | ||
CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
else | ||
CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single); | ||
} | ||
else if (typeKey == TextDataViewType.Instance) | ||
{ | ||
if (typeValue == TextDataViewType.Instance) | ||
return false; | ||
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
node.AddAttribute("keys_strings", _valueMap.GetKeys<ReadOnlyMemory<char>>()); | ||
} | ||
else if (typeKey == BooleanDataViewType.Instance) | ||
{ | ||
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) | ||
{ | ||
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput"); | ||
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); | ||
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType(); | ||
castNode.AddAttribute("to", t); | ||
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
var values = Array.ConvertAll(_valueMap.GetKeys<bool>(), item => Convert.ToString(Convert.ToByte(item))); | ||
node.AddAttribute("keys_strings", values); | ||
} | ||
else | ||
CastInputTo<bool>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single); | ||
} | ||
else | ||
return false; | ||
|
||
if (typeValue == NumberDataViewType.Int64) | ||
{ | ||
node.AddAttribute("values_int64s", _valueMap.GetValues<long>()); | ||
} | ||
else if (typeValue == NumberDataViewType.Int32) | ||
{ | ||
node.AddAttribute("values_int64s", _valueMap.GetValues<int>().Select(item => Convert.ToInt64(item))); | ||
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
castNode.AddAttribute("to", typeValue.RawType); | ||
} | ||
else if (typeValue == NumberDataViewType.Int16) | ||
{ | ||
node.AddAttribute("values_int64s", _valueMap.GetValues<short>().Select(item => Convert.ToInt64(item))); | ||
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
castNode.AddAttribute("to", typeValue.RawType); | ||
} | ||
else if (typeValue == NumberDataViewType.UInt64 || kind == InternalDataKind.U8) | ||
{ | ||
node.AddAttribute("values_int64s", _valueMap.GetValues<ulong>().Select(item => Convert.ToInt64(item))); | ||
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
castNode.AddAttribute("to", typeValue.RawType); | ||
} | ||
else if (typeValue == NumberDataViewType.UInt32 || kind == InternalDataKind.U4) | ||
{ | ||
node.AddAttribute("values_int64s", _valueMap.GetValues<uint>().Select(item => Convert.ToInt64(item))); | ||
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
castNode.AddAttribute("to", typeValue.RawType); | ||
} | ||
else if (typeValue == NumberDataViewType.UInt16) | ||
{ | ||
node.AddAttribute("values_int64s", _valueMap.GetValues<ushort>().Select(item => Convert.ToInt64(item))); | ||
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
castNode.AddAttribute("to", typeValue.RawType); | ||
} | ||
else if (typeValue == NumberDataViewType.Single) | ||
{ | ||
node.AddAttribute("values_floats", _valueMap.GetValues<float>()); | ||
} | ||
else if (typeValue == NumberDataViewType.Double) | ||
{ | ||
node.AddAttribute("values_floats", _valueMap.GetValues<double>().Select(item => Convert.ToSingle(item))); | ||
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
castNode.AddAttribute("to", typeValue.RawType); | ||
} | ||
else if (typeValue == TextDataViewType.Instance) | ||
{ | ||
node.AddAttribute("values_strings", _valueMap.GetValues<ReadOnlyMemory<char>>()); | ||
} | ||
else if (typeValue == BooleanDataViewType.Instance) | ||
{ | ||
node.AddAttribute("values_floats", _valueMap.GetValues<bool>().Select(item => Convert.ToSingle(item))); | ||
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
castNode.AddAttribute("to", typeValue.RawType); | ||
} | ||
else | ||
return false; | ||
|
||
//Unknown keys should map to 0 | ||
node.AddAttribute("default_int64", 0); | ||
node.AddAttribute("default_string", ""); | ||
node.AddAttribute("default_float", 0f); | ||
return true; | ||
} | ||
|
||
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() | ||
{ | ||
var result = new DataViewSchema.DetachedColumn[_columns.Length]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these need to be public? Seems they're only used by Onnx conversion, so might prefer to make them private.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The valuemap class itself is only used by the ValueMapping estimator, but since the map used is not populated until runtime, I keep the method abstract. I keep the methods public so they can used by the mapper class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, ok. Well, as long as the GetKeys new method isn't accessible to end users, it's all right.