Skip to content

Commit d3b70b5

Browse files
authored
An elaborate series of changes that are, incredibly, actually related. (dotnet#1563)
* Move IModelCombiner out of Core to Ensemble since it clearly belongs there, not in Core. * Remove dependency of Ensemble on FastTree. * Remove learners in Ensemble having defaults of FastTree or indeed any learner. (Incidentally: fixes dotnet#682.) * Rename *FastTree* Ensemble to TreeEnsemble, so as to avoid namespace/type collisions with that type and Ensemble namespace. * Add dependency of FastTree to Ensemble project so something there can implement TreeEnsembleCombiner. * Resolve circular dependency of FastTree -> Ensemble -> StandardLearners -> Legacy -> FastTree by removing Legacy as dependency of StandardLearners, since no project we intend to keep should depend on Legacy. * Move Legacy specific infrastructure that somehow was in StandardLearners over to Legacy. * Fix documentation in StandardLearners that was incorrectly referring to the Legacy pipelines and types directly, since in reality they have nothing to do with the types in Legacy.
1 parent d8402cd commit d3b70b5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+115
-121
lines changed

src/Microsoft.ML.Core/Prediction/ITrainer.cs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System.Collections.Generic;
65
using Microsoft.ML.Runtime.Data;
76

87
namespace Microsoft.ML.Runtime
@@ -94,12 +93,4 @@ public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData)
9493
public static TPredictor Train<TPredictor>(this ITrainer<TPredictor> trainer, RoleMappedData trainData) where TPredictor : IPredictor
9594
=> trainer.Train(new TrainContext(trainData));
9695
}
97-
98-
/// <summary>
99-
/// An interface that combines multiple predictors into a single predictor.
100-
/// </summary>
101-
public interface IModelCombiner
102-
{
103-
IPredictor CombineModels(IEnumerable<IPredictor> models);
104-
}
10596
}

src/Microsoft.ML.Core/Properties/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.ML;
77

88
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)]
9+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)]
910

1011
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)]
1112
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)]

src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
1111
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
1212
<ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
13-
<ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
1413
</ItemGroup>
1514

1615
</Project>

src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
using Microsoft.ML.Runtime.Data;
99
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
1010
using Microsoft.ML.Runtime.EntryPoints;
11-
using Microsoft.ML.Trainers.FastTree;
1211
using Microsoft.ML.Runtime.Internal.Internallearn;
1312
using Microsoft.ML.Runtime.Internal.Utilities;
14-
using Microsoft.ML.Runtime.Learners;
1513
using Microsoft.ML.Runtime.Model;
1614

1715
[assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner),
@@ -50,17 +48,6 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF
5048
internal override IComponentFactory<ITrainer<TVectorPredictor>> GetPredictorFactory() => BasePredictorType;
5149

5250
public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);
53-
54-
public Arguments()
55-
{
56-
// REVIEW: Perhaps we can have a better non-parametetric learner.
57-
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
58-
env => new Ova(env, new Ova.Arguments()
59-
{
60-
PredictorType = ComponentFactoryUtils.CreateFromFunction(
61-
e => new FastTreeBinaryClassificationTrainer(e, DefaultColumnNames.Label, DefaultColumnNames.Features))
62-
}));
63-
}
6451
}
6552

6653
public MultiStacking(IHostEnvironment env, Arguments args)

src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
using System;
55
using Microsoft.ML.Runtime;
66
using Microsoft.ML.Runtime.CommandLine;
7-
using Microsoft.ML.Runtime.Data;
87
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
98
using Microsoft.ML.Runtime.EntryPoints;
10-
using Microsoft.ML.Trainers.FastTree;
119
using Microsoft.ML.Runtime.Internal.Internallearn;
1210
using Microsoft.ML.Runtime.Model;
1311

@@ -47,12 +45,6 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF
4745

4846
internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;
4947

50-
public Arguments()
51-
{
52-
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
53-
env => new FastTreeRegressionTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
54-
}
55-
5648
public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
5749
}
5850

src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
using System;
66
using Microsoft.ML.Runtime;
77
using Microsoft.ML.Runtime.CommandLine;
8-
using Microsoft.ML.Runtime.Data;
98
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
109
using Microsoft.ML.Runtime.EntryPoints;
11-
using Microsoft.ML.Trainers.FastTree;
1210
using Microsoft.ML.Runtime.Internal.Internallearn;
1311
using Microsoft.ML.Runtime.Model;
1412

@@ -45,12 +43,6 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto
4543

4644
internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;
4745

