Skip to content

Added RankingEvaluatorOptions and removed the truncation limit. #4081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 58 commits into from
Oct 10, 2019
Merged
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
8749a10
Merge pull request #1 from dotnet/master
harishsk Jun 20, 2019
fe25bf6
Fixed build errors resulting from upgrade to VS2019 compilers
harishsk Jun 21, 2019
cb446be
Added additional message describing the previous fix
harishsk Jun 21, 2019
b5ee220
Merge pull request #2 from dotnet/master
harishsk Jun 21, 2019
b9a7471
Merge pull request #3 from dotnet/master
harishsk Jun 24, 2019
80e238d
Merge pull request #4 from dotnet/master
harishsk Jun 27, 2019
2ef424d
Merge pull request #5 from dotnet/master
harishsk Jul 9, 2019
3958f01
Merge pull request #6 from dotnet/master
harishsk Aug 7, 2019
56d4595
Fixed #3993
harishsk Aug 7, 2019
00bc7ef
Merge pull request #7 from dotnet/master
harishsk Aug 14, 2019
d0462f1
Merge pull request #8 from dotnet/master
harishsk Aug 16, 2019
87cefbc
Merge branch 'master' into bugfix_3993
harishsk Aug 16, 2019
c3a908b
Reverted previous change and added a separate class to control evalua…
harishsk Aug 21, 2019
c0a430a
Merge pull request #9 from dotnet/master
harishsk Aug 21, 2019
0b55903
Syncing upstream fork (#10)
harishsk Aug 30, 2019
56983d5
Syncing upstream fork (#11)
harishsk Aug 30, 2019
3382d1d
Merge pull request #13 from dotnet/master
harishsk Sep 6, 2019
8ca5d01
Merge remote-tracking branch 'upstream/master'
harishsk Sep 10, 2019
4ac459e
Added unit test for ranking evaluation with options
harishsk Sep 11, 2019
8f20ea4
Merge remote-tracking branch 'upstream/master'
harishsk Sep 11, 2019
f9f9e1d
Changed visibility of OutputGroupSummary to internal until we expose …
harishsk Sep 11, 2019
21cb8f3
Added a unit test for maml ranking evaluation
harishsk Sep 12, 2019
138f201
Merge remote-tracking branch 'upstream/master'
harishsk Sep 13, 2019
55e3460
Merge remote-tracking branch 'upstream/master'
harishsk Sep 13, 2019
e43bba3
Merge remote-tracking branch 'upstream/master'
harishsk Sep 16, 2019
421d713
Merge branch 'master' of ssh://github.com/harishsk/machinelearning
harishsk Sep 16, 2019
4f4f81c
Merge remote-tracking branch 'upstream/master'
harishsk Sep 17, 2019
89082a5
Merge branch 'master' into bugfix_3993
harishsk Sep 17, 2019
f167af8
Merge branch 'master' of ssh://github.com/harishsk/machinelearning
harishsk Sep 17, 2019
0d4d34f
getting rid of maxTruncationLevel
Lynx1820 Sep 18, 2019
6cd2f15
removed unnecessary imports
Lynx1820 Sep 18, 2019
1424ab3
removed old comment
Lynx1820 Sep 18, 2019
3ee03ca
removed old comment
Lynx1820 Sep 18, 2019
34b7a91
Merge remote-tracking branch 'upstream/master'
harishsk Sep 19, 2019
5539127
Merge remote-tracking branch 'upstream/master'
harishsk Sep 19, 2019
02053a6
adding a maml baseline test
Lynx1820 Sep 23, 2019
35ad3c0
reformatted maml test and moved to where all the other maml tests are
Lynx1820 Sep 25, 2019
0eb3e2b
Reverted back some spacing. Accidentally reverted some changes in Eva…
Lynx1820 Sep 25, 2019
a3291b1
added more relaxed precision
Lynx1820 Sep 25, 2019
37af437
Merge remote-tracking branch 'upstream/master'
harishsk Sep 26, 2019
68f1f35
baseline precision set 1
Lynx1820 Sep 26, 2019
5b90a34
added more helpful debugging comment
Lynx1820 Sep 26, 2019
0efe238
temp
Lynx1820 Sep 27, 2019
b6584aa
Merge remote-tracking branch 'upstream/master'
harishsk Sep 27, 2019
7d47832
Merge branch 'master' of ssh://github.com/harishsk/machinelearning
harishsk Sep 27, 2019
0e99776
Merge branch 'master' into bugfix_3993
harishsk Sep 27, 2019
20a4490
Revert "temp"
Lynx1820 Sep 30, 2019
0d111f4
Revert "added more relaxed precision"
Lynx1820 Sep 30, 2019
72d1a4d
using fastRankRanking to test RankingEvaluator instead of lightgbm
Lynx1820 Sep 30, 2019
ea9ebed
Merge branch 'bugfix_3993' of git://github.com/harishsk/machinelearni…
Lynx1820 Sep 30, 2019
013be4f
reverted some indenting and LightGBM imports
Lynx1820 Sep 30, 2019
d2ae365
removed the stratification flag to get real numbers
Lynx1820 Oct 7, 2019
a9e6db8
testcase change added
Lynx1820 Oct 7, 2019
5855f99
Added the Bestfriend attribute back
Lynx1820 Oct 7, 2019
724bb12
Changed implementation of discount map computation to minimize mutabl…
harishsk Oct 9, 2019
d009f55
Added doc strings
harishsk Oct 9, 2019
8f7b6cd
Fixed bug in creation of fixed discount map
harishsk Oct 9, 2019
30d56a0
Merge remote-tracking branch 'upstream/master' into bugfix_3993
harishsk Oct 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions src/Microsoft.ML.Data/Evaluators/RankingEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -73,7 +74,6 @@ public RankingEvaluator(IHostEnvironment env, RankingEvaluatorOptions options)

_truncationLevel = options.DcgTruncationLevel;
_groupSummary = options.OutputGroupSummary;
RankingUtils.SetTruncationLevel(options.DcgTruncationLevel);

var labelGains = new List<Double>();
string[] gains = options.LabelGains.Split(',');
Expand Down Expand Up @@ -289,6 +289,7 @@ public sealed class Counters
private readonly List<short> _queryLabels;
private readonly List<Single> _queryOutputs;
private readonly Double[] _labelGains;
private readonly Double[] _discountMap;

public bool GroupSummary { get { return _groupNdcg != null; } }

Expand Down Expand Up @@ -350,6 +351,8 @@ public Counters(Double[] labelGains, int truncationLevel, bool groupSummary)
Contracts.AssertValue(labelGains);

TruncationLevel = truncationLevel;
_discountMap = RankingUtils.GetDiscountMap(truncationLevel);

_sumDcgAtN = new Double[TruncationLevel];
_sumNdcgAtN = new Double[TruncationLevel];

Expand All @@ -375,15 +378,15 @@ public void Update(short label, Single output)

public void UpdateGroup(Single weight)
{
RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupMaxDcgCur);
RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupMaxDcgCur);
if (_groupMaxDcg != null)
{
var maxDcg = new Double[TruncationLevel];
Array.Copy(_groupMaxDcgCur, maxDcg, TruncationLevel);
_groupMaxDcg.Add(maxDcg);
}

RankingUtils.QueryDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupDcgCur);
RankingUtils.QueryDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupDcgCur);
if (_groupDcg != null)
{
var groupDcg = new Double[TruncationLevel];
Expand Down Expand Up @@ -686,6 +689,7 @@ private void SlotNamesGetter(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)

private readonly Bindings _bindings;
private readonly int _truncationLevel;
private readonly Double[] _discountMap;
private readonly Double[] _labelGains;

public Transform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol,
Expand All @@ -697,6 +701,7 @@ public Transform(IHostEnvironment env, IDataView input, string labelCol, string
Host.CheckValue(labelGains, nameof(labelGains));

_truncationLevel = truncationLevel;
_discountMap = RankingUtils.GetDiscountMap(_truncationLevel);
_labelGains = labelGains;
_bindings = new Bindings(Host, Source.Schema, true, LabelCol, ScoreCol, GroupCol, _truncationLevel);
}
Expand Down Expand Up @@ -802,9 +807,9 @@ protected override void ProcessExample(RowCursorState state, short label, Single
protected override void UpdateState(RowCursorState state)
{
// Calculate the current group DCG, NDCG and MaxDcg.
RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs,
RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs,
state.MaxDcgCur);
RankingUtils.QueryDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs, state.DcgCur);
RankingUtils.QueryDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, state.DcgCur);
for (int t = 0; t < _truncationLevel; t++)
{
Double ndcg = state.MaxDcgCur[t] > 0 ? state.DcgCur[t] / state.MaxDcgCur[t] : 0;
Expand Down Expand Up @@ -948,41 +953,41 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM

internal static class RankingUtils
{
private static readonly object _lock = new object();
private static volatile Double[] _discountMap;
private static volatile int _maxTruncationLevel = 0;
// Truncation levels are typically less than 100. So we maintain a fixed discount map of size 100
// If truncation level greater than 100 is required, we build a new one and return that.
private const int FixedDiscountMapSize = 100;
private static Double[] _discountMapFixed;

public static Double[] DiscountMap
private static Double[] GetDiscountMapCore(int truncationLevel)
{
get
{
return _discountMap;
}
var discountMap = new Double[FixedDiscountMapSize];
Copy link
Member

@eerhardt eerhardt Oct 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use the truncationLevel passed into the function. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Sorry. Fixed with the latest checkin.


In reply to: 333276567 [](ancestors = 333276567)


for (int i = 0; i < discountMap.Length; i++)
discountMap[i] = 1 / Math.Log(2 + i);

return discountMap;
}
/// <summary>
/// Reallocates discountMap for the largest truncationLevel seen so far
/// </summary>
public static void SetTruncationLevel(int truncationLevel)
{
lock (_lock) {
if (truncationLevel > _maxTruncationLevel)
{
_maxTruncationLevel = truncationLevel;

var discountMap = new Double[_maxTruncationLevel];
for (int i = 0; i < discountMap.Length; i++)
{
discountMap[i] = 1 / Math.Log(2 + i);
}
_discountMap = discountMap;
}
public static Double[] GetDiscountMap(int truncationLevel)
{
var discountMap = _discountMapFixed;
if (discountMap == null)
{
discountMap = GetDiscountMapCore(truncationLevel);
Copy link
Member

@eerhardt eerhardt Oct 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First check if the requested level is small enough for the “fixed” map, if it is small enough, use the cached one. If the cached one hasn’t been created yet, then create it using ‘FixedDiscountMapSize’. #Resolved

Interlocked.CompareExchange(ref _discountMapFixed, discountMap, null);
discountMap = _discountMapFixed;
}

if (truncationLevel <= discountMap.Length)
return discountMap;

return GetDiscountMapCore(truncationLevel);
}

/// <summary>
/// Calculates natural-based max DCG at all truncations from 1 to truncationLevel.
/// </summary>
public static void QueryMaxDcg(Double[] labelGains, int truncationLevel,
public static void QueryMaxDcg(Double[] labelGains, int truncationLevel, Double[] discountMap,
List<short> queryLabels, List<Single> queryOutputs, Double[] groupMaxDcgCur)
{
Contracts.Assert(Utils.Size(groupMaxDcgCur) == truncationLevel);
Expand All @@ -1007,21 +1012,21 @@ public static void QueryMaxDcg(Double[] labelGains, int truncationLevel,
while (labelCounts[topLabel] == 0)
topLabel--;

groupMaxDcgCur[0] = labelGains[topLabel] * DiscountMap[0];
groupMaxDcgCur[0] = labelGains[topLabel] * discountMap[0];
labelCounts[topLabel]--;
for (int t = 1; t < maxTrunc; t++)
{
while (labelCounts[topLabel] == 0)
topLabel--;
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1] + labelGains[topLabel] * DiscountMap[t];
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1] + labelGains[topLabel] * discountMap[t];
labelCounts[topLabel]--;
}
for (int t = maxTrunc; t < truncationLevel; t++)
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1];
}
}

public static void QueryDcg(Double[] labelGains, int truncationLevel,
public static void QueryDcg(Double[] labelGains, int truncationLevel, Double[] discountMap,
List<short> queryLabels, List<Single> queryOutputs, Double[] groupDcgCur)
{
// calculate the permutation
Expand All @@ -1034,7 +1039,7 @@ public static void QueryDcg(Double[] labelGains, int truncationLevel,
Double dcg = 0;
for (int t = 0; t < count; ++t)
{
dcg = dcg + labelGains[queryLabels[permutation[t]]] * DiscountMap[t];
dcg = dcg + labelGains[queryLabels[permutation[t]]] * discountMap[t];
groupDcgCur[t] = dcg;
}
for (int t = count; t < truncationLevel; ++t)
Expand Down