Skip to content

Commit 8b01fc5

Browse files
authored
Combine multiple tree ensemble models into a single tree ensemble (dotnet#364)
* Add a way to create a single tree ensemble model from multiple tree ensemble models. * Address PR comments, and fix bugs in serializing/deserializing RegressionTrees. * Address PR comments.
1 parent 09f7c66 commit 8b01fc5

File tree

6 files changed

+311
-39
lines changed

6 files changed

+311
-39
lines changed

src/Microsoft.ML.Core/Prediction/ITrainer.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ public interface IModelCombiner<TModel, TPredictor>
7474
TPredictor CombineModels(IEnumerable<TModel> models);
7575
}
7676

77+
public delegate void SignatureModelCombiner(PredictionKind kind);
78+
7779
/// <summary>
7880
/// Weakly typed interface for a trainer "session" that produces a predictor.
7981
/// </summary>

src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
<Compile Include="TreeEnsemble\Ensemble.cs" />
5858
<Compile Include="TreeEnsemble\QuantileRegressionTree.cs" />
5959
<Compile Include="TreeEnsemble\RegressionTree.cs" />
60+
<Compile Include="TreeEnsemble\TreeEnsembleCombiner.cs" />
6061
<Compile Include="Training\Applications\GradientWrappers.cs" />
6162
<Compile Include="Training\Applications\ObjectiveFunction.cs" />
6263
<Compile Include="Training\BaggingProvider.cs" />

src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,22 +105,21 @@ public RegressionTree(byte[] buffer, ref int position)
105105
LteChild = buffer.ToIntArray(ref position);
106106
GtChild = buffer.ToIntArray(ref position);
107107
SplitFeatures = buffer.ToIntArray(ref position);
108-
int[] categoricalNodeIndices = buffer.ToIntArray(ref position);
109-
CategoricalSplit = GetCategoricalSplitFromIndices(categoricalNodeIndices);
110-
if (categoricalNodeIndices?.Length > 0)
108+
byte[] categoricalSplitAsBytes = buffer.ToByteArray(ref position);
109+
CategoricalSplit = categoricalSplitAsBytes.Select(b => b > 0).ToArray();
110+
if (CategoricalSplit.Any(b => b))
111111
{
112112
CategoricalSplitFeatures = new int[NumNodes][];
113113
CategoricalSplitFeatureRanges = new int[NumNodes][];
114-
foreach (var index in categoricalNodeIndices)
114+
for (int index = 0; index < NumNodes; index++)
115115
{
116-
Contracts.Assert(CategoricalSplit[index]);
117-
118116
CategoricalSplitFeatures[index] = buffer.ToIntArray(ref position);
119-
CategoricalSplitFeatureRanges[index] = buffer.ToIntArray(ref position, 2);
117+
CategoricalSplitFeatureRanges[index] = buffer.ToIntArray(ref position);
120118
}
121119
}
122120

123121
Thresholds = buffer.ToUIntArray(ref position);
122+
RawThresholds = buffer.ToFloatArray(ref position);
124123
_splitGain = buffer.ToDoubleArray(ref position);
125124
_gainPValue = buffer.ToDoubleArray(ref position);
126125
_previousLeafValue = buffer.ToDoubleArray(ref position);
@@ -144,6 +143,23 @@ private bool[] GetCategoricalSplitFromIndices(int[] indices)
144143
return categoricalSplit;
145144
}
146145

