Skip to content

Support for Categorical features in CalculateFeatureContribution of LightGBM #5018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RegressionTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public abstract class RegressionTreeBase
/// (2) the categorical features indexed by <see cref="GetCategoricalCategoricalSplitFeatureRangeAt(int)"/>'s
/// returned value with nodeIndex=i is NOT a sub-set of <see cref="GetCategoricalSplitFeaturesAt(int)"/> with
/// nodeIndex=i.
/// Note that the case (1) happens only when <see cref="CategoricalSplitFlags"/>[i] is true and otherwise (2)
/// Note that the case (1) happens only when <see cref="CategoricalSplitFlags"/>[i] is false and otherwise (2)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me this doc was wrong, as it's inconsistent with what these other lines below say:

/// <summary>
/// <see cref="NumericalSplitFeatureIndexes"/>[i] is the feature index used the splitting function of the
/// i-th node. This value is valid only if <see cref="CategoricalSplitFlags"/>[i] is false.
/// </summary>
public IReadOnlyList<int> NumericalSplitFeatureIndexes => _numericalSplitFeatureIndexes;

/// <summary>
/// Return categorical thresholds used at node indexed by nodeIndex. If the considered input feature does NOT
/// matche any of values returned by <see cref="GetCategoricalSplitFeaturesAt(int)"/>, we call it a
/// less-than-threshold event and therefore <see cref="LeftChild"/>[nodeIndex] is the child node that input
/// should go next. The returned value is valid only if <see cref="CategoricalSplitFlags"/>[nodeIndex] is true.
/// </summary>
public IReadOnlyList<int> GetCategoricalSplitFeaturesAt(int nodeIndex)

