Skip to content

Commit 54313b9

Browse files
added in DcgTruncationLevel to AutoML api (dotnet#5433)
* added in DcgTruncationLevel to automl api * changed default to 10 * updated basline output * fixed failing tests and baselines * Changes from PR comments. * Update src/Microsoft.ML.AutoML/Experiment/MetricsAgents/RankingMetricsAgent.cs Co-authored-by: Justin Ormont <[email protected]> * Changes based on PR comments. * Fix ranking test. Co-authored-by: Justin Ormont <[email protected]>
1 parent 5038e81 commit 54313b9

16 files changed

+965
-794
lines changed

src/Microsoft.ML.AutoML/API/RankingExperiment.cs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,20 @@ public sealed class RankingExperimentSettings : ExperimentSettings
2626
/// The default value is a collection auto-populated with all possible trainers (all values of <see cref="RankingTrainer" />).
2727
/// </value>
2828
public ICollection<RankingTrainer> Trainers { get; }
29+
30+
/// <summary>
31+
/// Maximum truncation level for computing (N)DCG
32+
/// </summary>
33+
/// <value>
34+
/// The default value is 10.
35+
/// </value>
36+
public uint OptimizationMetricTruncationLevel { get; set; }
37+
2938
public RankingExperimentSettings()
3039
{
3140
OptimizingMetric = RankingMetric.Ndcg;
3241
Trainers = Enum.GetValues(typeof(RankingTrainer)).OfType<RankingTrainer>().ToList();
42+
OptimizationMetricTruncationLevel = 10;
3343
}
3444
}
3545
public enum RankingMetric
@@ -68,10 +78,11 @@ public static class RankingExperimentResultExtensions
6878
/// </summary>
6979
/// <param name="results">Enumeration of AutoML experiment run results.</param>
7080
/// <param name="metric">Metric to consider when selecting the best run.</param>
81+
/// <param name="optimizationMetricTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 10.</param>
7182
/// <returns>The best experiment run.</returns>
72-
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
83+
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, uint optimizationMetricTruncationLevel = 10)
7384
{
74-
var metricsAgent = new RankingMetricsAgent(null, metric);
85+
var metricsAgent = new RankingMetricsAgent(null, metric, optimizationMetricTruncationLevel);
7586
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
7687
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
7788
}
@@ -81,10 +92,11 @@ public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingM
8192
/// </summary>
8293
/// <param name="results">Enumeration of AutoML experiment cross validation run results.</param>
8394
/// <param name="metric">Metric to consider when selecting the best run.</param>
95+
/// <param name="optimizationMetricTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 10.</param>
8496
/// <returns>The best experiment run.</returns>
85-
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
97+
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, uint optimizationMetricTruncationLevel = 10)
8698
{
87-
var metricsAgent = new RankingMetricsAgent(null, metric);
99+
var metricsAgent = new RankingMetricsAgent(null, metric, optimizationMetricTruncationLevel);
88100
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
89101
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
90102
}
@@ -103,7 +115,7 @@ public sealed class RankingExperiment : ExperimentBase<RankingMetrics, RankingEx
103115
{
104116
internal RankingExperiment(MLContext context, RankingExperimentSettings settings)
105117
: base(context,
106-
new RankingMetricsAgent(context, settings.OptimizingMetric),
118+
new RankingMetricsAgent(context, settings.OptimizingMetric, settings.OptimizationMetricTruncationLevel),
107119
new OptimizingMetricInfo(settings.OptimizingMetric),
108120
settings,
109121
TaskKind.Ranking,

src/Microsoft.ML.AutoML/Experiment/MetricsAgents/RankingMetricsAgent.cs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,30 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using Microsoft.ML.Data;
7+
using Microsoft.ML.Runtime;
68

79
namespace Microsoft.ML.AutoML
810
{
911
internal class RankingMetricsAgent : IMetricsAgent<RankingMetrics>
1012
{
1113
private readonly MLContext _mlContext;
1214
private readonly RankingMetric _optimizingMetric;
15+
private readonly uint _dcgTruncationLevel;
1316

14-
public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric)
17+
public RankingMetricsAgent(MLContext mlContext, RankingMetric metric, uint optimizationMetricTruncationLevel)
1518
{
1619
_mlContext = mlContext;
17-
_optimizingMetric = optimizingMetric;
20+
_optimizingMetric = metric;
21+
22+
if (optimizationMetricTruncationLevel <= 0)
23+
throw _mlContext.ExceptUserArg(nameof(optimizationMetricTruncationLevel), "DCG Truncation Level must be greater than 0");
24+
25+
// We want to make sure we always report metrics for at least 10 results (e.g. NDCG@10) to the user.
26+
// Producing extra results adds no measurable performance impact, so we report at least 2x of the
27+
// user's requested optimization truncation level.
28+
_dcgTruncationLevel = optimizationMetricTruncationLevel;
1829
}
1930

2031
// Optimizing metric used: NDCG@10 and DCG@10
@@ -28,11 +39,9 @@ public double GetScore(RankingMetrics metrics)
2839
switch (_optimizingMetric)
2940
{
3041
case RankingMetric.Ndcg:
31-
return (metrics.NormalizedDiscountedCumulativeGains.Count >= 10) ? metrics.NormalizedDiscountedCumulativeGains[9] :
32-
metrics.NormalizedDiscountedCumulativeGains[metrics.NormalizedDiscountedCumulativeGains.Count - 1];
42+
return metrics.NormalizedDiscountedCumulativeGains[Math.Min(metrics.NormalizedDiscountedCumulativeGains.Count, (int)_dcgTruncationLevel) - 1];
3343
case RankingMetric.Dcg:
34-
return (metrics.DiscountedCumulativeGains.Count >= 10) ? metrics.DiscountedCumulativeGains[9] :
35-
metrics.DiscountedCumulativeGains[metrics.DiscountedCumulativeGains.Count-1];
44+
return metrics.DiscountedCumulativeGains[Math.Min(metrics.DiscountedCumulativeGains.Count, (int)_dcgTruncationLevel) - 1];
3645
default:
3746
throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
3847
}
@@ -59,7 +68,12 @@ public bool IsModelPerfect(double score)
5968

6069
public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
6170
{
62-
return _mlContext.Ranking.Evaluate(data, labelColumn, groupIdColumn);
71+
var rankingEvalOptions = new RankingEvaluatorOptions
72+
{
73+
DcgTruncationLevel = Math.Max(10, 2 * (int)_dcgTruncationLevel)
74+
};
75+
76+
return _mlContext.Ranking.Evaluate(data, rankingEvalOptions, labelColumn, groupIdColumn);
6377
}
6478
}
6579
}

src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ public static RunDetail<MulticlassClassificationMetrics> GetBestRun(IEnumerable<
3535
}
3636

3737
public static RunDetail<RankingMetrics> GetBestRun(IEnumerable<RunDetail<RankingMetrics>> results,
38-
RankingMetric metric)
38+
RankingMetric metric, uint dcgTruncationLevel)
3939
{
40-
var metricsAgent = new RankingMetricsAgent(null, metric);
40+
var metricsAgent = new RankingMetricsAgent(null, metric, dcgTruncationLevel);
4141

4242
var metricInfo = new OptimizingMetricInfo(metric);
4343
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);

