-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix a bug in Tree leaf featurizer entry point, and add a test for it. #131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
96703e3
13de95d
86d2a15
431ca8e
fb7e5e5
eec515e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,13 @@ | |
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Linq; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Api; | ||
using Microsoft.ML.Runtime.Core.Tests.UnitTests; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Data.IO; | ||
using Microsoft.ML.Runtime.EntryPoints; | ||
using Microsoft.ML.Runtime.EntryPoints.JsonUtils; | ||
using Microsoft.ML.Runtime.FastTree; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using Microsoft.ML.Runtime.Learners; | ||
using Newtonsoft.Json; | ||
|
@@ -2521,5 +2521,70 @@ public void EntryPointPrepareLabelConvertPredictedLabel() | |
} | ||
} | ||
} | ||
|
||
[Fact] | ||
public void EntryPointTreeLeafFeaturizer() | ||
{ | ||
var dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); | ||
var inputFile = new SimpleFileHandle(Env, dataPath, false, false); | ||
var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile }).Data; | ||
var cat = Categorical.CatTransformDict(Env, new CategoricalTransform.Arguments() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Would it be easier to use the LearningPipeline for the test here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file contains tests for entrypoints. Learning pipeline has its own file with its own test cases. I would not create a learning pipeline test here for this scenario, instead I would create this scenario in learning pipeline test file. I believe this pipeline can be created using learning pipeline with much fewer lines. We should try to add this scenario using learning pipeline under "Scenario" folder in the test folder. This way the users will see how tree leaf featurizer is used. In reply to: 188444049 [](ancestors = 188444049) |
||
{ | ||
Data = dataView, | ||
Column = new[] { new CategoricalTransform.Column { Name = "Categories", Source = "Categories" } } | ||
}); | ||
var concat = SchemaManipulation.ConcatColumns(Env, new ConcatTransform.Arguments() | ||
{ | ||
Data = cat.OutputData, | ||
Column = new[] { new ConcatTransform.Column { Name = "Features", Source = new[] { "Categories", "NumericFeatures" } } } | ||
}); | ||
|
||
var fastTree = FastTree.FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Arguments | ||
{ | ||
FeatureColumn = "Features", | ||
NumTrees = 5, | ||
NumLeaves = 4, | ||
LabelColumn = DefaultColumnNames.Label, | ||
TrainingData = concat.OutputData | ||
}); | ||
|
||
var combine = ModelOperations.CombineModels(Env, new ModelOperations.PredictorModelInput() | ||
{ | ||
PredictorModel = fastTree.PredictorModel, | ||
TransformModels = new[] { cat.Model, concat.Model } | ||
}); | ||
|
||
var treeLeaf = TreeFeaturize.Featurizer(Env, new TreeEnsembleFeaturizerTransform.ArgumentsForEntryPoint | ||
{ | ||
Data = dataView, | ||
PredictorModel = combine.PredictorModel | ||
}); | ||
|
||
var view = treeLeaf.OutputData; | ||
Assert.True(view.Schema.TryGetColumnIndex("Trees", out int treesCol)); | ||
Assert.True(view.Schema.TryGetColumnIndex("Leaves", out int leavesCol)); | ||
Assert.True(view.Schema.TryGetColumnIndex("Paths", out int pathsCol)); | ||
VBuffer<float> treeValues = default(VBuffer<float>); | ||
VBuffer<float> leafIndicators = default(VBuffer<float>); | ||
VBuffer<float> pathIndicators = default(VBuffer<float>); | ||
using (var curs = view.GetRowCursor(c => c == treesCol || c == leavesCol || c == pathsCol)) | ||
{ | ||
var treesGetter = curs.GetGetter<VBuffer<float>>(treesCol); | ||
var leavesGetter = curs.GetGetter<VBuffer<float>>(leavesCol); | ||
var pathsGetter = curs.GetGetter<VBuffer<float>>(pathsCol); | ||
while (curs.MoveNext()) | ||
{ | ||
treesGetter(ref treeValues); | ||
leavesGetter(ref leafIndicators); | ||
pathsGetter(ref pathIndicators); | ||
|
||
Assert.Equal(5, treeValues.Length); | ||
Assert.Equal(5, treeValues.Count); | ||
Assert.Equal(20, leafIndicators.Length); | ||
Assert.Equal(5, leafIndicators.Count); | ||
Assert.Equal(15, pathIndicators.Length); | ||
} | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If someone were to fit in a predictor that did implement
IValueMapper
(which most do), but that is nevertheless not a tree, what do we expect would happen? From my reading of the code it would just apply a generic scorer. Should we at least verify, in some fashion, that it is the type of predictor we expect beyond being merely aIValueMapper
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though mayhap, even if I'm right about that it should be addressed as a separate issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Tom, thanks for reviewing this change. If the predictor is an IValueMapper but is not a tree, then the constructor of TreeEnsembleFeaturizerBindableMapper below will fail.
In reply to: 187707032 [](ancestors = 187707032)