Skip to content

Commit 072e61d

Browse files
committed
PR feedback.
1 parent 542b845 commit 072e61d

File tree

14 files changed

+96
-73
lines changed

14 files changed

+96
-73
lines changed

ZBaselines/Common/EntryPoints/core_manifest.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,7 +1437,7 @@
14371437
"Kind": "Struct",
14381438
"Fields": [
14391439
{
1440-
"Name": "Model",
1440+
"Name": "PredictorModel",
14411441
"Type": "PredictorModel",
14421442
"Desc": "The predictor model",
14431443
"Required": false,
@@ -3055,7 +3055,7 @@
30553055
"Kind": "Struct",
30563056
"Fields": [
30573057
{
3058-
"Name": "Model",
3058+
"Name": "PredictorModel",
30593059
"Type": "PredictorModel",
30603060
"Desc": "The predictor model",
30613061
"Required": false,

src/Microsoft.ML.PipelineInference/PipelinePattern.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD
152152
},
153153
Outputs =
154154
{
155-
Model = finalOutput
155+
PredictorModel = finalOutput
156156
},
157157
PipelineId = UniqueId.ToString("N"),
158158
Kind = MacroUtils.TrainerKindApiValue<Models.MacroUtilsTrainerKinds>(trainerKind),
@@ -189,7 +189,7 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
189189
},
190190
Outputs =
191191
{
192-
Model = finalOutput
192+
PredictorModel = finalOutput
193193
},
194194
TrainingData = trainData,
195195
TestingData = testData,

src/Microsoft.ML/CSharpApi.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,7 +2214,7 @@ public sealed partial class CrossValidationMacroSubGraphOutput
22142214
/// <summary>
22152215
/// The predictor model
22162216
/// </summary>
2217-
public Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel> Model { get; set; } = new Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel>();
2217+
public Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel> PredictorModel { get; set; } = new Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel>();
22182218

22192219
/// <summary>
22202220
/// The transform model
@@ -3400,7 +3400,7 @@ public sealed partial class TrainTestMacroSubGraphOutput
34003400
/// <summary>
34013401
/// The predictor model
34023402
/// </summary>
3403-
public Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel> Model { get; set; } = new Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel>();
3403+
public Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel> PredictorModel { get; set; } = new Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel>();
34043404

34053405
/// <summary>
34063406
/// Transform model

src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
using Microsoft.ML.Runtime;
66
using Microsoft.ML.Runtime.Data;
7-
using Microsoft.ML.Runtime.EntryPoints;
87
using Microsoft.ML.Transforms;
9-
using System.Collections.Generic;
108
using System.Linq;
119

1210
namespace Microsoft.ML.Models
@@ -70,7 +68,7 @@ public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipe
7068

7169
var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
7270

73-
Contracts.Assert(metric.Count == 1);
71+
Contracts.Check(metric.Count == 1);
7472

7573
return metric[0];
7674
}

src/Microsoft.ML/Models/BinaryClassificationMetrics.cs

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ private BinaryClassificationMetrics()
1919
{
2020
}
2121

22-
internal static List<BinaryClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int skipRows = 0)
22+
internal static List<BinaryClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int confusionMatriceStartIndex = 0)
2323
{
2424
Contracts.AssertValue(env);
2525
env.AssertValue(overallMetrics);
@@ -28,41 +28,40 @@ internal static List<BinaryClassificationMetrics> FromMetrics(IHostEnvironment e
2828
var metricsEnumerable = overallMetrics.AsEnumerable<SerializationClass>(env, true, ignoreMissingColumns: true);
2929
var enumerator = metricsEnumerable.GetEnumerator();
3030

31-
while (skipRows-- >= 0)
31+
if (!enumerator.MoveNext())
3232
{
33-
if (!enumerator.MoveNext())
34-
{
35-
throw env.Except("The overall RegressionMetrics didn't have sufficient rows.");
36-
}
33+
throw env.Except("The overall RegressionMetrics didn't have sufficient rows.");
3734
}
3835

3936
List<BinaryClassificationMetrics> metrics = new List<BinaryClassificationMetrics>();
4037
var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();
38+
39+
int Index = 0;
4140
do
4241
{
4342
SerializationClass metric = enumerator.Current;
4443

45-
if (!confusionMatrices.MoveNext())
44+
if (Index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
4645
{
4746
throw env.Except("Confusion matrices didn't have enough matrices.");
4847
}
4948

50-
metrics.Add(
49+
metrics.Add(
5150
new BinaryClassificationMetrics()
52-
{
53-
Auc = metric.Auc,
54-
Accuracy = metric.Accuracy,
55-
PositivePrecision = metric.PositivePrecision,
56-
PositiveRecall = metric.PositiveRecall,
57-
NegativePrecision = metric.NegativePrecision,
58-
NegativeRecall = metric.NegativeRecall,
59-
LogLoss = metric.LogLoss,
60-
LogLossReduction = metric.LogLossReduction,
61-
Entropy = metric.Entropy,
62-
F1Score = metric.F1Score,
63-
Auprc = metric.Auprc,
64-
ConfusionMatrix = confusionMatrices.Current,
65-
});
51+
{
52+
Auc = metric.Auc,
53+
Accuracy = metric.Accuracy,
54+
PositivePrecision = metric.PositivePrecision,
55+
PositiveRecall = metric.PositiveRecall,
56+
NegativePrecision = metric.NegativePrecision,
57+
NegativeRecall = metric.NegativeRecall,
58+
LogLoss = metric.LogLoss,
59+
LogLossReduction = metric.LogLossReduction,
60+
Entropy = metric.Entropy,
61+
F1Score = metric.F1Score,
62+
Auprc = metric.Auprc,
63+
ConfusionMatrix = confusionMatrices.Current,
64+
});
6665

6766
} while (enumerator.MoveNext());
6867

src/Microsoft.ML/Models/ClassificationEvaluator.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using Microsoft.ML.Runtime;
66
using Microsoft.ML.Runtime.Data;
77
using Microsoft.ML.Transforms;
8-
using System.Collections.Generic;
98
using System.Linq;
109

1110
namespace Microsoft.ML.Models
@@ -70,7 +69,7 @@ public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLo
7069

7170
var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
7271

73-
Contracts.Assert(metric.Count == 1);
72+
Contracts.Check(metric.Count == 1);
7473

7574
return metric[0];
7675
}

src/Microsoft.ML/Models/ClassificationMetrics.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,27 @@ private ClassificationMetrics()
1818
{
1919
}
2020

21-
internal static List<ClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int skipRows = 0)
21+
internal static List<ClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix,
22+
int confusionMatriceStartIndex = 0)
2223
{
2324
Contracts.AssertValue(env);
2425
env.AssertValue(overallMetrics);
2526
env.AssertValue(confusionMatrix);
2627

2728
var metricsEnumerable = overallMetrics.AsEnumerable<SerializationClass>(env, true, ignoreMissingColumns: true);
2829
var enumerator = metricsEnumerable.GetEnumerator();
29-
while (skipRows-- >= 0)
30+
if (!enumerator.MoveNext())
3031
{
31-
if (!enumerator.MoveNext())
32-
{
33-
throw env.Except("The overall RegressionMetrics didn't have sufficient rows.");
34-
}
32+
throw env.Except("The overall RegressionMetrics didn't have sufficient rows.");
3533
}
36-
34+
3735
List<ClassificationMetrics> metrics = new List<ClassificationMetrics>();
3836
var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();
37+
38+
int Index = 0;
3939
do
4040
{
41-
if (!confusionMatrices.MoveNext())
41+
if (Index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
4242
{
4343
throw env.Except("Confusion matrices didn't have enough matrices.");
4444
}

src/Microsoft.ML/Models/CrossValidator.cs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ public sealed partial class CrossValidator
1717
/// </summary>
1818
/// <typeparam name="TInput">Class type that represents input schema.</typeparam>
1919
/// <typeparam name="TOutput">Class type that represents prediction schema.</typeparam>
20-
/// <param name="pipeline">Machine learning pipeline that contain may contain loader, transforms and at least one trainer.</param>
21-
/// <returns>List containning metrics and predictor model for each fold</returns>
20+
/// <param name="pipeline">Machine learning pipeline may contain loader, transforms and at least one trainer.</param>
21+
/// <returns>List containing metrics and predictor model for each fold</returns>
2222
public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(LearningPipeline pipeline)
2323
where TInput : class
2424
where TOutput : class, new()
@@ -127,23 +127,20 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
127127
cvOutput.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics(
128128
environment,
129129
experiment.GetOutput(crossValidateOutput.OverallMetrics),
130-
experiment.GetOutput(crossValidateOutput.ConfusionMatrix),
131-
2);
130+
experiment.GetOutput(crossValidateOutput.ConfusionMatrix), 2);
132131
}
133132
else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer)
134133
{
135134
cvOutput.ClassificationMetrics = ClassificationMetrics.FromMetrics(
136135
environment,
137136
experiment.GetOutput(crossValidateOutput.OverallMetrics),
138-
experiment.GetOutput(crossValidateOutput.ConfusionMatrix),
139-
2);
137+
experiment.GetOutput(crossValidateOutput.ConfusionMatrix), 2);
140138
}
141139
else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer)
142140
{
143141
cvOutput.RegressionMetrics = RegressionMetrics.FromOverallMetrics(
144142
environment,
145-
experiment.GetOutput(crossValidateOutput.OverallMetrics),
146-
2);
143+
experiment.GetOutput(crossValidateOutput.OverallMetrics));
147144
}
148145
else
149146
{

src/Microsoft.ML/Models/RegressionMetrics.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,16 @@ private RegressionMetrics()
1919
{
2020
}
2121

22-
internal static List<RegressionMetrics> FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics, int skipRows = 0)
22+
internal static List<RegressionMetrics> FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics)
2323
{
2424
Contracts.AssertValue(env);
2525
env.AssertValue(overallMetrics);
2626

2727
var metricsEnumerable = overallMetrics.AsEnumerable<SerializationClass>(env, true, ignoreMissingColumns: true);
2828
var enumerator = metricsEnumerable.GetEnumerator();
29-
while (skipRows-- >= 0)
29+
if (!enumerator.MoveNext())
3030
{
31-
if (!enumerator.MoveNext())
32-
{
33-
throw env.Except("The overall RegressionMetrics didn't have sufficient rows.");
34-
}
31+
throw env.Except("The overall RegressionMetrics didn't have sufficient rows.");
3532
}
3633

3734
List<RegressionMetrics> metrics = new List<RegressionMetrics>();

src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public sealed class SubGraphInput
3030
public sealed class SubGraphOutput
3131
{
3232
[Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)]
33-
public Var<IPredictorModel> Model;
33+
public Var<IPredictorModel> PredictorModel;
3434

3535
[Argument(ArgumentType.AtMostOnce, HelpText = "The transform model", SortOrder = 2)]
3636
public Var<ITransformModel> TransformModel;
@@ -203,15 +203,15 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
203203
VarName = mapping[input.Inputs.Data.VarName]
204204
};
205205

