Skip to content

Enabling custom groupId column in the Ranking AutoML experiment #5246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
groupid addition option 2
  • Loading branch information
Lynx1820 committed Jun 29, 2020
commit 63e4d51a9e8fa24a97d81c11c0ec4184679a84ba
19 changes: 14 additions & 5 deletions src/Microsoft.ML.AutoML/API/RankingExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ public sealed class RankingExperimentSettings : ExperimentSettings
/// <value>The default value is <see cref="RankingMetric" />.</value>
public RankingMetric OptimizingMetric { get; set; }

/// <summary>
/// Name for the GroupId column.
/// </summary>
/// <value>The default value is GroupId.</value>
public string GroupIdColumnName { get; set; }

/// <summary>
/// Collection of trainers the AutoML experiment can leverage.
/// </summary>
Expand All @@ -28,6 +34,7 @@ public sealed class RankingExperimentSettings : ExperimentSettings
public ICollection<RankingTrainer> Trainers { get; }
public RankingExperimentSettings()
{
GroupIdColumnName = "GroupId";
OptimizingMetric = RankingMetric.Ndcg;
Trainers = Enum.GetValues(typeof(RankingTrainer)).OfType<RankingTrainer>().ToList();
}
Expand Down Expand Up @@ -68,10 +75,11 @@ public static class RankingExperimentResultExtensions
/// </summary>
/// <param name="results">Enumeration of AutoML experiment run results.</param>
/// <param name="metric">Metric to consider when selecting the best run.</param>
/// <param name="groupIdColumnName">Name for the GroupId column.</param>
/// <returns>The best experiment run.</returns>
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
{
var metricsAgent = new RankingMetricsAgent(null, metric);
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
Expand All @@ -81,10 +89,11 @@ public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingM
/// </summary>
/// <param name="results">Enumeration of AutoML experiment cross validation run results.</param>
/// <param name="metric">Metric to consider when selecting the best run.</param>
/// <param name="groupIdColumnName">Name for the GroupId column.</param>
/// <returns>The best experiment run.</returns>
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
{
var metricsAgent = new RankingMetricsAgent(null, metric);
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
Expand All @@ -103,7 +112,7 @@ public sealed class RankingExperiment : ExperimentBase<RankingMetrics, RankingEx
{
internal RankingExperiment(MLContext context, RankingExperimentSettings settings)
: base(context,
new RankingMetricsAgent(context, settings.OptimizingMetric),
new RankingMetricsAgent(context, settings.OptimizingMetric, settings.GroupIdColumnName),
new OptimizingMetricInfo(settings.OptimizingMetric),
settings,
TaskKind.Ranking,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public bool IsModelPerfect(double score)
}
}

public BinaryClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupId)
public BinaryClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn)
{
return _mlContext.BinaryClassification.EvaluateNonCalibrated(data, labelColumn);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ internal interface IMetricsAgent<T>

bool IsModelPerfect(double score);

T EvaluateMetrics(IDataView data, string labelColumn, string groupId);
T EvaluateMetrics(IDataView data, string labelColumn);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public bool IsModelPerfect(double score)
}
}

public MulticlassClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupId)
public MulticlassClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn)
{
return _mlContext.MulticlassClassification.Evaluate(data, labelColumn);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ internal class RankingMetricsAgent : IMetricsAgent<RankingMetrics>
{
private readonly MLContext _mlContext;
private readonly RankingMetric _optimizingMetric;
private readonly string _groupIdColumnName;

public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric)
public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric, string groupIdColumnName)
{
_mlContext = mlContext;
_optimizingMetric = optimizingMetric;
_groupIdColumnName = groupIdColumnName;
}

// Optimizing metric used: NDCG@10 and DCG@10
Expand Down Expand Up @@ -57,9 +59,9 @@ public bool IsModelPerfect(double score)
}
}

public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn)
{
return _mlContext.Ranking.Evaluate(data, labelColumn, groupIdColumn);
return _mlContext.Ranking.Evaluate(data, labelColumn, _groupIdColumnName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public bool IsModelPerfect(double score)
}
}

public RegressionMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupId)
public RegressionMetrics EvaluateMetrics(IDataView data, string labelColumn)
{
return _mlContext.Regression.Evaluate(data, labelColumn);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.AutoML/Experiment/Runners/CrossValRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public CrossValRunner(MLContext context,
IMetricsAgent<TMetrics> metricsAgent,
IEstimator<ITransformer> preFeaturizer,
ITransformer[] preprocessorTransforms,
string labelColumn,
string groupIdColumn,
string labelColumn,
IChannel logger)
{
_context = context;
Expand All @@ -55,7 +55,7 @@ public CrossValRunner(MLContext context,
{
var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
_labelColumn, _groupIdColumn, _metricsAgent, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger);
_labelColumn, _metricsAgent, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger);
trainResults.Add(new SuggestedPipelineTrainResult<TMetrics>(trainResult.model, trainResult.metrics, trainResult.exception, trainResult.score));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public CrossValSummaryRunner(MLContext context,
{
var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
_labelColumn, _groupIdColumn,_metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema,
_labelColumn, _metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema,
_logger);
trainResults.Add(trainResult);
}
Expand Down
3 changes: 1 addition & 2 deletions src/Microsoft.ML.AutoML/Experiment/Runners/RunnerUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ public static (ModelContainer model, TMetrics metrics, Exception exception, doub
IDataView trainData,
IDataView validData,
string labelColumn,
string groupId,
IMetricsAgent<TMetrics> metricsAgent,
ITransformer preprocessorTransform,
FileInfo modelFileInfo,
Expand All @@ -29,7 +28,7 @@ public static (ModelContainer model, TMetrics metrics, Exception exception, doub
var model = estimator.Fit(trainData);

var scoredData = model.Transform(validData);
var metrics = metricsAgent.EvaluateMetrics(scoredData, labelColumn, groupId);
var metrics = metricsAgent.EvaluateMetrics(scoredData, labelColumn);
var score = metricsAgent.GetScore(metrics);

if (preprocessorTransform != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public TrainValidateRunner(MLContext context,
{
var modelFileInfo = GetModelFileInfo(modelDirectory, iterationNum);
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainData, _validData,
_labelColumn, _groupIdColumn, _metricsAgent, _preprocessorTransform, modelFileInfo, _modelInputSchema, _logger);
_labelColumn, _metricsAgent, _preprocessorTransform, modelFileInfo, _modelInputSchema, _logger);
var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail<TMetrics>(pipeline,
trainResult.score,
trainResult.exception == null,
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ public static RunDetail<MulticlassClassificationMetrics> GetBestRun(IEnumerable<
}

public static RunDetail<RankingMetrics> GetBestRun(IEnumerable<RunDetail<RankingMetrics>> results,
RankingMetric metric)
RankingMetric metric, string groupIdColumnName)
{
var metricsAgent = new RankingMetricsAgent(null, metric);
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public void AutoFitRankingTest()
trainDataView = mlContext.Data.SkipRows(trainDataView, 500);
// STEP 2: Run AutoML experiment
ExperimentResult<RankingMetrics> experimentResult = mlContext.Auto()
.CreateRankingExperiment(5)
.CreateRankingExperiment(new RankingExperimentSettings() { GroupIdColumnName = "CustomGroupId", MaxExperimentTimeInSeconds = 5})
.Execute(trainDataView, testDataView,
new ColumnInformation()
{
Expand Down
24 changes: 12 additions & 12 deletions test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ public void RankingMetricsGetScoreTest()
double[] ndcg = { 0.2, 0.3, 0.4 };
double[] dcg = { 0.2, 0.3, 0.4 };
var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
Assert.Equal(0.4, GetScore(metrics, RankingMetric.Dcg));
Assert.Equal(0.4, GetScore(metrics, RankingMetric.Ndcg));
Assert.Equal(0.4, GetScore(metrics, RankingMetric.Dcg, "GroupId"));
Assert.Equal(0.4, GetScore(metrics, RankingMetric.Ndcg, "GroupId"));

double[] largeNdcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 };
double[] largeDcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 };
metrics = MetricsUtil.CreateRankingMetrics(largeDcg, largeNdcg);
Assert.Equal(0.9, GetScore(metrics, RankingMetric.Dcg));
Assert.Equal(0.9, GetScore(metrics, RankingMetric.Ndcg));
Assert.Equal(0.9, GetScore(metrics, RankingMetric.Dcg, "GroupId"));
Assert.Equal(0.9, GetScore(metrics, RankingMetric.Ndcg, "GroupId"));
}

[Fact]
Expand All @@ -143,8 +143,8 @@ public void RankingMetricsNonPerfectTest()
double[] ndcg = { 0.2, 0.3, 0.4 };
double[] dcg = { 0.2, 0.3, 0.4 };
var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg));
Assert.False(IsPerfectModel(metrics, RankingMetric.Ndcg));
Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, "GroupId"));
Assert.False(IsPerfectModel(metrics, RankingMetric.Ndcg, "GroupId"));
}

[Fact]
Expand All @@ -153,8 +153,8 @@ public void RankingMetricsPerfectTest()
double[] ndcg = { 0.2, 0.3, 1 };
double[] dcg = { 0.2, 0.3, 1 };
var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg)); //REVIEW: No true Perfect model
Assert.True(IsPerfectModel(metrics, RankingMetric.Ndcg));
Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, "GroupId")); //REVIEW: No true Perfect model
Assert.True(IsPerfectModel(metrics, RankingMetric.Ndcg, "GroupId"));
}

[Fact]
Expand All @@ -179,9 +179,9 @@ private static double GetScore(RegressionMetrics metrics, RegressionMetric metri
return new RegressionMetricsAgent(null, metric).GetScore(metrics);
}

private static double GetScore(RankingMetrics metrics, RankingMetric metric)
private static double GetScore(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName)
{
return new RankingMetricsAgent(null, metric).GetScore(metrics);
return new RankingMetricsAgent(null, metric, groupIdColumnName).GetScore(metrics);
}

private static bool IsPerfectModel(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric)
Expand All @@ -202,9 +202,9 @@ private static bool IsPerfectModel(RegressionMetrics metrics, RegressionMetric m
return IsPerfectModel(metricsAgent, metrics);
}

private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric)
private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName)
{
var metricsAgent = new RankingMetricsAgent(null, metric);
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
return IsPerfectModel(metricsAgent, metrics);
}

Expand Down