src/Microsoft.ML.Data/Evaluators/RankingEvaluator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public sealed class RankingEvaluatorOptions
3535
/// Maximum truncation level for computing (N)DCG
3636
/// </value>
3737
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum truncation level for computing (N)DCG", ShortName = "t")]
38-
public int DcgTruncationLevel = 3;
38+
public int DcgTruncationLevel = 10;
3939

4040
/// <value>
4141
/// Label relevance gains
@@ -858,7 +858,7 @@ public sealed class Arguments : ArgumentsBase
858858
public string GroupIdColumn;
859859

860860
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum truncation level for computing (N)DCG", ShortName = "t")]
861-
public int DcgTruncationLevel = 3;
861+
public int DcgTruncationLevel = 10;
862862

863863
[Argument(ArgumentType.AtMostOnce, HelpText = "Label relevance gains", ShortName = "gains")]
864864
public string LabelGains = "0,3,7,15,31";

test/BaselineOutput/Common/Command/CommandCrossValidationKeyLabelWithFloatKeyValues-out.txt

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,80 @@ Making per-feature arrays
44
Changing data from row-wise to column-wise
55
Processed 40 instances
66
Binning and forming Feature objects
7-
Reserved memory for tree learner: 10764 bytes
7+
Reserved memory for tree learner: %Number% bytes
88
Starting to train ...
99
Not training a calibrator because it is not needed.
1010
Not adding a normalizer.
1111
Making per-feature arrays
1212
Changing data from row-wise to column-wise
1313
Processed 32 instances
1414
Binning and forming Feature objects
15-
Reserved memory for tree learner: 6396 bytes
15+
Reserved memory for tree learner: %Number% bytes
1616
Starting to train ...
1717
Not training a calibrator because it is not needed.
1818
NDCG@1: 0.000000
1919
NDCG@2: 0.000000
2020
NDCG@3: 0.000000
21+
NDCG@4: 0.000000
22+
NDCG@5: 0.000000
23+
NDCG@6: 0.000000
24+
NDCG@7: 0.000000
25+
NDCG@8: 0.000000
26+
NDCG@9: 0.000000
27+
NDCG@10: 0.000000
2128
DCG@1: 0.000000
2229
DCG@2: 0.000000
2330
DCG@3: 0.000000
31+
DCG@4: 0.000000
32+
DCG@5: 0.000000
33+
DCG@6: 0.000000
34+
DCG@7: 0.000000
35+
DCG@8: 0.000000
36+
DCG@9: 0.000000
37+
DCG@10: 0.000000
2438
NDCG@1: 0.000000
2539
NDCG@2: 0.000000
2640
NDCG@3: 0.000000
41+
NDCG@4: 0.000000
42+
NDCG@5: 0.000000
43+
NDCG@6: 0.000000
44+
NDCG@7: 0.000000
45+
NDCG@8: 0.000000
46+
NDCG@9: 0.000000
47+
NDCG@10: 0.000000
2748
DCG@1: 0.000000
2849
DCG@2: 0.000000
2950
DCG@3: 0.000000
51+
DCG@4: 0.000000
52+
DCG@5: 0.000000
53+
DCG@6: 0.000000
54+
DCG@7: 0.000000
55+
DCG@8: 0.000000
56+
DCG@9: 0.000000
57+
DCG@10: 0.000000
3058

