Skip to content

Commit 6be61dc

Browse files
committed
Address comments
1 parent 4aa9757 commit 6be61dc

File tree

6 files changed

+52
-49
lines changed

6 files changed

+52
-49
lines changed

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,20 @@ namespace Microsoft.ML
1313
/// </summary>
1414
public sealed class ModelOperationsCatalog
1515
{
16+
/// <summary>
17+
/// This is a best friend because an extension method defined in another assembly needs this field.
18+
/// </summary>
19+
[BestFriend]
1620
internal IHostEnvironment Environment { get; }
1721

1822
public ExplainabilityTransforms Explainability { get; }
1923

20-
public PortabilityTransforms Portability { get; }
21-
2224
internal ModelOperationsCatalog(IHostEnvironment env)
2325
{
2426
Contracts.AssertValue(env);
2527
Environment = env;
2628

2729
Explainability = new ExplainabilityTransforms(this);
28-
Portability = new PortabilityTransforms(this);
2930
}
3031

3132
public abstract class SubCatalogBase
@@ -62,17 +63,6 @@ internal ExplainabilityTransforms(ModelOperationsCatalog owner) : base(owner)
6263
}
6364
}
6465

65-
/// <summary>
66-
/// The catalog of model protability operations. Member function of this classes are able to convert the associated object to a protable format,
67-
/// so that the fitted pipeline can easily be depolyed to other platforms. Currently, the only supported format is ONNX (https://github.com/onnx/onnx).
68-
/// </summary>
69-
public sealed class PortabilityTransforms : SubCatalogBase
70-
{
71-
internal PortabilityTransforms(ModelOperationsCatalog owner) : base(owner)
72-
{
73-
}
74-
}
75-
7666
/// <summary>
7767
/// Create a prediction engine for one-time prediction.
7868
/// </summary>
Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,37 @@
1-
using System.Collections.Generic;
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
26
using Microsoft.ML.Core.Data;
37
using Microsoft.ML.Data;
48
using Microsoft.ML.Model.Onnx;
59
using Microsoft.ML.UniversalModelFormat.Onnx;
610

711
namespace Microsoft.ML
812
{
9-
public static class ProtabilityCatalog
13+
public static class OnnxExportExtensions
1014
{
1115
/// <summary>
1216
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
1317
/// </summary>
14-
/// <param name="catalog">A field in <see cref="MLContext"/> which this function associated with.</param>
18+
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, IDataView)"/> attached to.</param>
1519
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
1620
/// <param name="inputData">The input of the specified transform.</param>
17-
/// <returns></returns>
18-
public static ModelProto ConvertToOnnx(this ModelOperationsCatalog.PortabilityTransforms catalog, ITransformer transform, IDataView inputData)
21+
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
22+
public static ModelProto ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData)
1923
{
20-
var env = new MLContext(seed: 1);
21-
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.microsoft", OnnxVersion.Stable);
24+
var env = catalog.Environment;
25+
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "ai.onnx.ml", OnnxVersion.Stable);
2226
var outputData = transform.Transform(inputData);
2327
IDataView root = null;
2428
IDataView sink = null;
2529
LinkedList<ITransformCanSaveOnnx> transforms = null;
26-
using (var ch = (env as IChannelProvider).Start("ONNX conversion"))
30+
using (var ch = env.Start("ONNX conversion"))
31+
{
2732
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out root, out sink, out transforms);
28-
29-
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, root, sink, transforms, null, null);
33+
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
34+
}
3035
}
3136
}
3237
}

src/Microsoft.ML.Onnx/SaveOnnxCommand.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public override void Run()
116116

117117
internal static void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList<ITransformCanSaveOnnx> transforms)
118118
{
119-
Contracts.AssertValue(end);
119+
ch.AssertValue(end);
120120

121121
source = trueEnd = (end as CompositeDataLoader)?.View ?? end;
122122
IDataTransform transform = source as IDataTransform;
@@ -136,10 +136,10 @@ internal static void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, ou
136136
transform = (source = transform.Source) as IDataTransform;
137137
}
138138

