Skip to content

added in DcgTruncationLevel to AutoML api #5433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 12, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Changes from PR comments.
  • Loading branch information
michaelgsharp committed Dec 11, 2020
commit 70989b66e9c1683f1b3b25351f237d8f03efc0cb
18 changes: 9 additions & 9 deletions src/Microsoft.ML.AutoML/API/RankingExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ public sealed class RankingExperimentSettings : ExperimentSettings
/// <value>
/// The default value is 10.
/// </value>
public int DcgTruncationLevel { get; set; }
public int OptimizationMetricTruncationLevel { get; set; }

public RankingExperimentSettings()
{
OptimizingMetric = RankingMetric.Ndcg;
Trainers = Enum.GetValues(typeof(RankingTrainer)).OfType<RankingTrainer>().ToList();
DcgTruncationLevel = 10;
OptimizationMetricTruncationLevel = 10;
}
}
public enum RankingMetric
Expand Down Expand Up @@ -78,11 +78,11 @@ public static class RankingExperimentResultExtensions
/// </summary>
/// <param name="results">Enumeration of AutoML experiment run results.</param>
/// <param name="metric">Metric to consider when selecting the best run.</param>
/// <param name="dcgTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 3.</param>
/// <param name="optimizationMetricTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 10.</param>
/// <returns>The best experiment run.</returns>
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, int dcgTruncationLevel = 3)
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, int optimizationMetricTruncationLevel = 10)
{
var metricsAgent = new RankingMetricsAgent(null, metric, dcgTruncationLevel);
var metricsAgent = new RankingMetricsAgent(null, metric, optimizationMetricTruncationLevel);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
Expand All @@ -92,11 +92,11 @@ public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingM
/// </summary>
/// <param name="results">Enumeration of AutoML experiment cross validation run results.</param>
/// <param name="metric">Metric to consider when selecting the best run.</param>
/// <param name="dcgTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 3.</param>
/// <param name="optimizationMetricTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 10.</param>
/// <returns>The best experiment run.</returns>
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, int dcgTruncationLevel = 3)
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, int optimizationMetricTruncationLevel = 10)
{
var metricsAgent = new RankingMetricsAgent(null, metric, dcgTruncationLevel);
var metricsAgent = new RankingMetricsAgent(null, metric, optimizationMetricTruncationLevel);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
Expand All @@ -115,7 +115,7 @@ public sealed class RankingExperiment : ExperimentBase<RankingMetrics, RankingEx
{
internal RankingExperiment(MLContext context, RankingExperimentSettings settings)
: base(context,
new RankingMetricsAgent(context, settings.OptimizingMetric, settings.DcgTruncationLevel),
new RankingMetricsAgent(context, settings.OptimizingMetric, settings.OptimizationMetricTruncationLevel),
new OptimizingMetricInfo(settings.OptimizingMetric),
settings,
TaskKind.Ranking,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ internal class RankingMetricsAgent : IMetricsAgent<RankingMetrics>
private readonly RankingMetric _optimizingMetric;
private readonly int _dcgTruncationLevel;

public RankingMetricsAgent(MLContext mlContext, RankingMetric metric, int dcgTruncationLevel)
public RankingMetricsAgent(MLContext mlContext, RankingMetric metric, int optimizationMetricTruncationLevel)
{
_mlContext = mlContext;
_optimizingMetric = metric;
_dcgTruncationLevel = dcgTruncationLevel;

// We want to make sure we always have at least 10 results. Getting extra results adds no measurable performance
// impact, so err on the side of more.
_dcgTruncationLevel = System.Math.Max(10, 2 * optimizationMetricTruncationLevel);
}

// Optimizing metric used: NDCG@10 and DCG@10
Expand Down
6 changes: 3 additions & 3 deletions test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public void AutoFitRankingTest()
var settings = new RankingExperimentSettings()
{
MaxExperimentTimeInSeconds = 5,
DcgTruncationLevel = 5
OptimizationMetricTruncationLevel = 5
};
var experiment = mlContext.Auto()
.CreateRankingExperiment(settings);
Expand All @@ -203,8 +203,8 @@ public void AutoFitRankingTest()
for (int i = 0; i < experimentResults.Length; i++)
{
RunDetail<RankingMetrics> bestRun = experimentResults[i].BestRun;
Assert.Equal(5, bestRun.ValidationMetrics.DiscountedCumulativeGains.Count);
Assert.Equal(5, bestRun.ValidationMetrics.NormalizedDiscountedCumulativeGains.Count);
Assert.Equal(10, bestRun.ValidationMetrics.DiscountedCumulativeGains.Count);
Assert.Equal(10, bestRun.ValidationMetrics.NormalizedDiscountedCumulativeGains.Count);
Assert.True(experimentResults[i].RunDetails.Count() > 0);
Assert.NotNull(bestRun.ValidationMetrics);
Assert.True(bestRun.ValidationMetrics.NormalizedDiscountedCumulativeGains.Last() > 0.4);
Expand Down