48-
public Arguments()
49-
{
50-
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
51-
env => new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
52-
}
53-
5446
public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);
5547
}
5648

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using Microsoft.ML.Runtime;
7+
8+
namespace Microsoft.ML.Runtime.Ensemble
9+
{
10+
/// <summary>
11+
/// An interface that combines multiple predictors into a single predictor.
12+
/// </summary>
13+
public interface IModelCombiner
14+
{
15+
IPredictor CombineModels(IEnumerable<IPredictor> models);
16+
}
17+
}

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
5555
{
5656
protected readonly TArgs Args;
5757
protected readonly bool AllowGC;
58-
protected Ensemble TrainedEnsemble;
58+
protected TreeEnsemble TrainedEnsemble;
5959
protected int FeatureCount;
6060
protected RoleMappedData ValidData;
6161
protected IParallelTraining ParallelTraining;
@@ -76,7 +76,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
7676
protected double[] InitValidScores;
7777
protected double[][] InitTestScores;
7878
//protected int Iteration;
79-
protected Ensemble Ensemble;
79+
protected TreeEnsemble Ensemble;
8080

8181
protected bool HasValidSet => ValidSet != null;
8282

@@ -478,7 +478,7 @@ protected bool AreSamplesWeighted(IChannel ch)
478478

479479
private void InitializeEnsemble()
480480
{
481-
Ensemble = new Ensemble();
481+
Ensemble = new TreeEnsemble();
482482
}
483483

484484
/// <summary>
@@ -914,7 +914,7 @@ internal abstract class DataConverter
914914
/// of features we actually trained on. This can be null in the event that no filtering
915915
/// occurred.
916916
/// </summary>
917-
/// <seealso cref="Ensemble.RemapFeatures"/>
917+
/// <seealso cref="TreeEnsemble.RemapFeatures"/>
918918
public int[] FeatureMap;
919919

920920
protected readonly IHost Host;
@@ -2810,7 +2810,7 @@ public abstract class FastTreePredictionWrapper :
28102810
ISingleCanSaveOnnx
28112811
{
28122812
//The below two properties are necessary for tree Visualizer
2813-
public Ensemble TrainedEnsemble { get; }
2813+
public TreeEnsemble TrainedEnsemble { get; }
28142814
public int NumTrees => TrainedEnsemble.NumTrees;
28152815

28162816
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2834,7 +2834,7 @@ public abstract class FastTreePredictionWrapper :
28342834
public bool CanSavePfa => true;
28352835
public bool CanSaveOnnx(OnnxContext ctx) => true;
28362836

2837-
protected FastTreePredictionWrapper(IHostEnvironment env, string name, Ensemble trainedEnsemble, int numFeatures, string innerArgs)
2837+
protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
28382838
: base(env, name)
28392839
{
28402840
Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble));
@@ -2871,7 +2871,7 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoad
28712871
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
28722872
categoricalSplits = true;
28732873

2874-
TrainedEnsemble = new Ensemble(ctx, usingDefaultValues, categoricalSplits);
2874+
TrainedEnsemble = new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits);
28752875
MaxSplitFeatIdx = FindMaxFeatureIndex(TrainedEnsemble);
28762876

28772877
InnerArgs = ctx.LoadStringOrNull();
@@ -3264,7 +3264,7 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
32643264
bldr.GetResult(ref weights);
32653265
}
32663266

3267-
private static int FindMaxFeatureIndex(Ensemble ensemble)
3267+
private static int FindMaxFeatureIndex(TreeEnsemble ensemble)
32683268
{
32693269
int ifeatMax = 0;
32703270
for (int i = 0; i < ensemble.NumTrees; i++)

src/Microsoft.ML.FastTree/FastTreeClassification.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ private static VersionInfo GetVersionInfo()
6868

6969
protected override uint VerCategoricalSplitSerialized => 0x00010005;
7070

71-
internal FastTreeBinaryPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
71+
internal FastTreeBinaryPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
7272
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
7373
{
7474
}

src/Microsoft.ML.FastTree/FastTreeRanking.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,7 @@ private static VersionInfo GetVersionInfo()
11141114

11151115
protected override uint VerCategoricalSplitSerialized => 0x00010005;
11161116

1117-
internal FastTreeRankingPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
1117+
internal FastTreeRankingPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
11181118
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
11191119
{
11201120
}

src/Microsoft.ML.FastTree/FastTreeRegression.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ private static VersionInfo GetVersionInfo()
462462

463463
protected override uint VerCategoricalSplitSerialized => 0x00010005;
464464

465-
internal FastTreeRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
465+
internal FastTreeRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
466466
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
467467
{
468468
}

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ private static VersionInfo GetVersionInfo()
465465

466466
protected override uint VerCategoricalSplitSerialized => 0x00010003;
467467

468-
internal FastTreeTweediePredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
468+
internal FastTreeTweediePredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
469469
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
470470
{
471471
}

src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1212
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
1313
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
14+
<ProjectReference Include="..\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj" />
1415
</ItemGroup>
1516

1617
</Project>

src/Microsoft.ML.FastTree/RandomForestClassification.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ private static VersionInfo GetVersionInfo()
7878
/// </summary>
7979
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
8080

81-
public FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
81+
public FastForestClassificationPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
8282
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
8383
{ }
8484

src/Microsoft.ML.FastTree/RandomForestRegression.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ private static VersionInfo GetVersionInfo()
5959

6060
protected override uint VerCategoricalSplitSerialized => 0x00010006;
6161

62-
public FastForestRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
62+
public FastForestRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
6363
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
6464
{
6565
_quantileSampleCount = samplesCount;

src/Microsoft.ML.FastTree/Training/BaggingProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public int GetBagCount(int numTrees, int bagSize)
7575
// Divides output values of leaves to bag count.
7676
// This brings back the final scores generated by model on a same
7777
// range as when we didn't use bagging
78-
public void ScaleEnsembleLeaves(int numTrees, int bagSize, Ensemble ensemble)
78+
public void ScaleEnsembleLeaves(int numTrees, int bagSize, TreeEnsemble ensemble)
7979
{
8080
int bagCount = GetBagCount(numTrees, bagSize);
8181
for (int t = 0; t < ensemble.NumTrees; t++)

src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ public interface IEnsembleCompressor<TLabel>
1414

1515
void SetTreeScores(int idx, double[] scores);
1616

17-
bool Compress(IChannel ch, Ensemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression);
17+
bool Compress(IChannel ch, TreeEnsemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression);
1818

19-
Ensemble GetCompressedEnsemble();
19+
TreeEnsemble GetCompressedEnsemble();
2020
}
2121
}

src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public class LassoBasedEnsembleCompressor : IEnsembleCompressor<short>
5151

5252
private Dataset _trainSet;
5353
private short[] _labels;
54-
private Ensemble _compressedEnsemble;
54+
private TreeEnsemble _compressedEnsemble;
5555
private int[] _sampleObservationIndices;
5656
private Random _rnd;
5757

@@ -458,9 +458,9 @@ private LassoFit GetLassoFit(IChannel ch, int maxAllowedFeaturesPerModel)
458458
return fit;
459459
}
460460