/// occurs. A non-negative returned value means a node (i.e., not a leaf); for example, 2 means the 3rd node in
/// the underlying <see cref="_tree"/>. A negative returned value means a leaf; for example, -1 stands for the
/// <see langword="~"/>(-1)-th leaf in the underlying <see cref="_tree"/>. Note that <see langword="~"/> is the
Expand Down
85 changes: 70 additions & 15 deletions src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1515,28 +1515,83 @@ public void AppendFeatureContributions(in VBuffer<float> src, BufferBuilder<floa
int node = 0;
while (node >= 0)
{
int ifeat = SplitFeatures[node];
var val = src.GetItemOrDefault(ifeat);
val = GetFeatureValue(val, node);
int otherWay;
if (val <= RawThresholds[node])
if (CategoricalSplit[node])
{
otherWay = GtChild[node];
node = LteChild[node];
Contracts.Assert(CategoricalSplitFeatures != null);
bool match = false;
int selectedIndex = -1;
int newNode = 0;
foreach (var index in CategoricalSplitFeatures[node])
{
float fv = GetFeatureValue(src.GetItemOrDefault(index), node);
if (fv > 0.0f)
{
match = true;
selectedIndex = index; // We only expect at most one match
break;
}
}

// If the ghost got a smaller output, the contribution of the categorical features is positive, so
// the contribution is true minus ghost.
if (match)
{
newNode = GtChild[node];
otherWay = LteChild[node];

var ghostLeaf = GetLeafFrom(in src, otherWay);
var ghostOutput = GetOutput(ghostLeaf);
var diff = (float)(trueOutput - ghostOutput);
foreach (var index in CategoricalSplitFeatures[node])
{
if (index == selectedIndex) // this index caused the input to go to the GtChild
contributions.AddFeature(index, diff);
else // All of the others wouldn't cause it
contributions.AddFeature(index, -diff);
}
}
else
{
newNode = LteChild[node];
otherWay = GtChild[node];

var ghostLeaf = GetLeafFrom(in src, otherWay);
var ghostOutput = GetOutput(ghostLeaf);
var diff = (float)(trueOutput - ghostOutput);

// None of the indices caused the input to go to the GtChild,
// So all of them caused it to go to the Lte.
foreach (var index in CategoricalSplitFeatures[node])
contributions.AddFeature(index, diff);
}

node = newNode;
}
else
{
otherWay = LteChild[node];
node = GtChild[node];
}
int ifeat = SplitFeatures[node];
var val = src.GetItemOrDefault(ifeat);
val = GetFeatureValue(val, node);
if (val <= RawThresholds[node])
{
otherWay = GtChild[node];
node = LteChild[node];
}
else
{
otherWay = LteChild[node];
node = GtChild[node];
}

// What if we went the other way?
var ghostLeaf = GetLeafFrom(in src, otherWay);
var ghostOutput = GetOutput(ghostLeaf);
// What if we went the other way?
var ghostLeaf = GetLeafFrom(in src, otherWay);
var ghostOutput = GetOutput(ghostLeaf);

// If the ghost got a smaller output, the contribution of the feature is positive, so
// the contribution is true minus ghost.
contributions.AddFeature(ifeat, (float)(trueOutput - ghostOutput));
// If the ghost got a smaller output, the contribution of the feature is positive, so
// the contribution is true minus ghost.
contributions.AddFeature(ifeat, (float)(trueOutput - ghostOutput));
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#@ TextLoader{
#@ sep=tab
#@ col=VendorId:TX:0
#@ col=RateCode:R4:1
#@ col=PassengerCount:R4:2
#@ col=PassengerCount:R4:3
#@ col=TripTime:R4:4
#@ col=TripTime:R4:5
#@ col=TripDistance:R4:6
#@ col=TripDistance:R4:7
#@ col=PaymentType:TX:8
#@ col=FareAmount:R4:9
#@ col=Label:R4:10
#@ col=VendorIdEncoded:U4[1]:11
#@ col=VendorIdEncoded:R4:12-12
#@ col=RateCodeEncoded:U4[2]:13
#@ col=RateCodeEncoded:R4:14-15
#@ col=PaymentTypeEncoded:U4[3]:16
#@ col=PaymentTypeEncoded:R4:17-19
#@ col=Features:R4:20-28
#@ col=FeatureContributions:R4:29-37
#@ col=FeatureContributions:R4:38-46
#@ col=FeatureContributions:R4:47-55
#@ col=FeatureContributions:R4:56-64
#@ }
CMT 1 1 0.7088812 1271 1.64874518 3.8 1.0118916 CRD 17.5 17.5 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 1.64874518 1.0118916 36 4:0.107879594 7:0.725665748 8:1 15:-1 24:-0.0418495 26:1 33:-0.370121539 35:8.844109
CMT 1 1 0.7088812 474 0.6148743 1.5 0.3994309 CRD 8 8 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 0.6148743 0.3994309 36 4:1 15:-0.0364986733 16:-0.847436965 17:-1 22:0.011381451 26:-1 31:0.115415707 35:-10.1406841
CMT 1 1 0.7088812 637 0.8263184 1.4 0.372802168 CRD 8.5 8.5 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 0.8263184 0.372802168 36 4:1 15:-0.0366709046 16:-0.5593253 17:-1 22:0.0182117485 26:-1 31:0.183812216 35:-10.0930576
CMT 1 1 0.7088812 181 0.234793767 0.6 0.159772366 CSH 4.5 4.5 0 1 0 1 0 1 0 1 0 1 1 0 0 1 0 0.7088812 0.234793767 0.159772366 36 6:1 13:-0.293414325 16:-0.7202999 17:-1 24:0.0291313324 26:-1 33:0.33991462 35:-11.6683512
57 changes: 57 additions & 0 deletions test/Microsoft.ML.Tests/FeatureContributionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML.TestFramework.Attributes;
using Microsoft.ML.TestFrameworkCommon.Attributes;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.LightGbm;
using Xunit;
using Xunit.Abstractions;

Expand Down Expand Up @@ -52,6 +53,12 @@ public void TestLightGbmRegression()
TestFeatureContribution(ML.Regression.Trainers.LightGbm(), GetSparseDataset(numberOfInstances: 100), "LightGbmRegression");
}

[LightGBMFact]
public void TestLightGbmRegressionWithCategoricalSplit()
{
TestFeatureContribution(ML.Regression.Trainers.LightGbm(new LightGbmRegressionTrainer.Options() { UseCategoricalSplit = true }), GetOneHotEncodedData(numberOfInstances: 100), "LightGbmRegressionWithCategoricalSplit");
}

[Fact]
public void TestFastTreeRegression()
{
Expand Down Expand Up @@ -377,5 +384,55 @@ private enum TaskType
Ranking,
Clustering
}

public class TaxiTrip
{
[LoadColumn(0)]
public string VendorId;

[LoadColumn(1)]
public float RateCode;

[LoadColumn(2)]
public float PassengerCount;

[LoadColumn(3)]
public float TripTime;

[LoadColumn(4)]
public float TripDistance;

[LoadColumn(5)]
public string PaymentType;

[LoadColumn(6)]
public float FareAmount;
}

/// <summary>
/// Returns a DataView with a Features column which include HotEncodedData
/// </summary>
private IDataView GetOneHotEncodedData(int numberOfInstances = 100)
{
var trainDataPath = GetDataPath("taxi-fare-train.csv");
IDataView trainingDataView = ML.Data.LoadFromTextFile<TaxiTrip>(trainDataPath, hasHeader: true, separatorChar: ',');

var vendorIdEncoded = "VendorIdEncoded";
var rateCodeEncoded = "RateCodeEncoded";
var paymentTypeEncoded = "PaymentTypeEncoded";

var dataProcessPipeline = ML.Transforms.CopyColumns(outputColumnName: DefaultColumnNames.Label, inputColumnName: nameof(TaxiTrip.FareAmount))
.Append(ML.Transforms.Categorical.OneHotEncoding(outputColumnName: vendorIdEncoded, inputColumnName: nameof(TaxiTrip.VendorId)))
.Append(ML.Transforms.Categorical.OneHotEncoding(outputColumnName: rateCodeEncoded, inputColumnName: nameof(TaxiTrip.RateCode)))
.Append(ML.Transforms.Categorical.OneHotEncoding(outputColumnName: paymentTypeEncoded, inputColumnName: nameof(TaxiTrip.PaymentType)))
.Append(ML.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.PassengerCount)))
.Append(ML.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.TripTime)))
.Append(ML.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.TripDistance)))
.Append(ML.Transforms.Concatenate(DefaultColumnNames.Features, vendorIdEncoded, rateCodeEncoded, paymentTypeEncoded,
nameof(TaxiTrip.PassengerCount), nameof(TaxiTrip.TripTime), nameof(TaxiTrip.TripDistance)));

var someRows = ML.Data.TakeRows(trainingDataView, numberOfInstances);
return dataProcessPipeline.Fit(someRows).Transform(someRows);
}
}
}