From 737ec83a7bc87453186abaa35bc8b2f22449372e Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Fri, 25 Oct 2019 09:51:58 -0700 Subject: [PATCH 1/5] Initial work for adding onnx export support for OptionalColumnTransform --- .../Model/Onnx/OnnxContext.cs | 18 ++++--- .../OnnxContextImpl.cs | 29 +++++----- .../OptionalColumnTransform.cs | 53 ++++++++++++++++++- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 43 +++++++++++++++ 4 files changed, 122 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 42e735e974..3efc0e63e0 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -128,24 +128,27 @@ public OnnxNode CreateNode(string opType, string input, string output, string na /// /// The float number which is going to be added /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name - public abstract string AddInitializer(float value, string name = null); + public abstract string AddInitializer(float value, string name = null, bool makeUniqueName = true); /// /// Call this function can declare a global long /// /// The long number which is going to be added into the ONNX graph /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name - public abstract string AddInitializer(long value, string name = null); + public abstract string AddInitializer(long value, string name = null, bool makeUniqueName = true); /// /// Call this function can declare a global string /// /// The string which is going to be added into the ONNX graph /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name - public abstract string AddInitializer(string value, string name = null); + public abstract string AddInitializer(string value, string name = null, bool makeUniqueName = true); /// /// Call this function can declare a global float tensor @@ -153,8 +156,9 @@ public OnnxNode CreateNode(string opType, string input, string output, string na /// The floats which are going to be added into the ONNX graph /// The shape that the floats /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name - public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null); + public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); /// /// Call this function can declare a global long tensor @@ -162,8 +166,9 @@ public OnnxNode CreateNode(string opType, string input, string output, string na /// The longs which are going to be added into the ONNX graph /// The shape that the floats /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name - public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null); + public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); /// /// Call this function can declare a global string tensor @@ -171,7 +176,8 @@ public OnnxNode CreateNode(string opType, string input, string output, string na /// The strings which are going to be added into the ONNX graph /// The shape that the strings /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. /// The initializer's ONNX name - public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null); + public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); } } diff --git a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs index e14d1c4489..508d419764 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs @@ -200,11 +200,12 @@ public string TryGetVariableName(string colName) /// there is a collision between names in the pipeline at any point. /// /// IDataView column name. + /// Whether a unique name should be chosen for this variable. /// Unique variable name. - public string AddVariable(string colName) + public string AddVariable(string colName, bool makeUniqueName = true) { _host.CheckNonEmpty(colName, nameof(colName)); - _columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains); + _columnNameMap[colName] = makeUniqueName ? GetUniqueName(colName, _variableNames.Contains) : colName; _variableNames.Add(_columnNameMap[colName]); return _columnNameMap[colName]; } @@ -269,56 +270,56 @@ public override List RetrieveShapeOrNull(string variableName) } /// Adds constant tensor into the graph. - public override string AddInitializer(float value, string name = null) + public override string AddInitializer(float value, string name = null, bool makeUniqueName = true) { - name = AddVariable(name ?? "float"); + name = AddVariable(name ?? "float", makeUniqueName); _initializers.Add(OnnxUtils.MakeFloat(name, value)); return name; } - public override string AddInitializer(string value, string name = null) + public override string AddInitializer(string value, string name = null, bool makeUniqueName = true) { - name = AddVariable(name ?? "string"); + name = AddVariable(name ?? "string", makeUniqueName); _initializers.Add(OnnxUtils.MakeString(name, value)); return name; } - public override string AddInitializer(long value, string name = null) + public override string AddInitializer(long value, string name = null, bool makeUniqueName = true) { - name = AddVariable(name ?? "int64"); + name = AddVariable(name ?? "int64", makeUniqueName); _initializers.Add(OnnxUtils.MakeInt64(name, value)); return name; } - public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null) + public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) { _host.CheckValue(values, nameof(values)); if (dims != null) _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size"); - name = AddVariable(name ?? "floats"); + name = AddVariable(name ?? "floats", makeUniqueName); _initializers.Add(OnnxUtils.MakeFloats(name, values, dims)); return name; } - public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null) + public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) { _host.CheckValue(values, nameof(values)); if (dims != null) _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size"); - name = AddVariable(name ?? "int64s"); + name = AddVariable(name ?? "int64s", makeUniqueName); _initializers.Add(OnnxUtils.MakeInt64s(name, values, dims)); return name; } - public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null) + public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) { _host.CheckValue(values, nameof(values)); if (dims != null) _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size"); - name = AddVariable(name ?? "strings"); + name = AddVariable(name ?? "strings", makeUniqueName); _initializers.Add(OnnxUtils.MakeStrings(name, values, dims)); return name; } diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index 7c1cec1efb..26e42b9be1 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Data.IO; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -29,7 +30,7 @@ namespace Microsoft.ML.Transforms { /// [BestFriend] - internal sealed class OptionalColumnTransform : RowToRowMapperTransformBase + internal sealed class OptionalColumnTransform : RowToRowMapperTransformBase, ITransformCanSaveOnnx { public sealed class Arguments : TransformInputBase { @@ -498,6 +499,56 @@ private Delegate MakeGetterVec(int length) } } + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + Host.Assert(((ICanSaveOnnx)this).CanSaveOnnx(ctx)); + + for (int iinfo = 0; iinfo < _bindings.ColumnTypes.Length; ++iinfo) + { + var columnType = _bindings.ColumnTypes[iinfo]; + string inputColumnName = Source.Schema[_bindings.SrcCols[iinfo]].Name; + if (!ctx.ContainsColumn(inputColumnName)) + { + ctx.RemoveColumn(inputColumnName, false); + continue; + } + + if (!SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName), + ctx.AddIntermediateVariable(OutputSchema[_bindings.MapIinfoToCol(iinfo)].Type, inputColumnName))) + { + ctx.RemoveColumn(inputColumnName, true); + } + } + } + + public bool CanSaveOnnx(OnnxContext ctx) => true; + + private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) + { + var columnType = _bindings.ColumnTypes[iinfo]; + string inputColumnName = Source.Schema[_bindings.SrcCols[iinfo]].Name; + + Type type = columnType.RawType; + + int size; + if (columnType is VectorDataViewType && columnType.IsKnownSizeVector()) + size = columnType.GetVectorSize(); + else + size = 1; + + if (type == typeof(float)) + ctx.AddInitializer(new float[size], new long[] { 1, size }, inputColumnName, false); + else if (type == typeof(long)) + ctx.AddInitializer(new long[size], new long[] { 1, size }, inputColumnName, false); + else if (type == typeof(string)) + ctx.AddInitializer(new string[size], new long[] { 1, size }, inputColumnName, false); + else + return false; + + return true; + } + [TlcModule.EntryPoint(Desc = Summary, Name = "Transforms.OptionalColumnCreator", UserName = UserName, diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 5e08a64fb2..c0ad80ef00 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -997,6 +997,49 @@ from weighting in weightingCriteria Done(); } + [Fact] + public void OptionalColumnOnnxTest() + { + var mlContext = new MLContext(seed: 1); + + var samples = new List() + { + new BreastCancerCatFeatureExample() { Label = false, F1 = 0.0f, F2 = "F2"}, + new BreastCancerCatFeatureExample() { Label = true, F1 = 0.1f, F2 = "F2"}, + }; + IHostEnvironment env = mlContext as IHostEnvironment; + var dataView = mlContext.Data.LoadFromEnumerable(samples); + var args = new OptionalColumnTransform.Arguments { Columns = new[] { "F1" }, Data = dataView }; + var transform = OptionalColumnTransform.MakeOptional(env, args); + + var ctx = new OnnxContextImpl(mlContext, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable); + var outputData = transform.OutputData; + LinkedList transforms = null; + ModelProto onnxModel; + using (var ch = env.Start("ONNX conversion")) + { + SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms); + onnxModel = SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null); + } + + var onnxFileName = "optionalcol.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + var onnxTextFileName = "optionalcol.txt"; + var onnxTextPath = GetOutputPath(onnxTextFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath); + if (IsOnnxRuntimeSupported()) + { + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + //CompareSelectedVectorColumns(model.LastTransformer.ColumnPairs[0].outputColumnName, outputNames[1], transformedData, onnxResult); + } + Done(); + } + private void CreateDummyExamplesToMakeComplierHappy() { var dummyExample = new BreastCancerFeatureVector() { Features = null }; From 7c460bfd91c0b0052d7a06bef7b55875fb77f9a5 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Fri, 25 Oct 2019 09:56:45 -0700 Subject: [PATCH 2/5] Implemented support for optional initializers in OnnxTranformer to support OptionalColumnTransform --- src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs | 98 +++++++++---------- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 2 +- 2 files changed, 45 insertions(+), 55 deletions(-) diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs index 446a48bb7d..ace64905f9 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs @@ -38,6 +38,10 @@ public sealed class OnnxModelInfo /// public List OutputNames { get; } /// + /// Initializers[i] is the name of the i-th initializer in . + /// + public List InitializerNames { get; } + /// /// Inputs of the containing . /// public OnnxVariableInfo[] InputsInfo { get; } @@ -46,12 +50,19 @@ public sealed class OnnxModelInfo /// public OnnxVariableInfo[] OutputsInfo { get; } - public OnnxModelInfo(IEnumerable inputsInfo, IEnumerable outputsInfo) + /// + /// Initializers of the containing + /// + public OnnxVariableInfo[] InitializersInfo { get; } + + public OnnxModelInfo(IEnumerable inputsInfo, IEnumerable outputsInfo, IEnumerable initializersInfo) { InputNames = inputsInfo.Select(val => val.Name).ToList(); InputsInfo = inputsInfo.ToArray(); OutputNames = outputsInfo.Select(val => val.Name).ToList(); OutputsInfo = outputsInfo.ToArray(); + InitializerNames = initializersInfo.Select(val => val.Name).ToList(); + InitializersInfo = initializersInfo.ToArray(); } /// @@ -60,10 +71,16 @@ public OnnxModelInfo(IEnumerable inputsInfo, IEnumerable= 0) + return InputsInfo[index]; + + index = InitializerNames.IndexOf(name); + if (index >= 0) + return InitializersInfo[index]; + + // If we dont find the index in the input, try find it in the initializers + throw Contracts.ExceptParamValue(name, nameof(name), $"Input tensor, {name}, does not exist in the ONNX model. " + + $"Available input names are [{string.Join(",", InputNames)}]. Available initializers are [{string.Join(",", InitializerNames)}]"); } /// @@ -180,8 +197,12 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = var inputTypePool = new Dictionary(); foreach (var valueInfo in model.Graph.Input) inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); - var outputTypePool = new Dictionary(); + var initializerTypePool = new Dictionary(); + foreach (var valueInfo in model.Graph.Initializer) + initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType); + + var outputTypePool = new Dictionary(); // Build casters which maps NamedOnnxValue to .NET objects. var casterPool = new Dictionary>(); foreach (var valueInfo in model.Graph.Output) @@ -190,60 +211,31 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); } - var onnxRuntimeInputInfos = new List(); - // Collect input information for this ONNX model from ONNXRuntime's perspective. - foreach (var pair in _session.InputMetadata) - { - var name = pair.Key; - var meta = pair.Value; - var dataViewType = inputTypePool[name]; - - OnnxVariableInfo info = null; - if (shapeDictionary != null && shapeDictionary.ContainsKey(name)) - { - // If user provides a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from - // ONNX model file and the deduced DataViewVectorType. - - if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList())) - throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary), - "The specified shape " + string.Join(",", shapeDictionary[name]) + - " is not compatible with the shape " + string.Join(",", meta.Dimensions) + - " loaded from the ONNX model file. Only unknown dimension can replace or " + - "be replaced by another dimension."); + var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null); + var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool); + var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null); - if (dataViewType is VectorDataViewType vectorType) - { - if (shapeDictionary[name].All(value => value > 0)) - dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]); - else - dataViewType = new VectorDataViewType(vectorType.ItemType); - } + // Create a view to the used ONNX model from ONNXRuntime's perspective. + ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers); + } - info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, null); - } - else - { - // No user-specified shape is found, so the shape loaded from ONNX model file is used. - info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null); - } - onnxRuntimeInputInfos.Add(info); - } + private List GetOnnxVariablesFromMetadata(IReadOnlyDictionary nodeMetadata, + IDictionary shapeDictionary, + Dictionary typePool, + Dictionary> casterPool) + { + var onnxVariableInfos = new List(); - var onnxRuntimeOutputInfos = new List(); - // Collect output information for this ONNX model from ONNXRuntime's perspective. - foreach (var pair in _session.OutputMetadata) + foreach (var pair in nodeMetadata) { var name = pair.Key; var meta = pair.Value; - var dataViewType = outputTypePool[name]; - var caster = casterPool[name]; + var dataViewType = typePool[name]; + var caster = casterPool?[name]; OnnxVariableInfo info = null; if (shapeDictionary != null && shapeDictionary.ContainsKey(name)) { - // If user provide a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from - // ONNX model file. - if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList())) throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary), "The specified shape " + string.Join(",", shapeDictionary[name]) + @@ -267,11 +259,9 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster); } - onnxRuntimeOutputInfos.Add(info); + onnxVariableInfos.Add(info); } - - // Create a view to the used ONNX model from ONNXRuntime's perspective. - ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos); + return onnxVariableInfos; } /// diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index c0ad80ef00..1ef88c3889 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1035,7 +1035,7 @@ public void OptionalColumnOnnxTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - //CompareSelectedVectorColumns(model.LastTransformer.ColumnPairs[0].outputColumnName, outputNames[1], transformedData, onnxResult); + CompareSelectedR4ScalarColumns(transform.Model.OutputSchema[2].Name, outputNames[1], outputData, onnxResult); } Done(); } From 8204cde8ca828cb58214f81b65f09f47f094034e Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Fri, 25 Oct 2019 09:58:23 -0700 Subject: [PATCH 3/5] Fixed handling of double values and non-long numeric types --- src/Microsoft.ML.Transforms/OptionalColumnTransform.cs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index 26e42b9be1..40156bcc0c 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -537,9 +537,13 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, else size = 1; - if (type == typeof(float)) + // REVIEW: + // AddInitializer only supports long, float and string. + // Is it correct to cast double to float and ulong to long? + if ((type == typeof(float)) || (type == typeof(double))) ctx.AddInitializer(new float[size], new long[] { 1, size }, inputColumnName, false); - else if (type == typeof(long)) + else if ((type == typeof(long)) || (type == typeof(int)) || (type == typeof(short)) || (type == typeof(sbyte)) || + (type == typeof(ulong)) || (type == typeof(uint)) || (type == typeof(ushort)) || (type == typeof(byte))) ctx.AddInitializer(new long[size], new long[] { 1, size }, inputColumnName, false); else if (type == typeof(string)) ctx.AddInitializer(new string[size], new long[] { 1, size }, inputColumnName, false); From 8086f24316bc2b4e4f9241d3818ef7cac9f00c54 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Thu, 7 Nov 2019 09:34:14 -0800 Subject: [PATCH 4/5] Removed redundant line --- src/Microsoft.ML.Transforms/OptionalColumnTransform.cs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index 40156bcc0c..f689fba2fb 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -509,10 +509,7 @@ public void SaveAsOnnx(OnnxContext ctx) var columnType = _bindings.ColumnTypes[iinfo]; string inputColumnName = Source.Schema[_bindings.SrcCols[iinfo]].Name; if (!ctx.ContainsColumn(inputColumnName)) - { - ctx.RemoveColumn(inputColumnName, false); continue; - } if (!SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName), ctx.AddIntermediateVariable(OutputSchema[_bindings.MapIinfoToCol(iinfo)].Type, inputColumnName))) From 3f105849e79c5474478ebf418d86e842824020fb Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Wed, 13 Nov 2019 14:41:26 -0800 Subject: [PATCH 5/5] Updated review comment --- src/Microsoft.ML.Transforms/OptionalColumnTransform.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index f689fba2fb..a69664df30 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -536,7 +536,8 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, // REVIEW: // AddInitializer only supports long, float and string. - // Is it correct to cast double to float and ulong to long? + // Here we are casting double to float and ulong to long. + // Fixing this would involve adding additional functions to OnnxContext. if ((type == typeof(float)) || (type == typeof(double))) ctx.AddInitializer(new float[size], new long[] { 1, size }, inputColumnName, false); else if ((type == typeof(long)) || (type == typeof(int)) || (type == typeof(short)) || (type == typeof(sbyte)) ||