Skip to content

Enabling custom groupId column in the Ranking AutoML experiment #5246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/Microsoft.ML.AutoML/API/RankingExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ public sealed class RankingExperimentSettings : ExperimentSettings
/// <value>The default value is <see cref="RankingMetric" />.</value>
public RankingMetric OptimizingMetric { get; set; }

/// <summary>
/// Name for the GroupId column.
/// </summary>
/// <value>The default value is GroupId.</value>
public string GroupIdColumnName { get; set; }

/// <summary>
/// Collection of trainers the AutoML experiment can leverage.
/// </summary>
Expand All @@ -28,6 +34,7 @@ public sealed class RankingExperimentSettings : ExperimentSettings
public ICollection<RankingTrainer> Trainers { get; }
public RankingExperimentSettings()
{
GroupIdColumnName = "GroupId";
OptimizingMetric = RankingMetric.Ndcg;
Trainers = Enum.GetValues(typeof(RankingTrainer)).OfType<RankingTrainer>().ToList();
}
Expand Down Expand Up @@ -68,10 +75,11 @@ public static class RankingExperimentResultExtensions
/// </summary>
/// <param name="results">Enumeration of AutoML experiment run results.</param>
/// <param name="metric">Metric to consider when selecting the best run.</param>
/// <param name="groupIdColumnName">Name for the GroupId column.</param>
/// <returns>The best experiment run.</returns>
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
{
var metricsAgent = new RankingMetricsAgent(null, metric);
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
Expand All @@ -81,10 +89,11 @@ public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingM
/// </summary>
/// <param name="results">Enumeration of AutoML experiment cross validation run results.</param>
/// <param name="metric">Metric to consider when selecting the best run.</param>
/// <param name="groupIdColumnName">Name for the GroupId column.</param>
/// <returns>The best experiment run.</returns>
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
{
var metricsAgent = new RankingMetricsAgent(null, metric);
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
Expand All @@ -103,7 +112,7 @@ public sealed class RankingExperiment : ExperimentBase<RankingMetrics, RankingEx
{
internal RankingExperiment(MLContext context, RankingExperimentSettings settings)
: base(context,
new RankingMetricsAgent(context, settings.OptimizingMetric),
new RankingMetricsAgent(context, settings.OptimizingMetric, settings.GroupIdColumnName),
new OptimizingMetricInfo(settings.OptimizingMetric),
settings,
TaskKind.Ranking,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ internal class RankingMetricsAgent : IMetricsAgent<RankingMetrics>
{
private readonly MLContext _mlContext;
private readonly RankingMetric _optimizingMetric;
private readonly string _groupIdColumnName;

public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric)
public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric, string groupIdColumnName)
{
_mlContext = mlContext;
_optimizingMetric = optimizingMetric;
_groupIdColumnName = groupIdColumnName;
}

// Optimizing metric used: NDCG@10 and DCG@10
Expand Down Expand Up @@ -59,7 +61,7 @@ public bool IsModelPerfect(double score)

public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn)
{
return _mlContext.Ranking.Evaluate(data, labelColumn);
return _mlContext.Ranking.Evaluate(data, labelColumn, _groupIdColumnName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ private static IDictionary<string, object> BuildBasePipelineNodeProps(IEnumerabl
}

private static IDictionary<string, object> BuildLightGbmPipelineNodeProps(IEnumerable<SweepableParam> sweepParams,
string labelColumn, string weightColumn, string groupColumn = null)
string labelColumn, string weightColumn, string groupColumn)
{
Dictionary<string, object> props = null;
if (sweepParams == null || !sweepParams.Any())
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ public static RunDetail<MulticlassClassificationMetrics> GetBestRun(IEnumerable<
}

public static RunDetail<RankingMetrics> GetBestRun(IEnumerable<RunDetail<RankingMetrics>> results,
RankingMetric metric)
RankingMetric metric, string groupIdColumnName)
{
var metricsAgent = new RankingMetricsAgent(null, metric);
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}
Expand Down
4 changes: 2 additions & 2 deletions test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void AutoFitRankingTest()
{
string labelColumnName = "Label";
string scoreColumnName = "Score";
string groupIdColumnName = "GroupId";
string groupIdColumnName = "CustomGroupId";
string featuresColumnVectorNameA = "FeatureVectorA";
string featuresColumnVectorNameB = "FeatureVectorB";
var mlContext = new MLContext(1);
Expand All @@ -136,7 +136,7 @@ public void AutoFitRankingTest()
trainDataView = mlContext.Data.SkipRows(trainDataView, 500);
// STEP 2: Run AutoML experiment
ExperimentResult<RankingMetrics> experimentResult = mlContext.Auto()
.CreateRankingExperiment(5)
.CreateRankingExperiment(new RankingExperimentSettings() { GroupIdColumnName = "CustomGroupId", MaxExperimentTimeInSeconds = 5})
.Execute(trainDataView, testDataView,
new ColumnInformation()
{
Expand Down
24 changes: 12 additions & 12 deletions test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ public void RankingMetricsGetScoreTest()
double[] ndcg = { 0.2, 0.3, 0.4 };
double[] dcg = { 0.2, 0.3, 0.4 };
var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
Assert.Equal(0.4, GetScore(metrics, RankingMetric.Dcg));
Assert.Equal(0.4, GetScore(metrics, RankingMetric.Ndcg));
Assert.Equal(0.4, GetScore(metrics, RankingMetric.Dcg, "GroupId"));
Assert.Equal(0.4, GetScore(metrics, RankingMetric.Ndcg, "GroupId"));

double[] largeNdcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 };
double[] largeDcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 };
metrics = MetricsUtil.CreateRankingMetrics(largeDcg, largeNdcg);
Assert.Equal(0.9, GetScore(metrics, RankingMetric.Dcg));
Assert.Equal(0.9, GetScore(metrics, RankingMetric.Ndcg));
Assert.Equal(0.9, GetScore(metrics, RankingMetric.Dcg, "GroupId"));
Assert.Equal(0.9, GetScore(metrics, RankingMetric.Ndcg, "GroupId"));
}

[Fact]
Expand All @@ -143,8 +143,8 @@ public void RankingMetricsNonPerfectTest()
double[] ndcg = { 0.2, 0.3, 0.4 };
double[] dcg = { 0.2, 0.3, 0.4 };
var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg));
Assert.False(IsPerfectModel(metrics, RankingMetric.Ndcg));
Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, "GroupId"));
Assert.False(IsPerfectModel(metrics, RankingMetric.Ndcg, "GroupId"));
}

