Skip to content

Export to ONNX and cross-platform command-line tool to script ML.NET training and inference #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jun 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1b78603
Export to ONNX and Maml cross-platform executable
codemzs May 28, 2018
1508cd4
misc.
codemzs May 29, 2018
c6763f3
PR feedback.
codemzs May 30, 2018
b04ef49
PR feedback.
codemzs May 30, 2018
eaa8e3a
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs May 30, 2018
de02e1e
resolve merge issues.
codemzs May 30, 2018
6f4434e
cleanup.
codemzs May 30, 2018
6c97416
cleanup.
codemzs May 30, 2018
534fcd1
PR feedback.
codemzs May 30, 2018
caebed0
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs May 30, 2018
692b526
update test baselines.
codemzs May 30, 2018
6c92033
cleanup.
codemzs May 30, 2018
e466157
cleanup.
codemzs May 30, 2018
5a04795
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs May 31, 2018
8ef9b3f
PR feedback.
codemzs May 31, 2018
64cdb80
PR feedback.
codemzs May 31, 2018
3276dd3
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Jun 1, 2018
c6bc1c6
PR feedback.
codemzs Jun 1, 2018
17a738a
update baselines and regenerate csharp APIs.
codemzs Jun 1, 2018
5090d24
Add link to the commit in ONNX MD file.
codemzs Jun 1, 2018
4cfed38
PR feedback.
codemzs Jun 1, 2018
0a13ad7
cleanup.
codemzs Jun 1, 2018
5bea824
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Jun 4, 2018
2ab729f
Add missing attributes to ONNX model.
codemzs Jun 4, 2018
faf528c
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Jun 5, 2018
3dfd81f
cleanup.
codemzs Jun 5, 2018
1865825
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Jun 6, 2018
3562331
add more commands.
codemzs Jun 6, 2018
d20f1a4
cleanup.
codemzs Jun 6, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Core", "src\Microsoft.ML.Core\Microsoft.ML.Core.csproj", "{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{09EADF06-BE25-4228-AB53-95AE3E15B530}"
ProjectSection(SolutionItems) = preProject
src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj = src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{AED9C836-31E3-4F3F-8ABC-929555D3F3C4}"
EndProject
Expand All @@ -30,8 +33,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.KMeansClusteri
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.PCA", "src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj", "{58E06735-1129-4DD5-86E0-6BBFF049AAD9}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj", "{D956E291-F6E5-4474-9023-91793F45ABEB}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Api", "src\Microsoft.ML.Api\Microsoft.ML.Api.csproj", "{2F636A2C-062C-49F4-85F3-60DCADAB6A43}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Tests", "test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj", "{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}"
Expand Down Expand Up @@ -104,6 +105,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Parquet", "Mic
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Benchmarks", "test\Microsoft.ML.Benchmarks\Microsoft.ML.Benchmarks.csproj", "{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj", "{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Console", "src\Microsoft.ML.Console\Microsoft.ML.Console.csproj", "{362A98CF-FBF7-4EBB-A11B-990BBF845B15}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -158,10 +163,6 @@ Global
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.Build.0 = Release|Any CPU
{D956E291-F6E5-4474-9023-91793F45ABEB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{D956E291-F6E5-4474-9023-91793F45ABEB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{D956E291-F6E5-4474-9023-91793F45ABEB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{D956E291-F6E5-4474-9023-91793F45ABEB}.Release|Any CPU.Build.0 = Release|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.Build.0 = Debug|Any CPU
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.ActiveCfg = Release|Any CPU
Expand Down Expand Up @@ -202,6 +203,14 @@ Global
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.Build.0 = Release|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.Build.0 = Release|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.Build.0 = Debug|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.ActiveCfg = Release|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand All @@ -219,7 +228,6 @@ Global
{7288C084-11C0-43BE-AC7F-45DCFEAEEBF6} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{58E06735-1129-4DD5-86E0-6BBFF049AAD9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{D956E291-F6E5-4474-9023-91793F45ABEB} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{2F636A2C-062C-49F4-85F3-60DCADAB6A43} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{FDA2FD2C-A708-43AC-A941-4D941B0853BF} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
Expand All @@ -236,6 +244,8 @@ Global
{DEC8F776-49F7-4D87-836C-FE4DC057D08C} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{6C95FC87-F5F2-4EEF-BB97-567F2F5DD141} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{362A98CF-FBF7-4EBB-A11B-990BBF845B15} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.ML.Console/Console.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.ML.Runtime.Tools.Console
{
public static class Console
{
public static int Main(string[] args) => Maml.Main(args);
}
}
20 changes: 20 additions & 0 deletions src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>CORECLR</DefineConstants>
<IncludeInPackage>Microsoft.ML</IncludeInPackage>
<TargetFramework>netcoreapp2.0</TargetFramework>
<OutputType>Exe</OutputType>
<AssemblyName>MML</AssemblyName>
<StartupObject>Microsoft.ML.Runtime.Tools.Console.Console</StartupObject>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<ProjectReference Include="..\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
<ProjectReference Include="..\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
</ItemGroup>

</Project>
18 changes: 9 additions & 9 deletions src/Microsoft.ML.Data/Commands/DataCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,38 @@ public static class DataCommand
{
public abstract class ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>")]
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>")]
public SubComponent<IDataLoader, SignatureDataLoader> Loader;

[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)]
public string DataFile;

[Argument(ArgumentType.AtMostOnce, HelpText = "Model file to save", ShortName = "out")]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Model file to save", ShortName = "out")]
public string OutputModelFile;

[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)]
public string InputModelFile;

[Argument(ArgumentType.Multiple, HelpText = "Load transforms from model file?", ShortName = "loadTrans", SortOrder = 91)]
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Load transforms from model file?", ShortName = "loadTrans", SortOrder = 91)]
public bool? LoadTransforms;

[Argument(ArgumentType.AtMostOnce, HelpText = "Random seed", ShortName = "seed", SortOrder = 101)]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Random seed", ShortName = "seed", SortOrder = 101)]
public int? RandomSeed;

[Argument(ArgumentType.AtMostOnce, HelpText = "Verbose?", ShortName = "v", Hide = true)]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Verbose?", ShortName = "v", Hide = true)]
public bool? Verbose;

[Argument(ArgumentType.AtMostOnce, HelpText = "The web server to publish the RESTful API", Hide = true)]
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The web server to publish the RESTful API", Hide = true)]
public ServerChannel.IServerFactory Server;

// This is actually an advisory value. The implementations themselves are responsible for
// determining what they consider appropriate, and the actual heuristics is a bit more
// complex than just this.
[Argument(ArgumentType.LastOccurenceWins,
[Argument(ArgumentType.LastOccurenceWins, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly,
HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")]
public int? Parallel;

[Argument(ArgumentType.Multiple, HelpText = "Transform", ShortName = "xf")]
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf")]
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] Transform;
}

Expand Down
11 changes: 9 additions & 2 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ public sealed class OnnxContext
private readonly HashSet<string> _variableMap;
private readonly HashSet<string> _nodeNames;
private readonly string _name;
private readonly string _producerName;
private readonly IHost _host;
private readonly string _domain;
private readonly string _producerVersion;
private readonly long _modelVersion;

public OnnxContext(IHostEnvironment env, string name, string domain)
public OnnxContext(IHostEnvironment env, string name, string producerName,
string producerVersion, long modelVersion, string domain)
{
Contracts.CheckValue(env, nameof(env));
Contracts.CheckValue(name, nameof(name));
Expand All @@ -41,6 +45,9 @@ public OnnxContext(IHostEnvironment env, string name, string domain)
_variableMap = new HashSet<string>();
_nodeNames = new HashSet<string>();
_name = name;
_producerName = producerName;
_producerVersion = producerVersion;
_modelVersion = modelVersion;
_domain = domain;
}

Expand Down Expand Up @@ -234,6 +241,6 @@ public void AddInputVariable(ColumnType type, string colName)
/// Makes the ONNX model based on the context.
/// </summary>
public ModelProto MakeModel()
=> OnnxUtils.MakeModel(_nodes, _name, _name, _domain, _inputs, _outputs, _intermediateValues);
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
}
}
13 changes: 10 additions & 3 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private static AttributeProto MakeAttribute(string key, IEnumerable<GraphProto>

private static AttributeProto MakeAttribute(string key, bool value) => MakeAttribute(key, value ? 1 : 0);

public static NodeProto MakeNode(string opType, List<string> inputs, List<string> outputs, string name)
public static NodeProto MakeNode(string opType, List<string> inputs, List<string> outputs, string name, string domain = null)
{
Contracts.CheckNonEmpty(opType, nameof(opType));
Contracts.CheckValue(inputs, nameof(inputs));
Expand All @@ -165,7 +165,7 @@ public static NodeProto MakeNode(string opType, List<string> inputs, List<string
node.Input.Add(inputs);
node.Output.Add(outputs);
node.Name = name;
node.Domain = "ai.onnx.ml";
node.Domain = domain ?? "ai.onnx.ml";
return node;
}

Expand Down Expand Up @@ -251,7 +251,8 @@ public NodeProtoWrapper(NodeProto node)
}
}

public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name, string domain, List<ModelArgs> inputs,
public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name,
string domain, string producerVersion, long modelVersion, List<ModelArgs> inputs,
List<ModelArgs> outputs, List<ModelArgs> intermediateValues)
{
Contracts.CheckValue(nodes, nameof(nodes));
Expand All @@ -261,10 +262,16 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
Contracts.CheckNonEmpty(producerName, nameof(producerName));
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckNonEmpty(domain, nameof(domain));
Contracts.CheckNonEmpty(producerVersion, nameof(producerVersion));

var model = new ModelProto();
model.Domain = domain;
model.ProducerName = producerName;
model.ProducerVersion = producerVersion;
model.IrVersion = (long)UniversalModelFormat.Onnx.Version.IrVersion;
model.ModelVersion = modelVersion;
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx", Version = 6 });
model.Graph = new GraphProto();
var graph = model.Graph;
graph.Node.Add(nodes);
Expand Down
Loading