Skip to content

Tree based trainers implement ICanGetSummaryAsIDataView #3892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 2, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working on alternative schema
  • Loading branch information
artidoro committed Jul 1, 2019
commit 58d4125f0f54a9f7f56e25d714cfe1762c10da5f
86 changes: 52 additions & 34 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3415,55 +3415,73 @@ IDataView ICanGetSummaryAsIDataView.GetSummaryDataView(RoleMappedSchema schema)
var builder = new ArrayDataViewBuilder(Host);

var trees = TrainedTreeEnsemble.Trees;
var numberOfRows = trees.Select(tree => tree.NumberOfNodes).Sum();

// Bias column. It will be repeated for every tree.
// Tree id indicates which tree the node belongs to.
var treeId = new List<int>();
for (int i = 0; i < trees.Count; i++)
treeId.AddRange(Enumerable.Repeat(i, trees[i].NumberOfNodes));

builder.AddColumn(nameof(RegressionTreeEnsemble.Bias), NumberDataViewType.Double, treeId);

// Bias column. This is a repeated value for all trees.
builder.AddColumn(nameof(RegressionTreeEnsemble.Bias), NumberDataViewType.Double,
new[] { TrainedTreeEnsemble.Bias });
Enumerable.Repeat(TrainedTreeEnsemble.Bias, numberOfRows));

// TreeWeights column.
builder.AddColumn(nameof(RegressionTreeEnsemble.TreeWeights), NumberDataViewType.Double,
new[] { TrainedTreeEnsemble.TreeWeights.ToArray() });
var treeWeights = new List<double>();
for (int i = 0; i < trees.Count; i ++)
treeWeights.AddRange(Enumerable.Repeat(TrainedTreeEnsemble.TreeWeights[i], trees[i].NumberOfNodes));

builder.AddColumn(nameof(RegressionTreeEnsemble.TreeWeights), NumberDataViewType.Double, treeWeights);

// LeftChild column.
var leftChild = new List<int>();
for (int i = 0; i < trees.Count; i++)
{
string currentTree = $"_Tree_{i}";
leftChild.AddRange(trees[i].LeftChild.AsEnumerable());

builder.AddColumn(nameof(RegressionTree.LeftChild), NumberDataViewType.Int32, leftChild);

// LeftChild column.
builder.AddColumn(nameof(RegressionTree.LeftChild) + currentTree, NumberDataViewType.Int32,
new[] { trees[i].LeftChild.ToArray() });
// RightChild column.
var rightChild = new List<int>();
for (int i = 0; i < trees.Count; i++)
rightChild.AddRange(trees[i].RightChild.AsEnumerable());

builder.AddColumn(nameof(RegressionTree.RightChild), NumberDataViewType.Int32, rightChild);

// NumericalSplitFeatureIndexes column.
var numericalSplitFeatureIndexes = new List<int>();
for (int i = 0; i < trees.Count; i++)
numericalSplitFeatureIndexes.AddRange(trees[i].NumericalSplitFeatureIndexes.AsEnumerable());

// RightChild column.
builder.AddColumn(nameof(RegressionTree.RightChild) + currentTree, NumberDataViewType.Int32,
new[] { trees[i].RightChild.ToArray() });
builder.AddColumn(nameof(RegressionTree.NumericalSplitFeatureIndexes), NumberDataViewType.Int32, numericalSplitFeatureIndexes);

// NumericalSplitFeatureIndexes column.
builder.AddColumn(nameof(RegressionTree.NumericalSplitFeatureIndexes) + currentTree, NumberDataViewType.Int32,
new[] { trees[i].NumericalSplitFeatureIndexes.ToArray() });
// NumericalSplitThresholds column.
var numericalSplitThresholds = new List<float>();
for (int i = 0; i < trees.Count; i++)
numericalSplitThresholds.AddRange(trees[i].NumericalSplitThresholds.AsEnumerable());

// NumericalSplitThresholds column.
builder.AddColumn(nameof(RegressionTree.NumericalSplitThresholds) + currentTree, NumberDataViewType.Single,
new[] { trees[i].NumericalSplitThresholds.ToArray() });
builder.AddColumn(nameof(RegressionTree.NumericalSplitThresholds), NumberDataViewType.Single, numericalSplitThresholds);

// CategoricalSplitFlags column.
builder.AddColumn(nameof(RegressionTree.CategoricalSplitFlags) + currentTree, BooleanDataViewType.Instance,
new[] { trees[i].CategoricalSplitFlags.ToArray() });
// CategoricalSplitFlags column.
builder.AddColumn(nameof(RegressionTree.CategoricalSplitFlags), BooleanDataViewType.Instance,
new[] { trees[i].CategoricalSplitFlags.ToArray() });

// LeafValues column.
builder.AddColumn(nameof(RegressionTree.LeafValues) + currentTree, NumberDataViewType.Double,
new[] { trees[i].LeafValues.ToArray() });
// LeafValues column.
builder.AddColumn(nameof(RegressionTree.LeafValues), NumberDataViewType.Double,
new[] { trees[i].LeafValues.ToArray() });

// SplitGains column.
builder.AddColumn(nameof(RegressionTree.SplitGains) + currentTree, NumberDataViewType.Double,
new[] { trees[i].SplitGains.ToArray() });
// SplitGains column.
builder.AddColumn(nameof(RegressionTree.SplitGains), NumberDataViewType.Double,
new[] { trees[i].SplitGains.ToArray() });

// NumberOfLeaves column.
builder.AddColumn(nameof(RegressionTree.NumberOfLeaves) + currentTree, NumberDataViewType.Int32,
new[] { trees[i].NumberOfLeaves });
// NumberOfLeaves column.
builder.AddColumn(nameof(RegressionTree.NumberOfLeaves), NumberDataViewType.Int32,
new[] { trees[i].NumberOfLeaves });

// NumberOfNodes column.
builder.AddColumn(nameof(RegressionTree.NumberOfNodes) + currentTree, NumberDataViewType.Int32,
new[] { trees[i].NumberOfNodes });
}
// NumberOfNodes column.
builder.AddColumn(nameof(RegressionTree.NumberOfNodes), NumberDataViewType.Int32,
new[] { trees[i].NumberOfNodes });

// REVIEW: Should these two be exposed, if so how?
// public IReadOnlyList<int> GetCategoricalSplitFeaturesAt(int nodeIndex)
Expand Down