-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Tree based trainers implement ICanGetSummaryAsIDataView #3892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
6127f95
IDataView function
artidoro 4fe7683
testing a first option in one line
artidoro 58d4125
working on alternative schema
artidoro e6017b5
changed schema one node per row
artidoro 2c4fa9b
cleaning up
artidoro cc64586
lightgbm fact
artidoro 7f632ff
add entrypoint test
artidoro 567a77f
baseline output file fors other summary tests
artidoro 4e36d41
working on updating the utils
artidoro 791f31d
refactored utils class
artidoro 78f4f86
update tests to check categorical splits
artidoro a9d91e7
fixing summary test so that it trains the model and gets the summary
artidoro 79e9e16
change file open mode
artidoro b85fbf6
setting threads on fasttreetest
artidoro 60f4bd1
trying to reduce variance in test
artidoro 65885b8
refreshing baseline
artidoro File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
using Microsoft.ML.CommandLine; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Data.Conversion; | ||
using Microsoft.ML.FastTree.Utils; | ||
using Microsoft.ML.Internal.Internallearn; | ||
using Microsoft.ML.Internal.Utilities; | ||
using Microsoft.ML.Model; | ||
|
@@ -2791,7 +2792,6 @@ public abstract class TreeEnsembleModelParameters : | |
IPredictorWithFeatureWeights<float>, | ||
IFeatureContributionMapper, | ||
ICalculateFeatureContribution, | ||
ICanGetSummaryAsIRow, | ||
ISingleCanSavePfa, | ||
ISingleCanSaveOnnx | ||
{ | ||
|
@@ -3276,23 +3276,6 @@ internal int GetLeaf(int treeId, in VBuffer<float> features, ref List<int> path) | |
return TrainedEnsemble.GetTreeAt(treeId).GetLeaf(in features, ref path); | ||
} | ||
|
||
DataViewRow ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema) | ||
{ | ||
var names = default(VBuffer<ReadOnlyMemory<char>>); | ||
AnnotationUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumFeatures, ref names); | ||
var metaBuilder = new DataViewSchema.Annotations.Builder(); | ||
metaBuilder.AddSlotNames(NumFeatures, names.CopyTo); | ||
|
||
var weights = default(VBuffer<Single>); | ||
((IHaveFeatureWeights)this).GetFeatureWeights(ref weights); | ||
var builder = new DataViewSchema.Annotations.Builder(); | ||
builder.Add<VBuffer<float>>("Gains", new VectorDataViewType(NumberDataViewType.Single, NumFeatures), weights.CopyTo, metaBuilder.ToAnnotations()); | ||
|
||
return AnnotationUtils.AnnotationsAsRow(builder.ToAnnotations()); | ||
} | ||
|
||
DataViewRow ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema) => null; | ||
|
||
private sealed class Tree : ITree<VBuffer<float>> | ||
{ | ||
private readonly InternalRegressionTree _regTree; | ||
|
@@ -3378,7 +3361,7 @@ public TreeNode(Dictionary<string, object> keyValues) | |
/// and <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is the type of | ||
/// <see cref="TrainedTreeEnsemble"/>. | ||
/// </summary> | ||
public abstract class TreeEnsembleModelParametersBasedOnRegressionTree : TreeEnsembleModelParameters | ||
public abstract class TreeEnsembleModelParametersBasedOnRegressionTree : TreeEnsembleModelParameters, ICanGetSummaryAsIDataView | ||
{ | ||
/// <summary> | ||
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/> | ||
|
@@ -3406,6 +3389,12 @@ private RegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure() | |
var treeWeights = TrainedEnsemble.Trees.Select(tree => tree.Weight); | ||
return new RegressionTreeEnsemble(trees, treeWeights, TrainedEnsemble.Bias); | ||
} | ||
|
||
/// <summary> | ||
/// Used for the Summarize entrypoint. | ||
/// </summary> | ||
IDataView ICanGetSummaryAsIDataView.GetSummaryDataView(RoleMappedSchema schema) | ||
=> RegressionTreeBaseUtils.RegressionTreeEnsembleAsIDataView(Host, TrainedTreeEnsemble.Bias, TrainedTreeEnsemble.TreeWeights, TrainedTreeEnsemble.Trees); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent |
||
} | ||
|
||
/// <summary> | ||
|
@@ -3418,7 +3407,7 @@ private RegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure() | |
/// and <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is the type of | ||
/// <see cref="TrainedTreeEnsemble"/>. | ||
/// </summary> | ||
public abstract class TreeEnsembleModelParametersBasedOnQuantileRegressionTree : TreeEnsembleModelParameters | ||
public abstract class TreeEnsembleModelParametersBasedOnQuantileRegressionTree : TreeEnsembleModelParameters, ICanGetSummaryAsIDataView | ||
{ | ||
/// <summary> | ||
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/> | ||
|
@@ -3446,5 +3435,11 @@ private QuantileRegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructu | |
var treeWeights = TrainedEnsemble.Trees.Select(tree => tree.Weight); | ||
return new QuantileRegressionTreeEnsemble(trees, treeWeights, TrainedEnsemble.Bias); | ||
} | ||
|
||
/// <summary> | ||
/// Used for the Summarize entrypoint. | ||
/// </summary> | ||
IDataView ICanGetSummaryAsIDataView.GetSummaryDataView(RoleMappedSchema schema) | ||
=> RegressionTreeBaseUtils.RegressionTreeEnsembleAsIDataView(Host, TrainedTreeEnsemble.Bias, TrainedTreeEnsemble.TreeWeights, TrainedTreeEnsemble.Trees); | ||
} | ||
} |
147 changes: 147 additions & 0 deletions
147
src/Microsoft.ML.FastTree/Utils/RegressionTreeBaseUtils.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Trainers.FastTree; | ||
|
||
namespace Microsoft.ML.FastTree.Utils | ||
{ | ||
internal class RegressionTreeBaseUtils | ||
{ | ||
/// <summary> | ||
/// Utility method used to represent a tree ensemble as an <see cref="IDataView"/>. | ||
/// Every row in the <see cref="IDataView"/> corresponds to a node in the tree ensemble. The columns are the fields for each node. | ||
/// The column TreeID specifies which tree the node belongs to. The <see cref="QuantileRegressionTree"/> gets | ||
/// special treatment since it has some additional fields (<see cref="QuantileRegressionTree.GetLeafSamplesAt(int)"/> | ||
/// and <see cref="QuantileRegressionTree.GetLeafSampleWeightsAt(int)"/>). | ||
/// </summary> | ||
public static IDataView RegressionTreeEnsembleAsIDataView(IHost host, double bias, IReadOnlyList<double> treeWeights, IReadOnlyList<RegressionTreeBase> trees) | ||
{ | ||
var builder = new ArrayDataViewBuilder(host); | ||
var numberOfRows = trees.Select(tree => tree.NumberOfNodes).Sum() + trees.Select(tree => tree.NumberOfLeaves).Sum(); | ||
|
||
var treeWeightsList = new List<double>(); | ||
var treeId = new List<int>(); | ||
var isLeaf = new List<ReadOnlyMemory<char>>(); | ||
var leftChild = new List<int>(); | ||
var rightChild = new List<int>(); | ||
var numericalSplitFeatureIndexes = new List<int>(); | ||
var numericalSplitThresholds = new List<float>(); | ||
var categoricalSplitFlags = new List<bool>(); | ||
var leafValues = new List<double>(); | ||
var splitGains = new List<double>(); | ||
var categoricalSplitFeatures = new List<VBuffer<int>>(); | ||
var categoricalCategoricalSplitFeatureRange = new List<VBuffer<int>>(); | ||
|
||
for (int i = 0; i < trees.Count; i++) | ||
{ | ||
// TreeWeights column. The TreeWeight value will be repeated for all the notes in the same tree in the IDataView. | ||
treeWeightsList.AddRange(Enumerable.Repeat(treeWeights[i], trees[i].NumberOfNodes + trees[i].NumberOfLeaves)); | ||
|
||
// Tree id indicates which tree the node belongs to. | ||
treeId.AddRange(Enumerable.Repeat(i, trees[i].NumberOfNodes + trees[i].NumberOfLeaves)); | ||
|
||
// IsLeaf column indicates if node is a leaf node. | ||
isLeaf.AddRange(Enumerable.Repeat(new ReadOnlyMemory<char>("Tree node".ToCharArray()), trees[i].NumberOfNodes)); | ||
isLeaf.AddRange(Enumerable.Repeat(new ReadOnlyMemory<char>("Leaf node".ToCharArray()), trees[i].NumberOfLeaves)); | ||
|
||
// LeftChild column. | ||
leftChild.AddRange(trees[i].LeftChild.AsEnumerable()); | ||
leftChild.AddRange(Enumerable.Repeat(0, trees[i].NumberOfLeaves)); | ||
|
||
// RightChild column. | ||
rightChild.AddRange(trees[i].RightChild.AsEnumerable()); | ||
rightChild.AddRange(Enumerable.Repeat(0, trees[i].NumberOfLeaves)); | ||
|
||
// NumericalSplitFeatureIndexes column. | ||
numericalSplitFeatureIndexes.AddRange(trees[i].NumericalSplitFeatureIndexes.AsEnumerable()); | ||
numericalSplitFeatureIndexes.AddRange(Enumerable.Repeat(0, trees[i].NumberOfLeaves)); | ||
|
||
// NumericalSplitThresholds column. | ||
numericalSplitThresholds.AddRange(trees[i].NumericalSplitThresholds.AsEnumerable()); | ||
numericalSplitThresholds.AddRange(Enumerable.Repeat(0f, trees[i].NumberOfLeaves)); | ||
|
||
// CategoricalSplitFlags column. | ||
categoricalSplitFlags.AddRange(trees[i].CategoricalSplitFlags.AsEnumerable()); | ||
categoricalSplitFlags.AddRange(Enumerable.Repeat(false, trees[i].NumberOfLeaves)); | ||
|
||
// LeafValues column. | ||
leafValues.AddRange(Enumerable.Repeat(0d, trees[i].NumberOfNodes)); | ||
leafValues.AddRange(trees[i].LeafValues.AsEnumerable()); | ||
|
||
// SplitGains column. | ||
splitGains.AddRange(trees[i].SplitGains.AsEnumerable()); | ||
splitGains.AddRange(Enumerable.Repeat(0d, trees[i].NumberOfLeaves)); | ||
|
||
for (int j = 0; j < trees[i].NumberOfNodes; j++) | ||
{ | ||
// CategoricalSplitFeatures column. | ||
var categoricalSplitFeaturesArray = trees[i].GetCategoricalSplitFeaturesAt(j).ToArray(); | ||
categoricalSplitFeatures.Add(new VBuffer<int>(categoricalSplitFeaturesArray.Length, categoricalSplitFeaturesArray)); | ||
var len = trees[i].GetCategoricalSplitFeaturesAt(j).ToArray().Length; | ||
|
||
// CategoricalCategoricalSplitFeatureRange column. | ||
var categoricalCategoricalSplitFeatureRangeArray = trees[i].GetCategoricalCategoricalSplitFeatureRangeAt(j).ToArray(); | ||
categoricalCategoricalSplitFeatureRange.Add(new VBuffer<int>(categoricalCategoricalSplitFeatureRangeArray.Length, categoricalCategoricalSplitFeatureRangeArray)); | ||
len = trees[i].GetCategoricalCategoricalSplitFeatureRangeAt(j).ToArray().Length; | ||
} | ||
|
||
categoricalSplitFeatures.AddRange(Enumerable.Repeat(new VBuffer<int>(), trees[i].NumberOfLeaves)); | ||
categoricalCategoricalSplitFeatureRange.AddRange(Enumerable.Repeat(new VBuffer<int>(), trees[i].NumberOfLeaves)); | ||
} | ||
|
||
// Bias column. This will be a repeated value for all rows in the resulting IDataView. | ||
builder.AddColumn("Bias", NumberDataViewType.Double, Enumerable.Repeat(bias, numberOfRows).ToArray()); | ||
builder.AddColumn("TreeWeights", NumberDataViewType.Double, treeWeightsList.ToArray()); | ||
builder.AddColumn("TreeID", NumberDataViewType.Int32, treeId.ToArray()); | ||
builder.AddColumn("IsLeaf", TextDataViewType.Instance, isLeaf.ToArray()); | ||
builder.AddColumn(nameof(RegressionTreeBase.LeftChild), NumberDataViewType.Int32, leftChild.ToArray()); | ||
builder.AddColumn(nameof(RegressionTreeBase.RightChild), NumberDataViewType.Int32, rightChild.ToArray()); | ||
builder.AddColumn(nameof(RegressionTreeBase.NumericalSplitFeatureIndexes), NumberDataViewType.Int32, numericalSplitFeatureIndexes.ToArray()); | ||
builder.AddColumn(nameof(RegressionTreeBase.NumericalSplitThresholds), NumberDataViewType.Single, numericalSplitThresholds.ToArray()); | ||
builder.AddColumn(nameof(RegressionTreeBase.CategoricalSplitFlags), BooleanDataViewType.Instance, categoricalSplitFlags.ToArray()); | ||
builder.AddColumn(nameof(RegressionTreeBase.LeafValues), NumberDataViewType.Double, leafValues.ToArray()); | ||
builder.AddColumn(nameof(RegressionTreeBase.SplitGains), NumberDataViewType.Double, splitGains.ToArray()); | ||
builder.AddColumn("CategoricalSplitFeatures", NumberDataViewType.Int32, categoricalSplitFeatures.ToArray()); | ||
builder.AddColumn("CategoricalCategoricalSplitFeatureRange", NumberDataViewType.Int32, categoricalCategoricalSplitFeatureRange.ToArray()); | ||
|
||
// If the input tree array is a quantile regression tree we need to add two more columns. | ||
var quantileTrees = trees as IReadOnlyList<QuantileRegressionTree>; | ||
if (quantileTrees != null) | ||
{ | ||
// LeafSamples column. | ||
var leafSamples = new List<VBuffer<double>>(); | ||
|
||
// LeafSampleWeights column. | ||
var leafSampleWeights = new List<VBuffer<double>>(); | ||
for (int i = 0; i < quantileTrees.Count; i++) | ||
{ | ||
leafSamples.AddRange(Enumerable.Repeat(new VBuffer<double>(), quantileTrees[i].NumberOfNodes)); | ||
leafSampleWeights.AddRange(Enumerable.Repeat(new VBuffer<double>(), quantileTrees[i].NumberOfNodes)); | ||
for (int j = 0; j < quantileTrees[i].NumberOfLeaves; j++) | ||
{ | ||
var leafSamplesArray = quantileTrees[i].GetLeafSamplesAt(j).ToArray(); | ||
leafSamples.Add(new VBuffer<double>(leafSamplesArray.Length, leafSamplesArray)); | ||
var len = quantileTrees[i].GetLeafSamplesAt(j).ToArray().Length; | ||
|
||
var leafSampleWeightsArray = quantileTrees[i].GetLeafSampleWeightsAt(j).ToArray(); | ||
leafSampleWeights.Add(new VBuffer<double>(leafSampleWeightsArray.Length, leafSampleWeightsArray)); | ||
len = quantileTrees[i].GetLeafSampleWeightsAt(j).ToArray().Length; | ||
} | ||
} | ||
|
||
builder.AddColumn("LeafSamples", NumberDataViewType.Double, leafSamples.ToArray()); | ||
builder.AddColumn("LeafSampleWeights", NumberDataViewType.Double, leafSampleWeights.ToArray()); | ||
} | ||
|
||
var data = builder.GetDataView(); | ||
return data; | ||
} | ||
|
||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need
ICanGetSummaryAs...
?IDataView RegressionTreeEnsembleAsIDataView(..)
looks sufficient for generating a summary. #ByDesignThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the interface that Summarize entrypoint uses to get the IDataView from a trainer.
See:
machinelearning/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs
Line 32 in 665a366
And:
machinelearning/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs
Line 49 in 665a366