-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Added RankingEvaluatorOptions and removed the truncation limit. #4081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
8749a10
Merge pull request #1 from dotnet/master
harishsk fe25bf6
Fixed build errors resulting from upgrade to VS2019 compilers
harishsk cb446be
Added additional message describing the previous fix
harishsk b5ee220
Merge pull request #2 from dotnet/master
harishsk b9a7471
Merge pull request #3 from dotnet/master
harishsk 80e238d
Merge pull request #4 from dotnet/master
harishsk 2ef424d
Merge pull request #5 from dotnet/master
harishsk 3958f01
Merge pull request #6 from dotnet/master
harishsk 56d4595
Fixed #3993
harishsk 00bc7ef
Merge pull request #7 from dotnet/master
harishsk d0462f1
Merge pull request #8 from dotnet/master
harishsk 87cefbc
Merge branch 'master' into bugfix_3993
harishsk c3a908b
Reverted previous change and added a separate class to control evalua…
harishsk c0a430a
Merge pull request #9 from dotnet/master
harishsk 0b55903
Syncing upstream fork (#10)
harishsk 56983d5
Syncing upstream fork (#11)
harishsk 3382d1d
Merge pull request #13 from dotnet/master
harishsk 8ca5d01
Merge remote-tracking branch 'upstream/master'
harishsk 4ac459e
Added unit test for ranking evaluation with options
harishsk 8f20ea4
Merge remote-tracking branch 'upstream/master'
harishsk f9f9e1d
Changed visibility of OutputGroupSummary to internal until we expose …
harishsk 21cb8f3
Added a unit test for maml ranking evaluation
harishsk 138f201
Merge remote-tracking branch 'upstream/master'
harishsk 55e3460
Merge remote-tracking branch 'upstream/master'
harishsk e43bba3
Merge remote-tracking branch 'upstream/master'
harishsk 421d713
Merge branch 'master' of ssh://github.com/harishsk/machinelearning
harishsk 4f4f81c
Merge remote-tracking branch 'upstream/master'
harishsk 89082a5
Merge branch 'master' into bugfix_3993
harishsk f167af8
Merge branch 'master' of ssh://github.com/harishsk/machinelearning
harishsk 0d4d34f
getting rid of maxTruncationLevel
Lynx1820 6cd2f15
removed unnecessary imports
Lynx1820 1424ab3
removed old comment
Lynx1820 3ee03ca
removed old comment
Lynx1820 34b7a91
Merge remote-tracking branch 'upstream/master'
harishsk 5539127
Merge remote-tracking branch 'upstream/master'
harishsk 02053a6
adding a maml baseline test
Lynx1820 35ad3c0
reformatted maml test and moved to where all the other maml tests are
Lynx1820 0eb3e2b
Reverted back some spacing. Accidentally reverted some changes in Eva…
Lynx1820 a3291b1
added more relaxed precision
Lynx1820 37af437
Merge remote-tracking branch 'upstream/master'
harishsk 68f1f35
baseline precision set 1
Lynx1820 5b90a34
added more helpful debugging comment
Lynx1820 0efe238
temp
Lynx1820 b6584aa
Merge remote-tracking branch 'upstream/master'
harishsk 7d47832
Merge branch 'master' of ssh://github.com/harishsk/machinelearning
harishsk 0e99776
Merge branch 'master' into bugfix_3993
harishsk 20a4490
Revert "temp"
Lynx1820 0d111f4
Revert "added more relaxed precision"
Lynx1820 72d1a4d
using fastRankRanking to test RankingEvaluator instead of lightgbm
Lynx1820 ea9ebed
Merge branch 'bugfix_3993' of git://github.com/harishsk/machinelearni…
Lynx1820 013be4f
reverted some indenting and LightGBM imports
Lynx1820 d2ae365
removed the stratification flag to get real numbers
Lynx1820 a9e6db8
testcase change added
Lynx1820 5855f99
Added the Bestfriend attribute back
Lynx1820 724bb12
Changed implementation of discount map computation to minimize mutabl…
harishsk d009f55
Added doc strings
harishsk 8f7b6cd
Fixed bug in creation of fixed discount map
harishsk 30d56a0
Merge remote-tracking branch 'upstream/master' into bugfix_3993
harishsk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
using System.Linq; | ||
using System.Text; | ||
using System.Text.RegularExpressions; | ||
using System.Threading; | ||
using Microsoft.ML; | ||
using Microsoft.ML.CommandLine; | ||
using Microsoft.ML.Data; | ||
|
@@ -73,7 +74,6 @@ public RankingEvaluator(IHostEnvironment env, RankingEvaluatorOptions options) | |
|
||
_truncationLevel = options.DcgTruncationLevel; | ||
_groupSummary = options.OutputGroupSummary; | ||
RankingUtils.SetTruncationLevel(options.DcgTruncationLevel); | ||
|
||
var labelGains = new List<Double>(); | ||
string[] gains = options.LabelGains.Split(','); | ||
|
@@ -289,6 +289,7 @@ public sealed class Counters | |
private readonly List<short> _queryLabels; | ||
private readonly List<Single> _queryOutputs; | ||
private readonly Double[] _labelGains; | ||
private readonly Double[] _discountMap; | ||
|
||
public bool GroupSummary { get { return _groupNdcg != null; } } | ||
|
||
|
@@ -350,6 +351,8 @@ public Counters(Double[] labelGains, int truncationLevel, bool groupSummary) | |
Contracts.AssertValue(labelGains); | ||
|
||
TruncationLevel = truncationLevel; | ||
_discountMap = RankingUtils.GetDiscountMap(truncationLevel); | ||
|
||
_sumDcgAtN = new Double[TruncationLevel]; | ||
_sumNdcgAtN = new Double[TruncationLevel]; | ||
|
||
|
@@ -375,15 +378,15 @@ public void Update(short label, Single output) | |
|
||
public void UpdateGroup(Single weight) | ||
{ | ||
RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupMaxDcgCur); | ||
RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupMaxDcgCur); | ||
if (_groupMaxDcg != null) | ||
{ | ||
var maxDcg = new Double[TruncationLevel]; | ||
Array.Copy(_groupMaxDcgCur, maxDcg, TruncationLevel); | ||
_groupMaxDcg.Add(maxDcg); | ||
} | ||
|
||
RankingUtils.QueryDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupDcgCur); | ||
RankingUtils.QueryDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupDcgCur); | ||
if (_groupDcg != null) | ||
{ | ||
var groupDcg = new Double[TruncationLevel]; | ||
|
@@ -686,6 +689,7 @@ private void SlotNamesGetter(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst) | |
|
||
private readonly Bindings _bindings; | ||
private readonly int _truncationLevel; | ||
private readonly Double[] _discountMap; | ||
private readonly Double[] _labelGains; | ||
|
||
public Transform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol, | ||
|
@@ -697,6 +701,7 @@ public Transform(IHostEnvironment env, IDataView input, string labelCol, string | |
Host.CheckValue(labelGains, nameof(labelGains)); | ||
|
||
_truncationLevel = truncationLevel; | ||
_discountMap = RankingUtils.GetDiscountMap(_truncationLevel); | ||
_labelGains = labelGains; | ||
_bindings = new Bindings(Host, Source.Schema, true, LabelCol, ScoreCol, GroupCol, _truncationLevel); | ||
} | ||
|
@@ -802,9 +807,9 @@ protected override void ProcessExample(RowCursorState state, short label, Single | |
protected override void UpdateState(RowCursorState state) | ||
{ | ||
// Calculate the current group DCG, NDCG and MaxDcg. | ||
RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs, | ||
RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, | ||
state.MaxDcgCur); | ||
RankingUtils.QueryDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs, state.DcgCur); | ||
RankingUtils.QueryDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, state.DcgCur); | ||
for (int t = 0; t < _truncationLevel; t++) | ||
{ | ||
Double ndcg = state.MaxDcgCur[t] > 0 ? state.DcgCur[t] / state.MaxDcgCur[t] : 0; | ||
|
@@ -948,41 +953,41 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM | |
|
||
internal static class RankingUtils | ||
{ | ||
private static readonly object _lock = new object(); | ||
private static volatile Double[] _discountMap; | ||
private static volatile int _maxTruncationLevel = 0; | ||
// Truncation levels are typically less than 100. So we maintain a fixed discount map of size 100 | ||
// If truncation level greater than 100 is required, we build a new one and return that. | ||
private const int FixedDiscountMapSize = 100; | ||
private static Double[] _discountMapFixed; | ||
|
||
public static Double[] DiscountMap | ||
private static Double[] GetDiscountMapCore(int truncationLevel) | ||
{ | ||
get | ||
{ | ||
return _discountMap; | ||
} | ||
var discountMap = new Double[FixedDiscountMapSize]; | ||
|
||
for (int i = 0; i < discountMap.Length; i++) | ||
discountMap[i] = 1 / Math.Log(2 + i); | ||
|
||
return discountMap; | ||
} | ||
/// <summary> | ||
/// Reallocates discountMap for the largest truncationLevel seen so far | ||
/// </summary> | ||
public static void SetTruncationLevel(int truncationLevel) | ||
{ | ||
lock (_lock) { | ||
if (truncationLevel > _maxTruncationLevel) | ||
{ | ||
_maxTruncationLevel = truncationLevel; | ||
|
||
var discountMap = new Double[_maxTruncationLevel]; | ||
for (int i = 0; i < discountMap.Length; i++) | ||
{ | ||
discountMap[i] = 1 / Math.Log(2 + i); | ||
} | ||
_discountMap = discountMap; | ||
} | ||
public static Double[] GetDiscountMap(int truncationLevel) | ||
{ | ||
var discountMap = _discountMapFixed; | ||
if (discountMap == null) | ||
{ | ||
discountMap = GetDiscountMapCore(truncationLevel); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First check if the requested level is small enough for the “fixed” map, if it is small enough, use the cached one. If the cached one hasn’t been created yet, then create it using ‘FixedDiscountMapSize’. #Resolved |
||
Interlocked.CompareExchange(ref _discountMapFixed, discountMap, null); | ||
discountMap = _discountMapFixed; | ||
} | ||
|
||
if (truncationLevel <= discountMap.Length) | ||
return discountMap; | ||
|
||
return GetDiscountMapCore(truncationLevel); | ||
} | ||
|
||
/// <summary> | ||
/// Calculates natural-based max DCG at all truncations from 1 to truncationLevel. | ||
/// </summary> | ||
public static void QueryMaxDcg(Double[] labelGains, int truncationLevel, | ||
public static void QueryMaxDcg(Double[] labelGains, int truncationLevel, Double[] discountMap, | ||
List<short> queryLabels, List<Single> queryOutputs, Double[] groupMaxDcgCur) | ||
{ | ||
Contracts.Assert(Utils.Size(groupMaxDcgCur) == truncationLevel); | ||
|
@@ -1007,21 +1012,21 @@ public static void QueryMaxDcg(Double[] labelGains, int truncationLevel, | |
while (labelCounts[topLabel] == 0) | ||
topLabel--; | ||
|
||
groupMaxDcgCur[0] = labelGains[topLabel] * DiscountMap[0]; | ||
groupMaxDcgCur[0] = labelGains[topLabel] * discountMap[0]; | ||
labelCounts[topLabel]--; | ||
for (int t = 1; t < maxTrunc; t++) | ||
{ | ||
while (labelCounts[topLabel] == 0) | ||
topLabel--; | ||
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1] + labelGains[topLabel] * DiscountMap[t]; | ||
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1] + labelGains[topLabel] * discountMap[t]; | ||
labelCounts[topLabel]--; | ||
} | ||
for (int t = maxTrunc; t < truncationLevel; t++) | ||
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1]; | ||
} | ||
} | ||
|
||
public static void QueryDcg(Double[] labelGains, int truncationLevel, | ||
public static void QueryDcg(Double[] labelGains, int truncationLevel, Double[] discountMap, | ||
List<short> queryLabels, List<Single> queryOutputs, Double[] groupDcgCur) | ||
{ | ||
// calculate the permutation | ||
|
@@ -1034,7 +1039,7 @@ public static void QueryDcg(Double[] labelGains, int truncationLevel, | |
Double dcg = 0; | ||
for (int t = 0; t < count; ++t) | ||
{ | ||
dcg = dcg + labelGains[queryLabels[permutation[t]]] * DiscountMap[t]; | ||
dcg = dcg + labelGains[queryLabels[permutation[t]]] * discountMap[t]; | ||
groupDcgCur[t] = dcg; | ||
} | ||
for (int t = count; t < truncationLevel; ++t) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should use the truncationLevel passed into the function. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops. Sorry. Fixed with the latest checkin.
In reply to: 333276567 [](ancestors = 333276567)