Skip to content

Added onnx export support for OptionalColumnTransform #4454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,50 +128,56 @@ public OnnxNode CreateNode(string opType, string input, string output, string na
/// </summary>
/// <param name="value">The float number which is going to be added</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(float value, string name = null);
public abstract string AddInitializer(float value, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global long
/// </summary>
/// <param name="value">The long number which is going to be added into the ONNX graph</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(long value, string name = null);
public abstract string AddInitializer(long value, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global string
/// </summary>
/// <param name="value">The string which is going to be added into the ONNX graph</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(string value, string name = null);
public abstract string AddInitializer(string value, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global float tensor
/// </summary>
/// <param name="values">The floats which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the floats</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null);
public abstract string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global long tensor
/// </summary>
/// <param name="values">The longs which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the floats</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null);
public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global string tensor
/// </summary>
/// <param name="values">The strings which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the strings</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null);
public abstract string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
}
}
29 changes: 15 additions & 14 deletions src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,12 @@ public string TryGetVariableName(string colName)
/// there is a collision between names in the pipeline at any point.
/// </summary>
/// <param name="colName">IDataView column name.</param>
/// <param name="makeUniqueName">Whether a unique name should be chosen for this variable.</param>
/// <returns>Unique variable name.</returns>
public string AddVariable(string colName)
public string AddVariable(string colName, bool makeUniqueName = true)
{
_host.CheckNonEmpty(colName, nameof(colName));
_columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains);
_columnNameMap[colName] = makeUniqueName ? GetUniqueName(colName, _variableNames.Contains) : colName;
_variableNames.Add(_columnNameMap[colName]);
return _columnNameMap[colName];
}
Expand Down Expand Up @@ -269,56 +270,56 @@ public override List<long> RetrieveShapeOrNull(string variableName)
}

/// Adds constant tensor into the graph.
public override string AddInitializer(float value, string name = null)
public override string AddInitializer(float value, string name = null, bool makeUniqueName = true)
{
name = AddVariable(name ?? "float");
name = AddVariable(name ?? "float", makeUniqueName);
_initializers.Add(OnnxUtils.MakeFloat(name, value));
return name;
}

public override string AddInitializer(string value, string name = null)
public override string AddInitializer(string value, string name = null, bool makeUniqueName = true)
{
name = AddVariable(name ?? "string");
name = AddVariable(name ?? "string", makeUniqueName);
_initializers.Add(OnnxUtils.MakeString(name, value));
return name;
}

public override string AddInitializer(long value, string name = null)
public override string AddInitializer(long value, string name = null, bool makeUniqueName = true)
{
name = AddVariable(name ?? "int64");
name = AddVariable(name ?? "int64", makeUniqueName);
_initializers.Add(OnnxUtils.MakeInt64(name, value));
return name;
}

public override string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null)
public override string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

name = AddVariable(name ?? "floats");
name = AddVariable(name ?? "floats", makeUniqueName);
_initializers.Add(OnnxUtils.MakeFloats(name, values, dims));
return name;
}

public override string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null)
public override string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

name = AddVariable(name ?? "int64s");
name = AddVariable(name ?? "int64s", makeUniqueName);
_initializers.Add(OnnxUtils.MakeInt64s(name, values, dims));
return name;
}

public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null)
public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

name = AddVariable(name ?? "strings");
name = AddVariable(name ?? "strings", makeUniqueName);
_initializers.Add(OnnxUtils.MakeStrings(name, values, dims));
return name;
}
Expand Down
98 changes: 44 additions & 54 deletions src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ public sealed class OnnxModelInfo
/// </summary>
public List<string> OutputNames { get; }
/// <summary>
/// Initializers[i] is the name of the i-th initializer in <see cref="InitializersInfo"/>.
/// </summary>
public List<string> InitializerNames { get; }
/// <summary>
/// Inputs of the containing <see cref="OnnxModel"/>.
/// </summary>
public OnnxVariableInfo[] InputsInfo { get; }
Expand All @@ -46,12 +50,19 @@ public sealed class OnnxModelInfo
/// </summary>
public OnnxVariableInfo[] OutputsInfo { get; }

public OnnxModelInfo(IEnumerable<OnnxVariableInfo> inputsInfo, IEnumerable<OnnxVariableInfo> outputsInfo)
/// <summary>
/// Initializers of the containing <see cref="OnnxModel"/>
/// </summary>
public OnnxVariableInfo[] InitializersInfo { get; }