139-
Contracts.AssertValue(source);
139+
ch.AssertValue(source);
140140
}
141141

142-
internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx, IDataView inputData, IDataView outputData,
142+
internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx, IChannel ch, IDataView inputData, IDataView outputData,
143143
LinkedList<ITransformCanSaveOnnx> transforms, HashSet<string> inputColumnNamesToDrop=null, HashSet<string> outputColumnNamesToDrop=null)
144144
{
145145
inputColumnNamesToDrop = inputColumnNamesToDrop ?? new HashSet<string>();
@@ -158,7 +158,10 @@ internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx,
158158

159159
// Create graph nodes, outputs and intermediate values.
160160
foreach (var trans in transforms)
161+
{
162+
ch.Assert(trans.CanSaveOnnx(ctx));
161163
trans.SaveAsOnnx(ctx);
164+
}
162165

163166
// Add graph outputs.
164167
for (int i = 0; i < outputData.Schema.Count; ++i)
@@ -255,7 +258,7 @@ private void Run(IChannel ch)
255258
nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
256259
}
257260

258-
var model = ConvertTransformListToOnnxModel(ctx, source, end, transforms, _inputsToDrop, _outputsToDrop);
261+
var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop);
259262

260263
using (var file = Host.CreateOutputFile(_outputModelPath))
261264
using (var stream = file.CreateWriteStream())

