Skip to content

Add Ranking AutoML Sample #852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 1, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Review feedback
  • Loading branch information
justinormont committed Dec 18, 2020
commit 49457a318750e5f34134798cad5992dfdfa7284e
26 changes: 13 additions & 13 deletions samples/csharp/common/AutoML/ConsoleHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,24 @@ public static void PrintBinaryClassificationMetrics(string name, BinaryClassific
public static void PrintMulticlassClassificationMetrics(string name, MulticlassClassificationMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value from 0 and 1, where closer to 1.0 is better");
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value from 0 and 1, where closer to 1.0 is better");
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
Console.WriteLine($"************************************************************");
}

public static void PrintRankingMetrics(string name, RankingMetrics metrics)
public static void PrintRankingMetrics(string name, RankingMetrics metrics, uint optimizationMetricTruncationLevel)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for {name} ranking model ");
Console.WriteLine($"* Metrics for {name} ranking model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" Discounted Cumulative Gain (DCG@10) = {metrics?.DiscountedCumulativeGains?[9] ?? double.NaN:0.####}");
Console.WriteLine($" Normalized Discounted Cumulative Gain (NDCG@10) = {metrics?.NormalizedDiscountedCumulativeGains?[9] ?? double.NaN:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" Normalized Discounted Cumulative Gain (NDCG@{optimizationMetricTruncationLevel}) = {metrics?.NormalizedDiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}, a value from 0 and 1, where closer to 1.0 is better");
Console.WriteLine($" Discounted Cumulative Gain (DCG@{optimizationMetricTruncationLevel}) = {metrics?.DiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}");
}

public static void ShowDataViewInConsole(MLContext mlContext, IDataView dataView, int numberOfRows = 4)
Expand Down Expand Up @@ -100,7 +100,7 @@ internal static void PrintIterationMetrics(int iteration, string trainerName, Re

internal static void PrintIterationMetrics(int iteration, string trainerName, RankingMetrics metrics, double? runtimeInSeconds)
{
CreateRow($"{iteration,-4} {trainerName,-9} {metrics?.NormalizedDiscountedCumulativeGains[0] ?? double.NaN,9:F4}, {metrics?.NormalizedDiscountedCumulativeGains[2] ?? double.NaN,9:F4}, {metrics?.NormalizedDiscountedCumulativeGains[9] ?? double.NaN,9:F4} {metrics?.DiscountedCumulativeGains[9] ?? double.NaN,9:F4} {runtimeInSeconds.Value,9:F1}", Width);
CreateRow($"{iteration,-4} {trainerName,-15} {metrics?.NormalizedDiscountedCumulativeGains[0] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[2] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[9] ?? double.NaN,9:F4} {metrics?.DiscountedCumulativeGains[9] ?? double.NaN,9:F4} {runtimeInSeconds.Value,9:F1}", Width);
}

internal static void PrintIterationException(Exception ex)
Expand All @@ -125,7 +125,7 @@ internal static void PrintRegressionMetricsHeader()

internal static void PrintRankingMetricsHeader()
{
CreateRow($"{"",-4} {"Trainer",-14}, {"nDCG@1",9}, {"nDCG@3",9}, {"nDCG@10",9}, {"DCG@10",9}, {"Duration",9}", Width);
CreateRow($"{"",-4} {"Trainer",-15} {"NDCG@1",9} {"NDCG@3",9} {"NDCG@10",9} {"DCG@10",9} {"Duration",9}", Width);
}

private static void CreateRow(string message, int width)
Expand Down Expand Up @@ -258,10 +258,10 @@ private void AppendTableRow(ICollection<string[]> tableRows,

tableRows.Add(new[]
{
columnName,
GetColumnDataType(columnName),
columnPurpose
});
columnName,
GetColumnDataType(columnName),
columnPurpose
});
}

private void AppendTableRows(ICollection<string[]> tableRows,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Common;
using Common;
using Microsoft.ML;
using Microsoft.ML.AutoML;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -28,23 +28,21 @@ class Program
// Runtime should allow for the sweeping to plateau, which begins near iteration 60
private static uint ExperimentTime = 600;

private static IDataView _predictions = null;

static void Main(string[] args)
{
var mlContext = new MLContext(seed: 0);

// Create, train, evaluate and save a model
BuildTrainEvaluateAndSaveModel(mlContext);
(var model, var predictions) = BuildTrainEvaluateAndSaveModel(mlContext);

// Make a single test prediction loading the model from .ZIP file
TestSinglePrediction(mlContext);
TestSinglePrediction(mlContext, predictions);

Console.WriteLine("=============== End of process, hit any key to finish ===============");
Console.ReadKey();
}

private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
private static (ITransformer, IDataView) BuildTrainEvaluateAndSaveModel(MLContext mlContext)
{
// STEP 1: Download and load the data
GetData(InputPath, OutputPath, TrainDatasetPath, TrainDatasetUrl, TestDatasetUrl, TestDatasetPath,
Expand Down Expand Up @@ -96,69 +94,58 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
progressHandler: progressHandler);

// Print top models found by AutoML
Console.WriteLine();
PrintTopModels(experimentResult);
Console.WriteLine("\n===== Evaluating model's NDCG (on validation data) =====");
PrintTopModels(experimentResult, experimentSettings.OptimizationMetricTruncationLevel);

var rankingEvaluatorOptions = new RankingEvaluatorOptions
{
DcgTruncationLevel = Math.Min(10, (int)experimentSettings.OptimizationMetricTruncationLevel * 2)
};

Console.WriteLine("\n===== Evaluating model's NDCG (on test data) =====");
IDataView predictions = experimentResult.BestRun.Model.Transform(testDataView);
var metrics = mlContext.Ranking.Evaluate(predictions, rankingEvaluatorOptions);
ConsoleHelper.PrintRankingMetrics(experimentResult.BestRun.TrainerName, metrics, experimentSettings.OptimizationMetricTruncationLevel);

// STEP 5: Refit the model with all available data
// Re-fit best pipeline on train and validation data, to produce
// a model that is trained on as much data as is available while
// still having test data for the final estimate of how well the
// model will do in production.
Console.WriteLine("\n===== Refitting on train+valid and evaluating model's nDCG with test data =====");
Console.WriteLine("\n===== Refitting on train+valid and evaluating model's NDCG (on test data) =====");
var trainPlusValidationDataView = textLoader.Load(new MultiFileSource(TrainDatasetPath, ValidationDatasetPath));

var refitModel = experimentResult.BestRun.Estimator.Fit(trainPlusValidationDataView);

IDataView predictionsRefitOnTrainPlusValidation = refitModel.Transform(validationDataView);

// Setting the DCG truncation level
var rankingEvaluatorOptions = new RankingEvaluatorOptions { DcgTruncationLevel = 10 };

IDataView predictionsRefitOnTrainPlusValidation = refitModel.Transform(testDataView);
var metricsRefitOnTrainPlusValidation = mlContext.Ranking.Evaluate(predictionsRefitOnTrainPlusValidation, rankingEvaluatorOptions);
ConsoleHelper.PrintRankingMetrics(experimentResult.BestRun.TrainerName, metricsRefitOnTrainPlusValidation, experimentSettings.OptimizationMetricTruncationLevel);

ConsoleHelper.PrintRankingMetrics(experimentResult.BestRun.TrainerName, metricsRefitOnTrainPlusValidation);

// Re-fit best pipeline on train, validation, and test data, to
// STEP 6: Refit the model with all available data
// Re-fit best pipeline again on train, validation, and test data, to
// produce a model that is trained on as much data as is available.
// This is the final model that can be deployed to production.
// No metrics are printed since we no longer have an independent
// scoring dataset.
Console.WriteLine("\n===== Refitting on train+valid+test to get the final model to launch to production =====");
var trainPlusValidationPlusTestDataView = textLoader.Load(new MultiFileSource(TrainDatasetPath, ValidationDatasetPath, TestDatasetPath));

var refitModelWithValidationSet = experimentResult.BestRun.Estimator.Fit(trainPlusValidationPlusTestDataView);

IDataView predictionsRefitOnTrainValidationPlusTest = refitModelWithValidationSet.Transform(testDataView);

var metricsRefitOnTrainValidationPlusTest = mlContext.Ranking.Evaluate(predictionsRefitOnTrainValidationPlusTest, rankingEvaluatorOptions);

ConsoleHelper.PrintRankingMetrics(experimentResult.BestRun.TrainerName, metricsRefitOnTrainValidationPlusTest);

// STEP 5: Evaluate the model and print metrics
ConsoleHelper.ConsoleWriteHeader("=============== Evaluating model's nDCG with test data ===============");
RunDetail<RankingMetrics> bestRun = experimentResult.BestRun;

ITransformer trainedModel = bestRun.Model;
_predictions = trainedModel.Transform(testDataView);

var metrics = mlContext.Ranking.Evaluate(_predictions, rankingEvaluatorOptions);

ConsoleHelper.PrintRankingMetrics(bestRun.TrainerName, metrics);

// STEP 6: Save/persist the trained model to a .ZIP file
mlContext.Model.Save(trainedModel, trainDataView.Schema, ModelPath);
var refitModelOnTrainValidTest = experimentResult.BestRun.Estimator.Fit(trainPlusValidationPlusTestDataView);

// STEP 7: Save/persist the trained model to a .ZIP file
mlContext.Model.Save(refitModelOnTrainValidTest, trainDataView.Schema, ModelPath);

Console.WriteLine("The model is saved to {0}", ModelPath);

return trainedModel;
return (refitModelOnTrainValidTest, predictionsRefitOnTrainPlusValidation);
}

private static void TestSinglePrediction(MLContext mlContext)
private static void TestSinglePrediction(MLContext mlContext, IDataView predictions)
{
ConsoleHelper.ConsoleWriteHeader("=============== Testing prediction engine ===============");

ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema);
Console.WriteLine($"=============== Loaded Model OK ===============");

// In the predictions, get the scores of the search results included in the first query (e.g. group).
var searchQueries = mlContext.Data.CreateEnumerable<RankingPrediction>(_predictions, reuseRowObject: false);
var searchQueries = mlContext.Data.CreateEnumerable<RankingPrediction>(predictions, reuseRowObject: false);
var firstGroupId = searchQueries.First().GroupId;
var firstGroupPredictions = searchQueries.Take(100).Where(p => p.GroupId == firstGroupId).OrderByDescending(p => p.Score).ToList();

Expand Down Expand Up @@ -215,14 +202,14 @@ private static void GetData(string inputPath, string outputPath, string trainDat
Console.WriteLine("===== Download is finished =====\n");
}

private static void PrintTopModels(ExperimentResult<RankingMetrics> experimentResult)
private static void PrintTopModels(ExperimentResult<RankingMetrics> experimentResult, uint optimizationMetricTruncationLevel)
{
// Get top few runs ranked by nDCG
// Get top few runs ordered by NDCG
var topRuns = experimentResult.RunDetails
.Where(r => r.ValidationMetrics != null && !double.IsNaN(r.ValidationMetrics.NormalizedDiscountedCumulativeGains[0]))
.OrderByDescending(r => r.ValidationMetrics.NormalizedDiscountedCumulativeGains[9]).Take(5);
.Where(r => r.ValidationMetrics != null && !double.IsNaN(r.ValidationMetrics.NormalizedDiscountedCumulativeGains[(int)optimizationMetricTruncationLevel - 1]))
.OrderByDescending(r => r.ValidationMetrics.NormalizedDiscountedCumulativeGains[(int)optimizationMetricTruncationLevel - 1]).Take(5);

Console.WriteLine("Top models ranked by nDCG --");
Console.WriteLine($"Top models ordered by NDCG@{optimizationMetricTruncationLevel}");
ConsoleHelper.PrintRankingMetricsHeader();
for (var i = 0; i < topRuns.Count(); i++)
{
Expand Down