[Fact]
Expand All @@ -153,8 +153,8 @@ public void RankingMetricsPerfectTest()
double[] ndcg = { 0.2, 0.3, 1 };
double[] dcg = { 0.2, 0.3, 1 };
var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg)); //REVIEW: No true Perfect model
Assert.True(IsPerfectModel(metrics, RankingMetric.Ndcg));
Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, "GroupId")); //REVIEW: No true Perfect model
Assert.True(IsPerfectModel(metrics, RankingMetric.Ndcg, "GroupId"));
}

[Fact]
Expand All @@ -179,9 +179,9 @@ private static double GetScore(RegressionMetrics metrics, RegressionMetric metri
return new RegressionMetricsAgent(null, metric).GetScore(metrics);
}

private static double GetScore(RankingMetrics metrics, RankingMetric metric)
private static double GetScore(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName)
{
return new RankingMetricsAgent(null, metric).GetScore(metrics);
return new RankingMetricsAgent(null, metric, groupIdColumnName).GetScore(metrics);
}

private static bool IsPerfectModel(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric)
Expand All @@ -202,9 +202,9 @@ private static bool IsPerfectModel(RegressionMetrics metrics, RegressionMetric m
return IsPerfectModel(metricsAgent, metrics);
}

private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric)
private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName)
{
var metricsAgent = new RankingMetricsAgent(null, metric);
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
return IsPerfectModel(metricsAgent, metrics);
}

Expand Down