test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"irVersion": "3",
33
"producerName": "ML.NET",
44
"producerVersion": "##VERSION##",
5-
"domain": "com.microsoft",
5+
"domain": "ai.onnx.ml",
66
"graph": {
77
"node": [
88
{

test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"irVersion": "3",
33
"producerName": "ML.NET",
44
"producerVersion": "##VERSION##",
5-
"domain": "com.microsoft",
5+
"domain": "ai.onnx.ml",
66
"graph": {
77
"node": [
88
{

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.IO;
1+
using System;
2+
using System.IO;
23
using System.Linq;
34
using System.Runtime.InteropServices;
45
using System.Text.RegularExpressions;
@@ -31,7 +32,7 @@ public OnnxConversionTest(ITestOutputHelper output) : base(output)
3132
/// <summary>
3233
/// In this test, we convert a trained <see cref="TransformerChain"/> into ONNX <see cref="UniversalModelFormat.Onnx.ModelProto"/> file and then
3334
/// call <see cref="OnnxScoringEstimator"/> to evaluate that file. The outputs of <see cref="OnnxScoringEstimator"/> are checked against the original
34-
/// ML.NET model's outputs.
35+
/// ML.NET model's outputs.
3536
/// </summary>
3637
[Fact]
3738
public void SimpleEndToEndOnnxConversionTest()
@@ -52,12 +53,12 @@ public void SimpleEndToEndOnnxConversionTest()
5253
var transformedData = model.Transform(data);
5354

5455
// Step 2: Convert ML.NET model to ONNX format and save it as a file.
55-
var onnxModel = mlContext.Model.Portability.ConvertToOnnx(model, data);
56+
var onnxModel = mlContext.Model.ConvertToOnnx(model, data);
5657
var onnxFileName = "model.onnx";
5758
var onnxModelPath = GetOutputPath(onnxFileName);
5859
SaveOnnxModel(onnxModel, onnxModelPath, null);
5960

60-
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
61+
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess)
6162
{
6263
// Step 3: Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
6364
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
@@ -70,7 +71,8 @@ public void SimpleEndToEndOnnxConversionTest()
7071
CompareSelectedR4ScalarColumns("Score", "Score0", transformedData, onnxResult, 2);
7172
}
7273

73-
// Step 5: Check ONNX model's text format.
74+
// Step 5: Check ONNX model's text format. This test will be not necessary if Step 3 and Step 4 can run on Linux and
75+
// Mac to support cross-platform tests.
7476
var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Regression", "Adult");
7577
var onnxTextName = "SimplePipeline.txt";
7678
var onnxTextPath = GetOutputPath(subDir, onnxTextName);
@@ -86,11 +88,6 @@ private class BreastCancerFeatureVector
8688
public float[] Features;
8789
}
8890

89-
private void CreateDummyExamplesToMakeComplierHappy()
90-
{
91-
var dummyExample = new BreastCancerFeatureVector() { Features = null };
92-
}
93-
9491
[Fact]
9592
public void KmeansOnnxConversionTest()
9693
{
@@ -102,24 +99,24 @@ public void KmeansOnnxConversionTest()
10299
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
103100
var data = mlContext.Data.ReadFromTextFile<BreastCancerFeatureVector>(dataPath,
104101
hasHeader: true,
105-
separatorChar: '\t' );
102+
separatorChar: '\t');
106103

107104
var pipeline = mlContext.Transforms.Normalize("Features").
108105
Append(mlContext.Clustering.Trainers.KMeans(features: "Features", advancedSettings: settings =>
109106
{
110-
settings.MaxIterations = 1;
111-
settings.K = 4;
112-
settings.NumThreads = 1;
113-
settings.InitAlgorithm = Trainers.KMeans.KMeansPlusPlusTrainer.InitAlgorithm.KMeansPlusPlus;
107+
settings.MaxIterations = 1;
108+
settings.K = 4;
109+
settings.NumThreads = 1;
110+
settings.InitAlgorithm = Trainers.KMeans.KMeansPlusPlusTrainer.InitAlgorithm.KMeansPlusPlus;
114111
}));
115112

116113
var model = pipeline.Fit(data);
117114
var transformedData = model.Transform(data);
118115

119-
var onnxModel = mlContext.Model.Portability.ConvertToOnnx(model, data);
116+
var onnxModel = mlContext.Model.ConvertToOnnx(model, data);
120117

121118
// Compare results produced by ML.NET and ONNX's runtime.
122-
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
119+
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess)
123120
{
124121
var onnxFileName = "model.onnx";
125122
var onnxModelPath = GetOutputPath(onnxFileName);
@@ -134,6 +131,9 @@ public void KmeansOnnxConversionTest()
134131
CompareSelectedR4VectorColumns("Score", "Score0", transformedData, onnxResult, 3);
135132
}
136133

134+
// Check ONNX model's text format. We save the produced ONNX model as a text file and compare it against
135+
// the associated file in ML.NET repo. Such a comparison can be retired if ONNXRuntime ported to ML.NET
136+
// can support Linux and Mac.
137137
var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Cluster", "BreastCancer");
138138
var onnxTextName = "Kmeans.txt";
139139
var onnxTextPath = GetOutputPath(subDir, onnxTextName);
@@ -142,7 +142,12 @@ public void KmeansOnnxConversionTest()
142142
Done();
143143
}
144144

145-
private void CompareSelectedR4VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision=6)
145+
private void CreateDummyExamplesToMakeComplierHappy()
146+
{
147+
var dummyExample = new BreastCancerFeatureVector() { Features = null };
148+
}
149+
150+
private void CompareSelectedR4VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6)
146151
{
147152
var leftColumnIndex = left.Schema[leftColumnName].Index;
148153
var rightColumnIndex = right.Schema[rightColumnName].Index;
@@ -166,7 +171,7 @@ private void CompareSelectedR4VectorColumns(string leftColumnName, string rightC
166171
}
167172
}
168173

169-
private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision=6)
174+
private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6)
170175
{
171176
var leftColumnIndex = left.Schema[leftColumnName].Index;
172177
var rightColumnIndex = right.Schema[rightColumnName].Index;

0 commit comments

Comments
 (0)