Skip to content

Add Ranking AutoML Sample #852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions samples/csharp/common/AutoML/ConsoleHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,26 @@ public static void PrintBinaryClassificationMetrics(string name, BinaryClassific
public static void PrintMulticlassClassificationMetrics(string name, MulticlassClassificationMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value from 0 and 1, where closer to 1.0 is better");
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value from 0 and 1, where closer to 1.0 is better");
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
Console.WriteLine($"************************************************************");
}

public static void PrintRankingMetrics(string name, RankingMetrics metrics, uint optimizationMetricTruncationLevel)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for {name} ranking model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" Normalized Discounted Cumulative Gain (NDCG@{optimizationMetricTruncationLevel}) = {metrics?.NormalizedDiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}, a value from 0 and 1, where closer to 1.0 is better");
Console.WriteLine($" Discounted Cumulative Gain (DCG@{optimizationMetricTruncationLevel}) = {metrics?.DiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}");
}

public static void ShowDataViewInConsole(MLContext mlContext, IDataView dataView, int numberOfRows = 4)
{
string msg = string.Format("Show data in DataView: Showing {0} rows with the columns", numberOfRows.ToString());
Expand Down Expand Up @@ -89,6 +98,11 @@ internal static void PrintIterationMetrics(int iteration, string trainerName, Re
CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,8:F4} {metrics?.MeanAbsoluteError ?? double.NaN,13:F2} {metrics?.MeanSquaredError ?? double.NaN,12:F2} {metrics?.RootMeanSquaredError ?? double.NaN,8:F2} {runtimeInSeconds.Value,9:F1}", Width);
}

internal static void PrintIterationMetrics(int iteration, string trainerName, RankingMetrics metrics, double? runtimeInSeconds)
{
CreateRow($"{iteration,-4} {trainerName,-15} {metrics?.NormalizedDiscountedCumulativeGains[0] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[2] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[9] ?? double.NaN,9:F4} {metrics?.DiscountedCumulativeGains[9] ?? double.NaN,9:F4} {runtimeInSeconds.Value,9:F1}", Width);
}

internal static void PrintIterationException(Exception ex)
{
Console.WriteLine($"Exception during AutoML iteration: {ex}");
Expand All @@ -109,6 +123,11 @@ internal static void PrintRegressionMetricsHeader()
CreateRow($"{"",-4} {"Trainer",-35} {"RSquared",8} {"Absolute-loss",13} {"Squared-loss",12} {"RMS-loss",8} {"Duration",9}", Width);
}

internal static void PrintRankingMetricsHeader()
{
CreateRow($"{"",-4} {"Trainer",-15} {"NDCG@1",9} {"NDCG@3",9} {"NDCG@10",9} {"DCG@10",9} {"Duration",9}", Width);
}

private static void CreateRow(string message, int width)
{
Console.WriteLine("|" + message.PadRight(width - 2) + "|");
Expand Down Expand Up @@ -239,10 +258,10 @@ private void AppendTableRow(ICollection<string[]> tableRows,

tableRows.Add(new[]
{
columnName,
GetColumnDataType(columnName),
columnPurpose
});
columnName,
GetColumnDataType(columnName),
columnPurpose
});
}

private void AppendTableRows(ICollection<string[]> tableRows,
Expand Down
23 changes: 23 additions & 0 deletions samples/csharp/common/AutoML/ProgressHandlers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,27 @@ public void Report(RunDetail<RegressionMetrics> iterationResult)
}
}
}

public class RankingExperimentProgressHandler : IProgress<RunDetail<RankingMetrics>>
{
private int _iterationIndex;

public void Report(RunDetail<RankingMetrics> iterationResult)
{
if (_iterationIndex++ == 0)
{
ConsoleHelper.PrintRankingMetricsHeader();
}

if (iterationResult.Exception != null)
{
ConsoleHelper.PrintIterationException(iterationResult.Exception);
}
else
{
ConsoleHelper.PrintIterationMetrics(_iterationIndex, iterationResult.TrainerName,
iterationResult.ValidationMetrics, iterationResult.RuntimeInSeconds);
}
}
}
}
Loading