-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Export to ONNX and cross-platform command-line tool to script ML.NET training and inference #248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1b78603
1508cd4
c6763f3
b04ef49
eaa8e3a
de02e1e
6f4434e
6c97416
534fcd1
caebed0
692b526
6c92033
e466157
5a04795
8ef9b3f
64cdb80
3276dd3
c6bc1c6
17a738a
5090d24
4cfed38
0a13ad7
5bea824
2ab729f
faf528c
3dfd81f
1865825
3562331
d20f1a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
[assembly: LoadableClass(SaveOnnxCommand.Summary, typeof(SaveOnnxCommand), typeof(SaveOnnxCommand.Arguments), typeof(SignatureCommand), | ||
"Save ONNX", "SaveOnnx", DocName = "command/SaveOnnx.md")] | ||
|
||
[assembly: LoadableClass(typeof(void), typeof(SaveOnnxCommand), null, typeof(SignatureEntryPointModule), "SaveOnnxCommand")] | ||
[assembly: LoadableClass(typeof(void), typeof(SaveOnnxCommand), null, typeof(SignatureEntryPointModule), "SaveOnnx")] | ||
|
||
namespace Microsoft.ML.Runtime.Model.Onnx | ||
{ | ||
|
@@ -41,24 +41,24 @@ public sealed class Arguments : DataCommand.ArgumentsBase | |
[Argument(ArgumentType.AtMostOnce, HelpText = "The 'domain' property in the output ONNX.", NullName = "<Auto>", SortOrder = 4)] | ||
public string Domain; | ||
|
||
[Argument(ArgumentType.AtMostOnce, HelpText = "Comma delimited list of input column names to drop", ShortName = "idrop", SortOrder = 5)] | ||
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Comma delimited list of input column names to drop", ShortName = "idrop", SortOrder = 5)] | ||
public string InputsToDrop; | ||
|
||
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of input column names to drop", SortOrder = 6)] | ||
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of input column names to drop", Name = nameof(InputsToDrop), SortOrder = 6)] | ||
public string[] InputsToDropArray; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Awkward name to use. Please use |
||
|
||
[Argument(ArgumentType.AtMostOnce, HelpText = "Comma delimited list of output column names to drop", ShortName = "odrop", SortOrder = 7)] | ||
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Comma delimited list of output column names to drop", ShortName = "odrop", SortOrder = 7)] | ||
public string OutputsToDrop; | ||
|
||
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of output column names to drop", SortOrder = 8)] | ||
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of output column names to drop", Name = nameof(OutputsToDrop), SortOrder = 8)] | ||
public string[] OutputsToDropArray; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Same comments on inputs apply to outputs. #Closed |
||
|
||
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)] | ||
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)] | ||
public bool? LoadPredictor; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason to have this in entry-point land? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not if we are passing in ITransformModel that contains the predictor. In reply to: 191819078 [](ancestors = 191819078) |
||
|
||
[Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] | ||
[Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] | ||
|
||
public IPredictorModel Model; | ||
public ITransformModel Model; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There appears to be a newline gap here. We tend to not put whitespace between the attribute and the field, as we see elsewhere here. #Closed |
||
} | ||
|
||
private readonly string _outputModelPath; | ||
|
@@ -68,7 +68,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase | |
private readonly bool? _loadPredictor; | ||
private readonly HashSet<string> _inputsToDrop; | ||
private readonly HashSet<string> _outputsToDrop; | ||
private readonly IPredictorModel _model; | ||
private readonly ITransformModel _model; | ||
|
||
public SaveOnnxCommand(IHostEnvironment env, Arguments args) | ||
: base(env, args, LoadName) | ||
|
@@ -83,20 +83,12 @@ public SaveOnnxCommand(IHostEnvironment env, Arguments args) | |
_name = args.Name; | ||
|
||
_loadPredictor = args.LoadPredictor; | ||
_inputsToDrop = args.InputsToDropArray != null ? CreateDropMap(args.InputsToDropArray) : CreateDropMap(args.InputsToDrop); | ||
_outputsToDrop = args.OutputsToDropArray != null ? CreateDropMap(args.OutputsToDropArray) : CreateDropMap(args.OutputsToDrop); | ||
_inputsToDrop = CreateDropMap(args.InputsToDropArray ?? args.InputsToDrop?.Split(',')); | ||
_outputsToDrop = CreateDropMap(args.OutputsToDropArray ?? args.OutputsToDrop?.Split(',')); | ||
_domain = args.Domain; | ||
_model = args.Model; | ||
} | ||
|
||
private static HashSet<string> CreateDropMap(string toDrop) | ||
{ | ||
if (string.IsNullOrWhiteSpace(toDrop)) | ||
return new HashSet<string>(); | ||
|
||
return new HashSet<string>(toDrop.Split(',')); | ||
} | ||
|
||
private static HashSet<string> CreateDropMap(string[] toDrop) | ||
{ | ||
if (toDrop == null) | ||
|
@@ -140,8 +132,8 @@ private void GetPipe(IChannel ch, IDataView end, out IDataView source, out IData | |
|
||
private void Run(IChannel ch) | ||
{ | ||
IDataLoader loader = null; ; | ||
IPredictor rawPred; | ||
IDataLoader loader = null; | ||
IPredictor rawPred = null; | ||
IDataView view; | ||
RoleMappedSchema trainSchema = null; | ||
|
||
|
@@ -161,12 +153,7 @@ private void Run(IChannel ch) | |
view = loader; | ||
} | ||
else | ||
{ | ||
view = _model.TransformModel.View; | ||
rawPred = _model?.Predictor; | ||
if (rawPred != null) | ||
trainSchema = _model.GetTrainingSchema(Host); | ||
} | ||
view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema)); | ||
|
||
// Get the transform chain. | ||
IDataView source; | ||
|
@@ -276,7 +263,6 @@ private void Run(IChannel ch) | |
|
||
public sealed class Output | ||
{ | ||
//REVIEW: Would be nice to include ONNX protobuf model here but code generator needs an upgrade. | ||
} | ||
|
||
//REVIEW: Ideally there is no need to define this input class and just reuse the Argument class from SaveONNX command | ||
|
@@ -302,12 +288,8 @@ public sealed class Input | |
[Argument(ArgumentType.AtMostOnce, HelpText = "Array of output column names to drop", SortOrder = 6)] | ||
public string[] OutputsToDrop; | ||
|
||
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 7)] | ||
public bool? LoadPredictor; | ||
|
||
[Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 8)] | ||
|
||
public IPredictorModel Model; | ||
[Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 7)] | ||
public ITransformModel Model; | ||
} | ||
|
||
|
||
|
@@ -321,7 +303,6 @@ public static Output Apply(IHostEnvironment env, Input input) | |
args.Domain = input.Domain; | ||
args.InputsToDropArray = input.InputsToDrop; | ||
args.OutputsToDropArray = input.OutputsToDrop; | ||
args.LoadPredictor = input.LoadPredictor; | ||
args.Model = input.Model; | ||
|
||
var cmd = new SaveOnnxCommand(env, args); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
namespace Microsoft.ML.Runtime.Tools.Console | ||
{ | ||
public static class Console | ||
{ | ||
public static int Main(string[] args) | ||
{ | ||
string all = string.Join(" ", args); | ||
return Maml.MainAll(all); | ||
} | ||
|
||
public static unsafe int MainRaw(char* psz) | ||
{ | ||
string args = new string(psz); | ||
return Maml.MainAll(args); | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set visibility on this. #Closed