@@ -27,11 +27,12 @@ import org.apache.spark.SparkContext._
27
27
import org .jblas ._
28
28
import org .apache .spark .rdd .RDD
29
29
import org .apache .spark .mllib .regression .LabeledPoint
30
- import org .apache .spark .mllib .tree .impurity .{Entropy , Gini }
30
+ import org .apache .spark .mllib .tree .impurity .{Entropy , Gini , Variance }
31
31
import org .apache .spark .mllib .tree .model .Filter
32
32
import org .apache .spark .mllib .tree .configuration .Strategy
33
33
import org .apache .spark .mllib .tree .configuration .Algo ._
34
34
import scala .collection .mutable
35
+ import org .apache .spark .mllib .tree .configuration .FeatureType ._
35
36
36
37
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
37
38
@@ -56,7 +57,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
56
57
assert(bins.length== 2 )
57
58
assert(splits(0 ).length== 99 )
58
59
assert(bins(0 ).length== 100 )
59
- // println(splits(1)(98))
60
60
}
61
61
62
62
test(" split and bin calculation for categorical variables" ){
@@ -69,13 +69,71 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
69
69
assert(bins.length== 2 )
70
70
assert(splits(0 ).length== 99 )
71
71
assert(bins(0 ).length== 100 )
72
- println(splits(0 )(0 ))
73
- println(splits(0 )(1 ))
74
- println(bins(0 )(0 ))
75
- println(splits(1 )(0 ))
76
- println(splits(1 )(1 ))
77
- println(bins(1 )(0 ))
78
- // TODO: Add asserts
72
+
73
+ // Checking splits
74
+
75
+ assert(splits(0 )(0 ).feature == 0 )
76
+ assert(splits(0 )(0 ).threshold == Double .MinValue )
77
+ assert(splits(0 )(0 ).featureType == Categorical )
78
+ assert(splits(0 )(0 ).categories.length == 1 )
79
+ assert(splits(0 )(0 ).categories.contains(1.0 ))
80
+
81
+
82
+ assert(splits(0 )(1 ).feature == 0 )
83
+ assert(splits(0 )(1 ).threshold == Double .MinValue )
84
+ assert(splits(0 )(1 ).featureType == Categorical )
85
+ assert(splits(0 )(1 ).categories.length == 2 )
86
+ assert(splits(0 )(1 ).categories.contains(1.0 ))
87
+ assert(splits(0 )(1 ).categories.contains(0.0 ))
88
+
89
+ assert(splits(0 )(2 ) == null )
90
+
91
+ assert(splits(1 )(0 ).feature == 1 )
92
+ assert(splits(1 )(0 ).threshold == Double .MinValue )
93
+ assert(splits(1 )(0 ).featureType == Categorical )
94
+ assert(splits(1 )(0 ).categories.length == 1 )
95
+ assert(splits(1 )(0 ).categories.contains(0.0 ))
96
+
97
+
98
+ assert(splits(1 )(1 ).feature == 1 )
99
+ assert(splits(1 )(1 ).threshold == Double .MinValue )
100
+ assert(splits(1 )(1 ).featureType == Categorical )
101
+ assert(splits(1 )(1 ).categories.length == 2 )
102
+ assert(splits(1 )(1 ).categories.contains(1.0 ))
103
+ assert(splits(1 )(1 ).categories.contains(0.0 ))
104
+
105
+ assert(splits(1 )(2 ) == null )
106
+
107
+
108
+ // Checks bins
109
+
110
+ assert(bins(0 )(0 ).category == 1.0 )
111
+ assert(bins(0 )(0 ).lowSplit.categories.length == 0 )
112
+ assert(bins(0 )(0 ).highSplit.categories.length == 1 )
113
+ assert(bins(0 )(0 ).highSplit.categories.contains(1.0 ))
114
+
115
+ assert(bins(0 )(1 ).category == 0.0 )
116
+ assert(bins(0 )(1 ).lowSplit.categories.length == 1 )
117
+ assert(bins(0 )(1 ).lowSplit.categories.contains(1.0 ))
118
+ assert(bins(0 )(1 ).highSplit.categories.length == 2 )
119
+ assert(bins(0 )(1 ).highSplit.categories.contains(1.0 ))
120
+ assert(bins(0 )(1 ).highSplit.categories.contains(0.0 ))
121
+
122
+ assert(bins(0 )(2 ).category == Double .MaxValue )
123
+
124
+ assert(bins(1 )(0 ).category == 0.0 )
125
+ assert(bins(1 )(0 ).lowSplit.categories.length == 0 )
126
+ assert(bins(1 )(0 ).highSplit.categories.length == 1 )
127
+ assert(bins(1 )(0 ).highSplit.categories.contains(0.0 ))
128
+
129
+ assert(bins(1 )(1 ).category == 1.0 )
130
+ assert(bins(1 )(1 ).lowSplit.categories.length == 1 )
131
+ assert(bins(1 )(1 ).lowSplit.categories.contains(0.0 ))
132
+ assert(bins(1 )(1 ).highSplit.categories.length == 2 )
133
+ assert(bins(1 )(1 ).highSplit.categories.contains(0.0 ))
134
+ assert(bins(1 )(1 ).highSplit.categories.contains(1.0 ))
135
+
136
+ assert(bins(1 )(2 ).category == Double .MaxValue )
79
137
80
138
}
81
139
@@ -85,29 +143,106 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
85
143
val rdd = sc.parallelize(arr)
86
144
val strategy = new Strategy (Classification ,Gini ,3 ,100 ,categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
87
145
val (splits, bins) = DecisionTree .findSplitsBins(rdd,strategy)
88
- assert(splits.length== 2 )
89
- assert(bins.length== 2 )
90
- assert(splits(0 ).length== 99 )
91
- assert(bins(0 ).length== 100 )
92
- println(splits(0 )(0 ))
93
- println(splits(0 )(1 ))
94
- println(splits(0 )(2 ))
95
- println(bins(0 )(0 ))
96
- println(bins(0 )(1 ))
97
- println(bins(0 )(2 ))
98
- println(splits(1 )(0 ))
99
- println(splits(1 )(1 ))
100
- println(splits(1 )(2 ))
101
- println(bins(1 )(0 ))
102
- println(bins(1 )(1 ))
103
- println(bins(0 )(2 ))
104
- println(bins(0 )(3 ))
105
- // TODO: Add asserts
106
146
107
- }
147
+ // Checking splits
148
+
149
+ assert(splits(0 )(0 ).feature == 0 )
150
+ assert(splits(0 )(0 ).threshold == Double .MinValue )
151
+ assert(splits(0 )(0 ).featureType == Categorical )
152
+ assert(splits(0 )(0 ).categories.length == 1 )
153
+ assert(splits(0 )(0 ).categories.contains(1.0 ))
154
+
155
+ assert(splits(0 )(1 ).feature == 0 )
156
+ assert(splits(0 )(1 ).threshold == Double .MinValue )
157
+ assert(splits(0 )(1 ).featureType == Categorical )
158
+ assert(splits(0 )(1 ).categories.length == 2 )
159
+ assert(splits(0 )(1 ).categories.contains(1.0 ))
160
+ assert(splits(0 )(1 ).categories.contains(0.0 ))
161
+
162
+ assert(splits(0 )(2 ).feature == 0 )
163
+ assert(splits(0 )(2 ).threshold == Double .MinValue )
164
+ assert(splits(0 )(2 ).featureType == Categorical )
165
+ assert(splits(0 )(2 ).categories.length == 3 )
166
+ assert(splits(0 )(2 ).categories.contains(1.0 ))
167
+ assert(splits(0 )(2 ).categories.contains(0.0 ))
168
+ assert(splits(0 )(2 ).categories.contains(2.0 ))
169
+
170
+ assert(splits(0 )(3 ) == null )
171
+
172
+ assert(splits(1 )(0 ).feature == 1 )
173
+ assert(splits(1 )(0 ).threshold == Double .MinValue )
174
+ assert(splits(1 )(0 ).featureType == Categorical )
175
+ assert(splits(1 )(0 ).categories.length == 1 )
176
+ assert(splits(1 )(0 ).categories.contains(0.0 ))
177
+
178
+ assert(splits(1 )(1 ).feature == 1 )
179
+ assert(splits(1 )(1 ).threshold == Double .MinValue )
180
+ assert(splits(1 )(1 ).featureType == Categorical )
181
+ assert(splits(1 )(1 ).categories.length == 2 )
182
+ assert(splits(1 )(1 ).categories.contains(1.0 ))
183
+ assert(splits(1 )(1 ).categories.contains(0.0 ))
184
+
185
+ assert(splits(1 )(2 ).feature == 1 )
186
+ assert(splits(1 )(2 ).threshold == Double .MinValue )
187
+ assert(splits(1 )(2 ).featureType == Categorical )
188
+ assert(splits(1 )(2 ).categories.length == 3 )
189
+ assert(splits(1 )(2 ).categories.contains(1.0 ))
190
+ assert(splits(1 )(2 ).categories.contains(0.0 ))
191
+ assert(splits(1 )(2 ).categories.contains(2.0 ))
192
+
193
+ assert(splits(1 )(3 ) == null )
194
+
195
+
196
+ // Checks bins
197
+
198
+ assert(bins(0 )(0 ).category == 1.0 )
199
+ assert(bins(0 )(0 ).lowSplit.categories.length == 0 )
200
+ assert(bins(0 )(0 ).highSplit.categories.length == 1 )
201
+ assert(bins(0 )(0 ).highSplit.categories.contains(1.0 ))
202
+
203
+ assert(bins(0 )(1 ).category == 0.0 )
204
+ assert(bins(0 )(1 ).lowSplit.categories.length == 1 )
205
+ assert(bins(0 )(1 ).lowSplit.categories.contains(1.0 ))
206
+ assert(bins(0 )(1 ).highSplit.categories.length == 2 )
207
+ assert(bins(0 )(1 ).highSplit.categories.contains(1.0 ))
208
+ assert(bins(0 )(1 ).highSplit.categories.contains(0.0 ))
209
+
210
+ assert(bins(0 )(2 ).category == 2.0 )
211
+ assert(bins(0 )(2 ).lowSplit.categories.length == 2 )
212
+ assert(bins(0 )(2 ).lowSplit.categories.contains(1.0 ))
213
+ assert(bins(0 )(2 ).lowSplit.categories.contains(0.0 ))
214
+ assert(bins(0 )(2 ).highSplit.categories.length == 3 )
215
+ assert(bins(0 )(2 ).highSplit.categories.contains(1.0 ))
216
+ assert(bins(0 )(2 ).highSplit.categories.contains(0.0 ))
217
+ assert(bins(0 )(2 ).highSplit.categories.contains(2.0 ))
218
+
219
+ assert(bins(0 )(3 ).category == Double .MaxValue )
220
+
221
+ assert(bins(1 )(0 ).category == 0.0 )
222
+ assert(bins(1 )(0 ).lowSplit.categories.length == 0 )
223
+ assert(bins(1 )(0 ).highSplit.categories.length == 1 )
224
+ assert(bins(1 )(0 ).highSplit.categories.contains(0.0 ))
225
+
226
+ assert(bins(1 )(1 ).category == 1.0 )
227
+ assert(bins(1 )(1 ).lowSplit.categories.length == 1 )
228
+ assert(bins(1 )(1 ).lowSplit.categories.contains(0.0 ))
229
+ assert(bins(1 )(1 ).highSplit.categories.length == 2 )
230
+ assert(bins(1 )(1 ).highSplit.categories.contains(0.0 ))
231
+ assert(bins(1 )(1 ).highSplit.categories.contains(1.0 ))
232
+
233
+ assert(bins(1 )(2 ).category == 2.0 )
234
+ assert(bins(1 )(2 ).lowSplit.categories.length == 2 )
235
+ assert(bins(1 )(2 ).lowSplit.categories.contains(0.0 ))
236
+ assert(bins(1 )(2 ).lowSplit.categories.contains(1.0 ))
237
+ assert(bins(1 )(2 ).highSplit.categories.length == 3 )
238
+ assert(bins(1 )(2 ).highSplit.categories.contains(0.0 ))
239
+ assert(bins(1 )(2 ).highSplit.categories.contains(1.0 ))
240
+ assert(bins(1 )(2 ).highSplit.categories.contains(2.0 ))
241
+
242
+ assert(bins(1 )(3 ).category == Double .MaxValue )
108
243
109
- // TODO: Test max feature value > num bins
110
244
245
+ }
111
246
112
247
test(" classification stump with all categorical variables" ){
113
248
val arr = DecisionTreeSuite .generateCategoricalDataPoints()
@@ -117,22 +252,41 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
117
252
val (splits, bins) = DecisionTree .findSplitsBins(rdd,strategy)
118
253
strategy.numBins = 100
119
254
val bestSplits = DecisionTree .findBestSplits(rdd,new Array (7 ),strategy,0 ,Array [List [Filter ]](),splits,bins)
120
- println(bestSplits(0 )._1)
121
- println(bestSplits(0 )._2)
122
- // TODO: Add asserts
255
+
256
+ val split = bestSplits(0 )._1
257
+ assert(split.categories.length == 1 )
258
+ assert(split.categories.contains(1.0 ))
259
+ assert(split.featureType == Categorical )
260
+ assert(split.threshold == Double .MinValue )
261
+
262
+ val stats = bestSplits(0 )._2
263
+ assert(stats.gain > 0 )
264
+ assert(stats.predict > 0.4 )
265
+ assert(stats.predict < 0.5 )
266
+ assert(stats.impurity > 0.2 )
267
+
123
268
}
124
269
125
270
test(" regression stump with all categorical variables" ){
126
271
val arr = DecisionTreeSuite .generateCategoricalDataPoints()
127
272
assert(arr.length == 1000 )
128
273
val rdd = sc.parallelize(arr)
129
- val strategy = new Strategy (Classification , Gini ,3 ,100 ,categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
274
+ val strategy = new Strategy (Regression , Variance ,3 ,100 ,categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
130
275
val (splits, bins) = DecisionTree .findSplitsBins(rdd,strategy)
131
276
strategy.numBins = 100
132
277
val bestSplits = DecisionTree .findBestSplits(rdd,new Array (7 ),strategy,0 ,Array [List [Filter ]](),splits,bins)
133
- println(bestSplits(0 )._1)
134
- println(bestSplits(0 )._2)
135
- // TODO: Add asserts
278
+
279
+ val split = bestSplits(0 )._1
280
+ assert(split.categories.length == 1 )
281
+ assert(split.categories.contains(1.0 ))
282
+ assert(split.featureType == Categorical )
283
+ assert(split.threshold == Double .MinValue )
284
+
285
+ val stats = bestSplits(0 )._2
286
+ assert(stats.gain > 0 )
287
+ assert(stats.predict > 0.4 )
288
+ assert(stats.predict < 0.5 )
289
+ assert(stats.impurity > 0.2 )
136
290
}
137
291
138
292
@@ -157,7 +311,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
157
311
assert(0 == bestSplits(0 )._2.gain)
158
312
assert(0 == bestSplits(0 )._2.leftImpurity)
159
313
assert(0 == bestSplits(0 )._2.rightImpurity)
160
- println(bestSplits( 0 )._2.predict)
314
+
161
315
}
162
316
163
317
test(" stump with fixed label 1 for Gini" ){
0 commit comments