3159
OVERALL RESULTS
3260
---------------------------------------
3361
NDCG@1: 0.000000 (0.0000)
3462
NDCG@2: 0.000000 (0.0000)
3563
NDCG@3: 0.000000 (0.0000)
64+
NDCG@4: 0.000000 (0.0000)
65+
NDCG@5: 0.000000 (0.0000)
66+
NDCG@6: 0.000000 (0.0000)
67+
NDCG@7: 0.000000 (0.0000)
68+
NDCG@8: 0.000000 (0.0000)
69+
NDCG@9: 0.000000 (0.0000)
70+
NDCG@10: 0.000000 (0.0000)
3671
DCG@1: 0.000000 (0.0000)
3772
DCG@2: 0.000000 (0.0000)
3873
DCG@3: 0.000000 (0.0000)
74+
DCG@4: 0.000000 (0.0000)
75+
DCG@5: 0.000000 (0.0000)
76+
DCG@6: 0.000000 (0.0000)
77+
DCG@7: 0.000000 (0.0000)
78+
DCG@8: 0.000000 (0.0000)
79+
DCG@9: 0.000000 (0.0000)
80+
DCG@10: 0.000000 (0.0000)
3981

4082
---------------------------------------
4183
Physical memory usage(MB): %Number%

0 commit comments

Comments
 (0)