206-
if (input.Outputs.Model != null && mapping.ContainsKey(input.Outputs.Model.VarName))
206+
if (input.Outputs.PredictorModel != null && mapping.ContainsKey(input.Outputs.PredictorModel.VarName))
207207
{
208-
args.Outputs.Model = new Var<IPredictorModel>
208+
args.Outputs.PredictorModel = new Var<IPredictorModel>
209209
{
210-
VarName = mapping[input.Outputs.Model.VarName]
210+
VarName = mapping[input.Outputs.PredictorModel.VarName]
211211
};
212212
}
213213
else
214-
args.Outputs.Model = null;
214+
args.Outputs.PredictorModel = null;
215215

216216
if (input.Outputs.TransformModel != null && mapping.ContainsKey(input.Outputs.TransformModel.VarName))
217217
{

src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public sealed class SubGraphInput
2525
public sealed class SubGraphOutput
2626
{
2727
[Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)]
28-
public Var<IPredictorModel> Model;
28+
public Var<IPredictorModel> PredictorModel;
2929

3030
[Argument(ArgumentType.AtMostOnce, HelpText = "Transform model", SortOrder = 2)]
3131
public Var<ITransformModel> TransformModel;
@@ -129,7 +129,7 @@ public static CommonOutputs.MacroOutput<Output> TrainTest(
129129
subGraphRunContext.RemoveVariable(dataVariable);
130130

131131
// Change the subgraph to use the model variable as output.
132-
varName = input.Outputs.UseTransformModel ? input.Outputs.TransformModel.VarName : input.Outputs.Model.VarName;
132+
varName = input.Outputs.UseTransformModel ? input.Outputs.TransformModel.VarName : input.Outputs.PredictorModel.VarName;
133133
if (!subGraphRunContext.TryGetVariable(varName, out dataVariable))
134134
throw env.Except($"Invalid variable name '{varName}'.");
135135

test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ public void TestCrossValidationMacro()
319319
TransformModel = null
320320
};
321321
crossValidate.Inputs.Data = nop.Data;
322-
crossValidate.Outputs.Model = modelCombineOutput.PredictorModel;
322+
crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
323323
var crossValidateOutput = experiment.Add(crossValidate);
324324

325325
experiment.Compile();
@@ -410,7 +410,7 @@ public void TestCrossValidationMacroWithMultiClass()
410410
TransformModel = null
411411
};
412412
crossValidate.Inputs.Data = nop.Data;
413-
crossValidate.Outputs.Model = modelCombineOutput.PredictorModel;
413+
crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
414414
var crossValidateOutput = experiment.Add(crossValidate);
415415