public OnnxModelInfo(IEnumerable<OnnxVariableInfo> inputsInfo, IEnumerable<OnnxVariableInfo> outputsInfo, IEnumerable<OnnxVariableInfo> initializersInfo)
{
InputNames = inputsInfo.Select(val => val.Name).ToList();
InputsInfo = inputsInfo.ToArray();
OutputNames = outputsInfo.Select(val => val.Name).ToList();
OutputsInfo = outputsInfo.ToArray();
InitializerNames = initializersInfo.Select(val => val.Name).ToList();
InitializersInfo = initializersInfo.ToArray();
}

/// <summary>
Expand All @@ -60,10 +71,16 @@ public OnnxModelInfo(IEnumerable<OnnxVariableInfo> inputsInfo, IEnumerable<OnnxV
public OnnxVariableInfo GetInput(string name)
{
var index = InputNames.IndexOf(name);
if (index < 0)
throw Contracts.ExceptParamValue(name, nameof(name), $"Input tensor, {name}, does not exist in the ONNX model. " +
$"Available input names are [{string.Join(",", InputNames)}].");
return InputsInfo[index];
if (index >= 0)
return InputsInfo[index];

index = InitializerNames.IndexOf(name);
if (index >= 0)
return InitializersInfo[index];

// If we dont find the index in the input, try find it in the initializers
throw Contracts.ExceptParamValue(name, nameof(name), $"Input tensor, {name}, does not exist in the ONNX model. " +
$"Available input names are [{string.Join(",", InputNames)}]. Available initializers are [{string.Join(",", InitializerNames)}]");
}

/// <summary>
Expand Down Expand Up @@ -180,8 +197,12 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
var inputTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Input)
inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
var outputTypePool = new Dictionary<string, DataViewType>();

var initializerTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Initializer)
initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);

var outputTypePool = new Dictionary<string, DataViewType>();
// Build casters which maps NamedOnnxValue to .NET objects.
var casterPool = new Dictionary<string, Func<NamedOnnxValue, object>>();
foreach (var valueInfo in model.Graph.Output)
Expand All @@ -190,60 +211,31 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
}

var onnxRuntimeInputInfos = new List<OnnxVariableInfo>();
// Collect input information for this ONNX model from ONNXRuntime's perspective.
foreach (var pair in _session.InputMetadata)
{
var name = pair.Key;
var meta = pair.Value;
var dataViewType = inputTypePool[name];

OnnxVariableInfo info = null;
if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
{
// If user provides a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from
// ONNX model file and the deduced DataViewVectorType.

if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList()))
throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary),
"The specified shape " + string.Join(",", shapeDictionary[name]) +
" is not compatible with the shape " + string.Join(",", meta.Dimensions) +
" loaded from the ONNX model file. Only unknown dimension can replace or " +
"be replaced by another dimension.");
var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

if (dataViewType is VectorDataViewType vectorType)
{
if (shapeDictionary[name].All(value => value > 0))
dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]);
else
dataViewType = new VectorDataViewType(vectorType.ItemType);
}
// Create a view to the used ONNX model from ONNXRuntime's perspective.
ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);
}

info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, null);
}
else
{
// No user-specified shape is found, so the shape loaded from ONNX model file is used.
info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null);
}
onnxRuntimeInputInfos.Add(info);
}
private List<OnnxVariableInfo> GetOnnxVariablesFromMetadata(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata,
IDictionary<string, int[]> shapeDictionary,
Dictionary<string, DataViewType> typePool,
Dictionary<string, Func<NamedOnnxValue, object>> casterPool)
{
var onnxVariableInfos = new List<OnnxVariableInfo>();

var onnxRuntimeOutputInfos = new List<OnnxVariableInfo>();
// Collect output information for this ONNX model from ONNXRuntime's perspective.
foreach (var pair in _session.OutputMetadata)
foreach (var pair in nodeMetadata)
{
var name = pair.Key;
var meta = pair.Value;
var dataViewType = outputTypePool[name];
var caster = casterPool[name];
var dataViewType = typePool[name];
var caster = casterPool?[name];

OnnxVariableInfo info = null;
if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
{
// If user provide a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from
// ONNX model file.

if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList()))
throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary),
"The specified shape " + string.Join(",", shapeDictionary[name]) +
Expand All @@ -267,11 +259,9 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster);
}

onnxRuntimeOutputInfos.Add(info);
onnxVariableInfos.Add(info);
}

// Create a view to the used ONNX model from ONNXRuntime's perspective.
ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos);
return onnxVariableInfos;
}

/// <summary>
Expand Down
Loading