461-
private Ensemble GetEnsembleFromSolution(LassoFit fit, int solutionIdx, Ensemble originalEnsemble)
461+
private TreeEnsemble GetEnsembleFromSolution(LassoFit fit, int solutionIdx, TreeEnsemble originalEnsemble)
462462
{
463-
Ensemble ensemble = new Ensemble();
463+
TreeEnsemble ensemble = new TreeEnsemble();
464464

465465
int weightsCount = fit.NumberOfWeights[solutionIdx];
466466
for (int i = 0; i < weightsCount; i++)
@@ -534,7 +534,7 @@ private unsafe void LoadTargets(double[] trainScores, int bestIteration)
534534
}
535535
}
536536

537-
public bool Compress(IChannel ch, Ensemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression)
537+
public bool Compress(IChannel ch, TreeEnsemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression)
538538
{
539539
LoadTargets(trainScores, bestIteration);
540540

@@ -552,7 +552,7 @@ public bool Compress(IChannel ch, Ensemble ensemble, double[] trainScores, int b
552552
return true;
553553
}
554554

555-
public Ensemble GetCompressedEnsemble()
555+
public TreeEnsemble GetCompressedEnsemble()
556556
{
557557
return _compressedEnsemble;
558558
}

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Microsoft.ML.Trainers.FastTree.Internal
99
//Accelerated gradient descent score tracker
1010
public class AcceleratedGradientDescent : GradientDescent
1111
{
12-
public AcceleratedGradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
12+
public AcceleratedGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
1313
: base(ensemble, trainData, initTrainScores, gradientWrapper)
1414
{
1515
UseFastTrainingScoresUpdate = false;

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class ConjugateGradientDescent : GradientDescent
1313
private double[] _currentGradient;
1414
private double[] _currentDk;
1515

16-
public ConjugateGradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
16+
public ConjugateGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
1717
: base(ensemble, trainData, initTrainScores, gradientWrapper)
1818
{
1919
_currentDk = new double[trainData.NumDocs];

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public class GradientDescent : OptimizationAlgorithm
2222
private double[] _droppedScores;
2323
private double[] _scores;
2424

25-
public GradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
25+
public GradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
2626
: base(ensemble, trainData, initTrainScores)
2727
{
2828
_gradientWrapper = gradientWrapper;

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class RandomForestOptimizer : GradientDescent
1414
{
1515
private IGradientAdjuster _gradientWrapper;
1616
// REVIEW: When the FastTree appliation is decoupled with tree learner and boosting logic, this class should be removed.
17-
public RandomForestOptimizer(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
17+
public RandomForestOptimizer(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
1818
: base(ensemble, trainData, initTrainScores, gradientWrapper)
1919
{
2020
_gradientWrapper = gradientWrapper;

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public abstract class OptimizationAlgorithm
2626
public delegate void PreScoreUpdateHandler(IChannel ch);
2727
public PreScoreUpdateHandler PreScoreUpdateEvent;
2828

29-
public Ensemble Ensemble;
29+
public TreeEnsemble Ensemble;
3030

3131
public ScoreTracker TrainingScores;
3232
public List<ScoreTracker> TrackedScores;
@@ -37,7 +37,7 @@ public abstract class OptimizationAlgorithm
3737
public Random DropoutRng;
3838
public bool UseFastTrainingScoresUpdate;
3939

40-
public OptimizationAlgorithm(Ensemble ensemble, Dataset trainData, double[] initTrainScores)
40+
public OptimizationAlgorithm(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores)
4141
{
4242
Ensemble = ensemble;
4343
TrainingScores = ConstructScoreTracker("train", trainData, initTrainScores);

0 commit comments

Comments
 (0)