416416
experiment.Compile();
@@ -541,7 +541,7 @@ public void TestCrossValidationMacroWithStratification()
541541
StratificationColumn = "Strat"
542542
};
543543
crossValidate.Inputs.Data = nop.Data;
544-
crossValidate.Outputs.Model = modelCombineOutput.PredictorModel;
544+
crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
545545
var crossValidateOutput = experiment.Add(crossValidate);
546546
experiment.Compile();
547547
experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1875,7 +1875,7 @@ public void EntryPointTrainTestMacroNoTransformInput()
18751875
'Data': '$data1'
18761876
},
18771877
'Outputs': {
1878-
'Model': '$model'
1878+
'PredictorModel': '$model'
18791879
}
18801880
},
18811881
'Outputs': {
@@ -1980,7 +1980,7 @@ public void EntryPointTrainTestMacro()
19801980
'Data': '$data1'
19811981
},
19821982
'Outputs': {
1983-
'Model': '$model'
1983+
'PredictorModel': '$model'
19841984
}
19851985
},
19861986
'Outputs': {
@@ -2108,7 +2108,7 @@ public void EntryPointChainedTrainTestMacros()
21082108
'Data': '$data1'
21092109
},
21102110
'Outputs': {
2111-
'Model': '$model'
2111+
'PredictorModel': '$model'
21122112
}
21132113
},
21142114
'Outputs': {
@@ -2141,7 +2141,7 @@ public void EntryPointChainedTrainTestMacros()
21412141
'Data': '$data4'
21422142
},
21432143
'Outputs': {
2144-
'Model': '$model2'
2144+
'PredictorModel': '$model2'
21452145
}
21462146
},
21472147
'Outputs': {
@@ -2274,7 +2274,7 @@ public void EntryPointChainedCrossValMacros()
22742274
'Data': '$data6'
22752275
},
22762276
'Outputs': {
2277-
'Model': '$model'
2277+
'PredictorModel': '$model'
22782278
}
22792279
},
22802280
'Outputs': {
@@ -2336,7 +2336,7 @@ public void EntryPointChainedCrossValMacros()
23362336
'Data': '$data4'
23372337
},
23382338
'Outputs': {
2339-
'Model': '$model2'
2339+
'PredictorModel': '$model2'
23402340
}
23412341
},
23422342
'Outputs': {

0 commit comments

Comments
 (0)