Skip to content

Tree-based featurization #3812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 26, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add a sample
  • Loading branch information
wschin committed Jun 6, 2019
commit 49fe1d713ebbf0d5f55eeab2c696daee1a665c89
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;

namespace Samples.Dynamic.Transforms.TreeFeaturization
{
public static class PretrainedTreeEnsembleFeaturizationWithOptions
{
public static void Example()
{
// Create data set
int dataPointCount = 200;
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
// as a catalog of available operations and as the source of randomness.
// Setting the seed to a fixed number in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);

// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(dataPointCount).ToList();

// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
var dataView = mlContext.Data.LoadFromEnumerable(dataPoints);

// Define input and output columns of tree-based featurizer.
string labelColumnName = nameof(DataPoint.Label);
string featureColumnName = nameof(DataPoint.Features);
string treesColumnName = nameof(TransformedDataPoint.Trees);
string leavesColumnName = nameof(TransformedDataPoint.Leaves);
string pathsColumnName = nameof(TransformedDataPoint.Paths);

// Define a tree model whose trees will be extracted to construct a tree featurizer.
var trainer = mlContext.BinaryClassification.Trainers.FastTree(
new FastTreeBinaryTrainer.Options
{
NumberOfThreads = 1,
NumberOfTrees = 1,
NumberOfLeaves = 4,
MinimumExampleCountPerLeaf = 1,
FeatureColumnName = featureColumnName,
LabelColumnName = labelColumnName
});

// Train the defined tree model.
var model = trainer.Fit(dataView);
var predicted = model.Transform(dataView);

// Define the configuration of tree-based featurizer.
var options = new PretrainedTreeFeaturizationEstimator.Options()
{
InputColumnName = featureColumnName,
ModelParameters = model.Model.SubModel, // Pretrained tree model.
TreesColumnName = treesColumnName,
LeavesColumnName = leavesColumnName,
PathsColumnName = pathsColumnName
};

// Fit the created featurizer. It doesn't perform actual training because a pretrained model is provided.
var treeFeaturizer = mlContext.Transforms.FeaturizeByPretrainTreeEnsemble(options).Fit(dataView);

// Apply TreeEnsembleFeaturizer to the input data.
var transformed = treeFeaturizer.Transform(dataView);

// Convert IDataView object to a list. Each element in the resulted list corresponds to a row in the IDataView.
var transformedDataPoints = mlContext.Data.CreateEnumerable<TransformedDataPoint>(transformed, false).ToList();

// Print out the transformation of the first 3 data points.
for (int i = 0; i < 3; ++i)
{
var dataPoint = dataPoints[i];
var transformedDataPoint = transformedDataPoints[i];
Console.WriteLine($"The original feature vector [{String.Join(",", dataPoint.Features)}] is transformed to three different tree-based feature vectors:");
Console.WriteLine($" Trees' output values: [{String.Join(",", transformedDataPoint.Trees)}].");
Console.WriteLine($" Leave IDs' 0-1 representation: [{String.Join(",", transformedDataPoint.Leaves)}].");
Console.WriteLine($" Paths IDs' 0-1 representation: [{String.Join(",", transformedDataPoint.Paths)}].");
}

// Expected output:
// The original feature vector[0.8173254, 0.7680227, 0.5581612] is transformed to three different tree - based feature vectors:
// Trees' output values: [0.4172185].
// Leave IDs' 0-1 representation: [1,0,0,0].
// Paths IDs' 0-1 representation: [1,1,1].
// The original feature vector[0.7588848, 1.106027, 0.6421779] is transformed to three different tree - based feature vectors:
// Trees' output values: [-1].
// Leave IDs' 0-1 representation: [0,0,1,0].
// Paths IDs' 0-1 representation: [1,1,0].
// The original feature vector[0.2737045, 0.2919063, 0.4673147] is transformed to three different tree - based feature vectors:
// Trees' output values: [0.4172185].
// Leave IDs' 0-1 representation: [1,0,0,0].
// Paths IDs' 0-1 representation: [1,1,1].
}

private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
{
var random = new Random(seed);
float randomFloat() => (float)random.NextDouble();
for (int i = 0; i < count; i++)
{
var label = randomFloat() > 0.5;
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with the label.
// For data points with false label, the feature values are slightly increased by adding a constant.
Features = Enumerable.Repeat(label, 3).Select(x => x ? randomFloat() : randomFloat() + 0.2f).ToArray()
};
}
}

// Example with label and 3 feature values. A data set is a collection of such examples.
private class DataPoint
{
public bool Label { get; set; }
[VectorType(3)]
public float[] Features { get; set; }
}

// Class used to capture the output of tree-base featurization.
private class TransformedDataPoint : DataPoint
{
// The i-th value is the output value of the i-th decision tree.
public float[] Trees { get; set; }
// The 0-1 encoding of leaves the input feature vector falls into.
public float[] Leaves { get; set; }
// The 0-1 encoding of paths the input feature vector reaches the leaves.
public float[] Paths { get; set; }
}
}
}
49 changes: 49 additions & 0 deletions src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,13 @@ public static FastForestBinaryTrainer FastForest(this BinaryClassificationCatalo
/// <param name="catalog">The context <see cref="TransformsCatalog"/> to create <see cref="PretrainedTreeFeaturizationEstimator"/>.</param>
/// <param name="options">The options to configure <see cref="PretrainedTreeFeaturizationEstimator"/>. See <see cref="PretrainedTreeFeaturizationEstimator.Options"/> and
/// <see cref="TreeEnsembleFeaturizationEstimatorBase.CommonOptions"/> for available settings.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FeaturizeByPretrainTreeEnsemble](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TreeFeaturization/PretrainedTreeEnsembleFeaturizationWithOptions.cs)]
/// ]]>
/// </format>
/// </example>
public static PretrainedTreeFeaturizationEstimator FeaturizeByPretrainTreeEnsemble(this TransformsCatalog catalog,
PretrainedTreeFeaturizationEstimator.Options options)
{
Expand All @@ -457,6 +464,13 @@ public static PretrainedTreeFeaturizationEstimator FeaturizeByPretrainTreeEnsemb
/// <param name="catalog">The context <see cref="TransformsCatalog"/> to create <see cref="PretrainedTreeFeaturizationEstimator"/>.</param>
/// <param name="options">The options to configure <see cref="FastForestRegressionFeaturizationEstimator"/>. See <see cref="FastForestRegressionFeaturizationEstimator.Options"/> and
/// <see cref="TreeEnsembleFeaturizationEstimatorBase.CommonOptions"/> for available settings.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FeaturizeByFastTreeRegression](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TreeFeaturization/FastForestRegressionFeaturizationWithOptions.cs)]
/// ]]>
/// </format>
/// </example>
public static FastForestRegressionFeaturizationEstimator FeaturizeByFastForestRegression(this TransformsCatalog catalog,
FastForestRegressionFeaturizationEstimator.Options options)
{
Expand All @@ -471,6 +485,13 @@ public static FastForestRegressionFeaturizationEstimator FeaturizeByFastForestRe
/// <param name="catalog">The context <see cref="TransformsCatalog"/> to create <see cref="FastTreeRegressionFeaturizationEstimator"/>.</param>
/// <param name="options">The options to configure <see cref="FastTreeRegressionFeaturizationEstimator"/>. See <see cref="FastTreeRegressionFeaturizationEstimator.Options"/> and
/// <see cref="TreeEnsembleFeaturizationEstimatorBase.CommonOptions"/> for available settings.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FeaturizeByFastTreeRegression](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TreeFeaturization/FastTreeRegressionFeaturizationWithOptions.cs)]
/// ]]>
/// </format>
/// </example>
public static FastTreeRegressionFeaturizationEstimator FeaturizeByFastTreeRegression(this TransformsCatalog catalog,
FastTreeRegressionFeaturizationEstimator.Options options)
{
Expand All @@ -485,6 +506,13 @@ public static FastTreeRegressionFeaturizationEstimator FeaturizeByFastTreeRegres
/// <param name="catalog">The context <see cref="TransformsCatalog"/> to create <see cref="FastForestBinaryFeaturizationEstimator"/>.</param>
/// <param name="options">The options to configure <see cref="FastForestBinaryFeaturizationEstimator"/>. See <see cref="FastForestBinaryFeaturizationEstimator.Options"/> and
/// <see cref="TreeEnsembleFeaturizationEstimatorBase.CommonOptions"/> for available settings.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FeaturizeByFastForestBinary](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TreeFeaturization/FastForestBinaryFeaturizationWithOptions.cs)]
/// ]]>
/// </format>
/// </example>
public static FastForestBinaryFeaturizationEstimator FeaturizeByFastForestBinary(this TransformsCatalog catalog,
FastForestBinaryFeaturizationEstimator.Options options)
{
Expand All @@ -499,6 +527,13 @@ public static FastForestBinaryFeaturizationEstimator FeaturizeByFastForestBinary
/// <param name="catalog">The context <see cref="TransformsCatalog"/> to create <see cref="FastTreeBinaryFeaturizationEstimator"/>.</param>
/// <param name="options">The options to configure <see cref="FastTreeBinaryFeaturizationEstimator"/>. See <see cref="FastTreeBinaryFeaturizationEstimator.Options"/> and
/// <see cref="TreeEnsembleFeaturizationEstimatorBase.CommonOptions"/> for available settings.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FeaturizeByFastTreeBinary](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TreeFeaturization/FastTreeBinaryFeaturizationWithOptions.cs)]
/// ]]>
/// </format>
/// </example>
public static FastTreeBinaryFeaturizationEstimator FeaturizeByFastTreeBinary(this TransformsCatalog catalog,
FastTreeBinaryFeaturizationEstimator.Options options)
{
Expand All @@ -513,6 +548,13 @@ public static FastTreeBinaryFeaturizationEstimator FeaturizeByFastTreeBinary(thi
/// <param name="catalog">The context <see cref="TransformsCatalog"/> to create <see cref="FastTreeRankingFeaturizationEstimator"/>.</param>
/// <param name="options">The options to configure <see cref="FastTreeRankingFeaturizationEstimator"/>. See <see cref="FastTreeRankingFeaturizationEstimator.Options"/> and
/// <see cref="TreeEnsembleFeaturizationEstimatorBase.CommonOptions"/> for available settings.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FeaturizeByFastTreeRanking](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TreeFeaturization/FastTreeRankingFeaturizationWithOptions.cs)]
/// ]]>
/// </format>
/// </example>
public static FastTreeRankingFeaturizationEstimator FeaturizeByFastTreeRanking(this TransformsCatalog catalog,
FastTreeRankingFeaturizationEstimator.Options options)
{
Expand All @@ -527,6 +569,13 @@ public static FastTreeRankingFeaturizationEstimator FeaturizeByFastTreeRanking(t
/// <param name="catalog">The context <see cref="TransformsCatalog"/> to create <see cref="FastTreeTweedieFeaturizationEstimator"/>.</param>
/// <param name="options">The options to configure <see cref="FastTreeTweedieFeaturizationEstimator"/>. See <see cref="FastTreeTweedieFeaturizationEstimator.Options"/> and
/// <see cref="TreeEnsembleFeaturizationEstimatorBase.CommonOptions"/> for available settings.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FeaturizeByFastTreeTweedie](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TreeFeaturization/FastTreeTweedieFeaturizationWithOptions.cs)]
/// ]]>
/// </format>
/// </example>
public static FastTreeTweedieFeaturizationEstimator FeaturizeByFastTreeTweedie(this TransformsCatalog catalog,
Copy link
Contributor

@justinormont justinormont Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May want to note in its name that FastTreeTweedie is regression: (the naming of the others list their task types)

Suggested change
public static FastTreeTweedieFeaturizationEstimator FeaturizeByFastTreeTweedie(this TransformsCatalog catalog,
public static FastTreeTweedieRegressionFeaturizationEstimator FeaturizeByFastTreeTweedieRegression(this TransformsCatalog catalog,
``` #ByDesign

Copy link
Member Author

@wschin wschin Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if the model name doesn't tell the task. Given that Tweedie somehow implies a regression case, we don't have Regression appended to any of public Tweedie modules. This pattern can be seen in FastTreeTweedieTrainer and FastTreeTweedieModelParameters. #Resolved

FastTreeTweedieFeaturizationEstimator.Options options)
{
Expand Down