-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Support for Categorical features in CalculateFeatureContribution of LightGBM #5018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
b7fa547
fc58840
6b713f4
a2d8779
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1515,28 +1515,59 @@ public void AppendFeatureContributions(in VBuffer<float> src, BufferBuilder<floa | |
int node = 0; | ||
while (node >= 0) | ||
{ | ||
int ifeat = SplitFeatures[node]; | ||
var val = src.GetItemOrDefault(ifeat); | ||
val = GetFeatureValue(val, node); | ||
int otherWay; | ||
if (val <= RawThresholds[node]) | ||
if (CategoricalSplit[node]) | ||
{ | ||
Contracts.Assert(CategoricalSplitFeatures != null); | ||
|
||
int newNode = LteChild[node]; | ||
otherWay = GtChild[node]; | ||
node = LteChild[node]; | ||
foreach (var index in CategoricalSplitFeatures[node]) | ||
{ | ||
float fv = GetFeatureValue(src.GetItemOrDefault(index), node); | ||
if (fv > 0.0f) | ||
{ | ||
newNode = GtChild[node]; | ||
otherWay = LteChild[node]; | ||
break; | ||
} | ||
} | ||
|
||
// What if we went the other way? | ||
var ghostLeaf = GetLeafFrom(in src, otherWay); | ||
var ghostOutput = GetOutput(ghostLeaf); | ||
|
||
// If the ghost got a smaller output, the contribution of the categorical features is positive, so | ||
// the contribution is true minus ghost. | ||
foreach(var ifeat in CategoricalSplitFeatures[node]) | ||
contributions.AddFeature(ifeat, (float)(trueOutput - ghostOutput)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code-wise, I think this is the correct way to find which features are involved in the categorical split of this node (as it is done in a similar way here) And I tried to make this analogous to how feature contribution is already being calculated for non-categorical features (here). But I don't know if this is the "mathematically correct" way of calculating feature contributions for categorical features. I can think about a couple of alternatives to this, but I wouldn't know which one to choose. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I've updated this. I am still not sure the updated version is correct. But I think it's the closest to how feature contribution is calculated for numerical feature splits, and considering that for other other cases (FastTree, Gam, etc...) categorical features are treated the same as any other feature (ignoring the fact they're categorical) when calculating feature contribution. |
||
|
||
node = newNode; | ||
} | ||
else | ||
{ | ||
otherWay = LteChild[node]; | ||
node = GtChild[node]; | ||
} | ||
int ifeat = SplitFeatures[node]; | ||
var val = src.GetItemOrDefault(ifeat); | ||
val = GetFeatureValue(val, node); | ||
if (val <= RawThresholds[node]) | ||
{ | ||
otherWay = GtChild[node]; | ||
node = LteChild[node]; | ||
} | ||
else | ||
{ | ||
otherWay = LteChild[node]; | ||
node = GtChild[node]; | ||
} | ||
|
||
// What if we went the other way? | ||
var ghostLeaf = GetLeafFrom(in src, otherWay); | ||
var ghostOutput = GetOutput(ghostLeaf); | ||
// What if we went the other way? | ||
var ghostLeaf = GetLeafFrom(in src, otherWay); | ||
var ghostOutput = GetOutput(ghostLeaf); | ||
|
||
// If the ghost got a smaller output, the contribution of the feature is positive, so | ||
// the contribution is true minus ghost. | ||
contributions.AddFeature(ifeat, (float)(trueOutput - ghostOutput)); | ||
// If the ghost got a smaller output, the contribution of the feature is positive, so | ||
// the contribution is true minus ghost. | ||
contributions.AddFeature(ifeat, (float)(trueOutput - ghostOutput)); | ||
} | ||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#@ TextLoader{ | ||
#@ sep=tab | ||
#@ col=VendorId:TX:0 | ||
#@ col=RateCode:R4:1 | ||
#@ col=PassengerCount:R4:2 | ||
#@ col=PassengerCount:R4:3 | ||
#@ col=TripTime:R4:4 | ||
#@ col=TripTime:R4:5 | ||
#@ col=TripDistance:R4:6 | ||
#@ col=TripDistance:R4:7 | ||
#@ col=PaymentType:TX:8 | ||
#@ col=FareAmount:R4:9 | ||
#@ col=Label:R4:10 | ||
#@ col=VendorIdEncoded:U4[1]:11 | ||
#@ col=VendorIdEncoded:R4:12-12 | ||
#@ col=RateCodeEncoded:U4[2]:13 | ||
#@ col=RateCodeEncoded:R4:14-15 | ||
#@ col=PaymentTypeEncoded:U4[3]:16 | ||
#@ col=PaymentTypeEncoded:R4:17-19 | ||
#@ col=Features:R4:20-28 | ||
#@ col=FeatureContributions:R4:29-37 | ||
#@ col=FeatureContributions:R4:38-46 | ||
#@ col=FeatureContributions:R4:47-55 | ||
#@ col=FeatureContributions:R4:56-64 | ||
#@ } | ||
CMT 1 1 0.7088812 1271 1.64874518 3.8 1.0118916 CRD 17.5 17.5 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 1.64874518 1.0118916 36 4:0.107879594 7:0.725665748 8:1 15:-1 24:-0.0418495 26:1 33:-0.370121539 35:8.844109 | ||
CMT 1 1 0.7088812 474 0.6148743 1.5 0.3994309 CRD 8 8 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 0.6148743 0.3994309 36 4:1 15:-0.0364986733 16:-0.847436965 17:-1 22:0.011381451 26:-1 31:0.115415707 35:-10.1406841 | ||
CMT 1 1 0.7088812 637 0.8263184 1.4 0.372802168 CRD 8.5 8.5 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0.7088812 0.8263184 0.372802168 36 4:1 15:-0.0366709046 16:-0.5593253 17:-1 22:0.0182117485 26:-1 31:0.183812216 35:-10.0930576 | ||
CMT 1 1 0.7088812 181 0.234793767 0.6 0.159772366 CSH 4.5 4.5 0 1 0 1 0 1 0 1 0 1 1 0 0 1 0 0.7088812 0.234793767 0.159772366 36 6:1 13:-0.293414325 16:-0.7202999 17:-1 24:0.0291313324 26:-1 33:0.33991462 35:-11.6683512 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to me this doc was wrong, as it's inconsistent with what these other lines below say:
machinelearning/src/Microsoft.ML.FastTree/RegressionTree.cs
Lines 76 to 80 in 8660ecc
machinelearning/src/Microsoft.ML.FastTree/RegressionTree.cs
Lines 100 to 106 in 8660ecc