146+
private bool[] GetCategoricalSplitFromBytes(byte[] indices)
147+
{
148+
bool[] categoricalSplit = new bool[NumNodes];
149+
if (indices == null)
150+
return categoricalSplit;
151+
152+
Contracts.Assert(indices.Length <= NumNodes);
153+
154+
foreach (int index in indices)
155+
{
156+
Contracts.Assert(index >= 0 && index < NumNodes);
157+
categoricalSplit[index] = true;
158+
}
159+
160+
return categoricalSplit;
161+
}
162+
147163
/// <summary>
148164
/// Create a Regression Tree object from raw tree contents.
149165
/// </summary>
@@ -192,7 +208,7 @@ internal RegressionTree(int[] splitFeatures, Double[] splitGain, Double[] gainPV
192208
LeafValues = leafValues;
193209
CategoricalSplitFeatures = categoricalSplitFeatures;
194210
CategoricalSplitFeatureRanges = new int[CategoricalSplitFeatures.Length][];
195-
for(int i= 0; i < CategoricalSplitFeatures.Length; ++i)
211+
for (int i = 0; i < CategoricalSplitFeatures.Length; ++i)
196212
{
197213
if (CategoricalSplitFeatures[i] != null && CategoricalSplitFeatures[i].Length > 0)
198214
{
@@ -500,6 +516,7 @@ public virtual int SizeInBytes()
500516
NumNodes * sizeof(int) +
501517
CategoricalSplit.Length * sizeof(bool) +
502518
Thresholds.SizeInBytes() +
519+
RawThresholds.SizeInBytes() +
503520
_splitGain.SizeInBytes() +
504521
_gainPValue.SizeInBytes() +
505522
_previousLeafValue.SizeInBytes() +
@@ -514,22 +531,22 @@ public virtual void ToByteArray(byte[] buffer, ref int position)
514531
LteChild.ToByteArray(buffer, ref position);
515532
GtChild.ToByteArray(buffer, ref position);
516533
SplitFeatures.ToByteArray(buffer, ref position);
534+
CategoricalSplit.Length.ToByteArray(buffer, ref position);
517535
foreach (var split in CategoricalSplit)
518536
Convert.ToByte(split).ToByteArray(buffer, ref position);
519537

520538
if (CategoricalSplitFeatures != null)
521539
{
522-
foreach (var splits in CategoricalSplitFeatures)
523-
splits.ToByteArray(buffer, ref position);
524-
}
525-
526-
if (CategoricalSplitFeatureRanges != null)
527-
{
528-
foreach (var ranges in CategoricalSplitFeatureRanges)
529-
ranges.ToByteArray(buffer, ref position);
540+
Contracts.AssertValue(CategoricalSplitFeatureRanges);
541+
for (int i = 0; i < CategoricalSplitFeatures.Length; i++)
542+
{
543+
CategoricalSplitFeatures[i].ToByteArray(buffer, ref position);
544+
CategoricalSplitFeatureRanges[i].ToByteArray(buffer, ref position);
545+
}
530546
}
531547

532548
Thresholds.ToByteArray(buffer, ref position);
549+
RawThresholds.ToByteArray(buffer, ref position);
533550
_splitGain.ToByteArray(buffer, ref position);
534551
_gainPValue.ToByteArray(buffer, ref position);
535552
_previousLeafValue.ToByteArray(buffer, ref position);
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.FastTree.Internal;
8+
using Microsoft.ML.Runtime.Internal.Calibration;
9+
10+
[assembly: LoadableClass(typeof(TreeEnsembleCombiner), null, typeof(SignatureModelCombiner), "Fast Tree Model Combiner", "FastTreeCombiner")]
11+
12+
namespace Microsoft.ML.Runtime.FastTree.Internal
13+
{
14+
public sealed class TreeEnsembleCombiner : IModelCombiner<IPredictorProducing<float>, IPredictorProducing<float>>
15+
{
16+
private readonly IHost _host;
17+
private readonly PredictionKind _kind;
18+
19+
public TreeEnsembleCombiner(IHostEnvironment env, PredictionKind kind)
20+
{
21+
_host = env.Register("TreeEnsembleCombiner");
22+
switch (kind)
23+
{
24+
case PredictionKind.BinaryClassification:
25+
case PredictionKind.Regression:
26+
case PredictionKind.Ranking:
27+
_kind = kind;
28+
break;
29+
default:
30+
throw _host.ExceptUserArg(nameof(kind), $"Tree ensembles can be either of type {nameof(PredictionKind.BinaryClassification)}, " +
31+
$"{nameof(PredictionKind.Regression)} or {nameof(PredictionKind.Ranking)}");
32+
}
33+
}
34+
35+
public IPredictorProducing<float> CombineModels(IEnumerable<IPredictorProducing<float>> models)
36+
{
37+
_host.CheckValue(models, nameof(models));
38+
39+
var ensemble = new Ensemble();
40+
int modelCount = 0;
41+
int featureCount = -1;
42+
bool binaryClassifier = false;
43+
foreach (var model in models)
44+
{
45+
modelCount++;
46+
47+
var predictor = model;
48+
_host.CheckValue(predictor, nameof(models), "One of the models is null");
49+
50+
var calibrated = predictor as CalibratedPredictorBase;
51+
double paramA = 1;
52+
if (calibrated != null)
53+
{
54+
_host.Check(calibrated.Calibrator is PlattCalibrator,
55+
"Combining FastTree models can only be done when the models are calibrated with Platt calibrator");
56+
predictor = calibrated.SubPredictor;
57+
paramA = -(calibrated.Calibrator as PlattCalibrator).ParamA;
58+
}
59+
var tree = predictor as FastTreePredictionWrapper;
60+
if (tree == null)
61+
throw _host.Except("Model is not a tree ensemble");
62+
foreach (var t in tree.TrainedEnsemble.Trees)
63+
{
64+
var bytes = new byte[t.SizeInBytes()];
65+
int position = -1;
66+
t.ToByteArray(bytes, ref position);
67+
position = -1;
68+
var tNew = new RegressionTree(bytes, ref position);
69+
if (paramA != 1)
70+
{
71+
for (int i = 0; i < tNew.NumLeaves; i++)
72+
tNew.SetOutput(i, tNew.LeafValues[i] * paramA);
73+
}
74+
ensemble.AddTree(tNew);
75+
}
76+
77+
if (modelCount == 1)
78+
{
79+
binaryClassifier = calibrated != null;
80+
featureCount = tree.InputType.ValueCount;
81+
}
82+
else
83+
{
84+
_host.Check((calibrated != null) == binaryClassifier, "Ensemble contains both calibrated and uncalibrated models");
85+
_host.Check(featureCount == tree.InputType.ValueCount, "Found models with different number of features");
86+
}
87+
}
88+
89+
var scale = 1 / (double)modelCount;
90+
91+
foreach (var t in ensemble.Trees)
92+
{
93+
for (int i = 0; i < t.NumLeaves; i++)
94+
t.SetOutput(i, t.LeafValues[i] * scale);
95+
}
96+
97+
switch (_kind)
98+
{
99+
case PredictionKind.BinaryClassification:
100+
if (!binaryClassifier)
101+
return new FastTreeBinaryPredictor(_host, ensemble, featureCount, null);
102+
103+
var cali = new PlattCalibrator(_host, -1, 0);
104+
return new FeatureWeightsCalibratedPredictor(_host, new FastTreeBinaryPredictor(_host, ensemble, featureCount, null), cali);
105+
case PredictionKind.Regression:
106+
return new FastTreeRegressionPredictor(_host, ensemble, featureCount, null);
107+
case PredictionKind.Ranking:
108+
return new FastTreeRankingPredictor(_host, ensemble, featureCount, null);
109+
default:
110+
_host.Assert(false);
111+
throw _host.ExceptNotSupp();
112+
}
113+
}
114+
}
115+
}

src/Microsoft.ML.FastTree/Utils/ToByteArrayExtensions.cs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Linq;
77
using System.Text;
8+
using Microsoft.ML.Runtime.Internal.Utilities;
89

910
namespace Microsoft.ML.Runtime.FastTree.Internal
1011
{
@@ -290,7 +291,7 @@ public static string ToString(this byte[] buffer, ref int position)
290291

291292
public static int SizeInBytes(this byte[] a)
292293
{
293-
return sizeof(int) + a.Length * sizeof(byte);
294+
return sizeof(int) + Utils.Size(a) * sizeof(byte);
294295
}
295296

296297
public static void ToByteArray(this byte[] a, byte[] buffer, ref int position)
@@ -314,7 +315,7 @@ public static byte[] ToByteArray(this byte[] buffer, ref int position)
314315

315316
public static int SizeInBytes(this short[] a)
316317
{
317-
return sizeof(int) + a.Length * sizeof(short);
318+
return sizeof(int) + Utils.Size(a) * sizeof(short);
318319
}
319320

320321
public unsafe static void ToByteArray(this short[] a, byte[] buffer, ref int position)
@@ -353,7 +354,7 @@ public unsafe static short[] ToShortArray(this byte[] buffer, ref int position)
353354

354355
public static int SizeInBytes(this ushort[] a)
355356
{
356-
return sizeof(int) + a.Length * sizeof(ushort);
357+
return sizeof(int) + Utils.Size(a) * sizeof(ushort);
357358
}
358359

359360
public unsafe static void ToByteArray(this ushort[] a, byte[] buffer, ref int position)
@@ -392,12 +393,12 @@ public unsafe static ushort[] ToUShortArray(this byte[] buffer, ref int position
392393

393394
public static int SizeInBytes(this int[] array)
394395
{
395-
return sizeof(int) + array.Length * sizeof(int);
396+
return sizeof(int) + Utils.Size(array) * sizeof(int);
396397
}
397398

398399
public unsafe static void ToByteArray(this int[] a, byte[] buffer, ref int position)
399400
{
400-
int length = a.Length;
401+
int length = Utils.Size(a);
401402
length.ToByteArray(buffer, ref position);
402403

403404
fixed (byte* tmpBuffer = buffer)
@@ -415,6 +416,9 @@ public unsafe static int[] ToIntArray(this byte[] buffer, ref int position)
415416

416417
public unsafe static int[] ToIntArray(this byte[] buffer, ref int position, int length)
417418
{
419+
if (length == 0)
420+
return null;
421+
418422
int[] a = new int[length];
419423

420424
fixed (byte* tmpBuffer = buffer)
@@ -433,7 +437,7 @@ public unsafe static int[] ToIntArray(this byte[] buffer, ref int position, int
433437

434438
public static int SizeInBytes(this uint[] array)
435439
{
436-
return sizeof(int) + array.Length * sizeof(uint);
440+
return sizeof(int) + Utils.Size(array) * sizeof(uint);
437441
}
438442

439443
public unsafe static void ToByteArray(this uint[] a, byte[] buffer, ref int position)
@@ -472,7 +476,7 @@ public unsafe static uint[] ToUIntArray(this byte[] buffer, ref int position)
472476

473477
public static int SizeInBytes(this long[] array)
474478
{
475-
return sizeof(int) + array.Length * sizeof(long);
479+
return sizeof(int) + Utils.Size(array) * sizeof(long);
476480
}
477481

478482
public unsafe static void ToByteArray(this long[] a, byte[] buffer, ref int position)
@@ -511,7 +515,7 @@ public unsafe static long[] ToLongArray(this byte[] buffer, ref int position)
511515

512516
public static int SizeInBytes(this ulong[] array)
513517
{
514-
return sizeof(int) + array.Length * sizeof(ulong);
518+
return sizeof(int) + Utils.Size(array) * sizeof(ulong);
515519
}
516520

517521
public unsafe static void ToByteArray(this ulong[] a, byte[] buffer, ref int position)
@@ -550,7 +554,7 @@ public unsafe static ulong[] ToULongArray(this byte[] buffer, ref int position)
550554

551555
public static int SizeInBytes(this MD5Hash[] array)
552556
{
553-
return sizeof(int) + array.Length * MD5Hash.SizeInBytes();
557+
return sizeof(int) + Utils.Size(array) * MD5Hash.SizeInBytes();
554558
}
555559

556560
public static void ToByteArray(this MD5Hash[] a, byte[] buffer, ref int position)
@@ -577,7 +581,7 @@ public unsafe static MD5Hash[] ToUInt128Array(this byte[] buffer, ref int positi
577581

578582
public static int SizeInBytes(this float[] array)
579583
{
580-
return sizeof(int) + array.Length * sizeof(float);
584+
return sizeof(int) + Utils.Size(array) * sizeof(float);
581585
}
582586

583587
public unsafe static void ToByteArray(this float[] a, byte[] buffer, ref int position)
@@ -616,7 +620,7 @@ public unsafe static float[] ToFloatArray(this byte[] buffer, ref int position)
616620

617621
public static int SizeInBytes(this double[] array)
618622
{
619-
return sizeof(int) + array.Length * sizeof(double);
623+
return sizeof(int) + Utils.Size(array) * sizeof(double);
620624
}
621625

622626
public unsafe static void ToByteArray(this double[] a, byte[] buffer, ref int position)
@@ -655,6 +659,8 @@ public unsafe static double[] ToDoubleArray(this byte[] buffer, ref int position
655659

656660
public static int SizeInBytes(this double[][] array)
657661
{
662+
if (Utils.Size(array) == 0)
663+
return sizeof(int);
658664
return sizeof(int) + array.Sum(x => x.SizeInBytes());
659665
}
660666

@@ -683,7 +689,7 @@ public static double[][] ToDoubleJaggedArray(this byte[] buffer, ref int positio
683689
public static long SizeInBytes(this string[] array)
684690
{
685691
long length = sizeof(int);
686-
for (int i = 0; i < array.Length; ++i)
692+
for (int i = 0; i < Utils.Size(array); ++i)
687693
{
688694
length += array[i].SizeInBytes();
689695
}
@@ -692,8 +698,8 @@ public static long SizeInBytes(this string[] array)
692698

693699
public static void ToByteArray(this string[] a, byte[] buffer, ref int position)
694700
{
695-
a.Length.ToByteArray(buffer, ref position);
696-
for (int i = 0; i < a.Length; ++i)
701+
Utils.Size(a).ToByteArray(buffer, ref position);
702+
for (int i = 0; i < Utils.Size(a); ++i)
697703
{
698704
a[i].ToByteArray(buffer, ref position);
699705
}

0 commit comments

Comments
 (0)