Skip to content

Commit 5841c28

Browse files
committed
unit tests for categorical features
Signed-off-by: Manish Amde <[email protected]>
1 parent f067d68 commit 5841c28

File tree

1 file changed

+191
-37
lines changed

1 file changed

+191
-37
lines changed

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 191 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ import org.apache.spark.SparkContext._
2727
import org.jblas._
2828
import org.apache.spark.rdd.RDD
2929
import org.apache.spark.mllib.regression.LabeledPoint
30-
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini}
30+
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
3131
import org.apache.spark.mllib.tree.model.Filter
3232
import org.apache.spark.mllib.tree.configuration.Strategy
3333
import org.apache.spark.mllib.tree.configuration.Algo._
3434
import scala.collection.mutable
35+
import org.apache.spark.mllib.tree.configuration.FeatureType._
3536

3637
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
3738

@@ -56,7 +57,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
5657
assert(bins.length==2)
5758
assert(splits(0).length==99)
5859
assert(bins(0).length==100)
59-
//println(splits(1)(98))
6060
}
6161

6262
test("split and bin calculation for categorical variables"){
@@ -69,13 +69,71 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
6969
assert(bins.length==2)
7070
assert(splits(0).length==99)
7171
assert(bins(0).length==100)
72-
println(splits(0)(0))
73-
println(splits(0)(1))
74-
println(bins(0)(0))
75-
println(splits(1)(0))
76-
println(splits(1)(1))
77-
println(bins(1)(0))
78-
//TODO: Add asserts
72+
73+
//Checking splits
74+
75+
assert(splits(0)(0).feature == 0)
76+
assert(splits(0)(0).threshold == Double.MinValue)
77+
assert(splits(0)(0).featureType == Categorical)
78+
assert(splits(0)(0).categories.length == 1)
79+
assert(splits(0)(0).categories.contains(1.0))
80+
81+
82+
assert(splits(0)(1).feature == 0)
83+
assert(splits(0)(1).threshold == Double.MinValue)
84+
assert(splits(0)(1).featureType == Categorical)
85+
assert(splits(0)(1).categories.length == 2)
86+
assert(splits(0)(1).categories.contains(1.0))
87+
assert(splits(0)(1).categories.contains(0.0))
88+
89+
assert(splits(0)(2) == null)
90+
91+
assert(splits(1)(0).feature == 1)
92+
assert(splits(1)(0).threshold == Double.MinValue)
93+
assert(splits(1)(0).featureType == Categorical)
94+
assert(splits(1)(0).categories.length == 1)
95+
assert(splits(1)(0).categories.contains(0.0))
96+
97+
98+
assert(splits(1)(1).feature == 1)
99+
assert(splits(1)(1).threshold == Double.MinValue)
100+
assert(splits(1)(1).featureType == Categorical)
101+
assert(splits(1)(1).categories.length == 2)
102+
assert(splits(1)(1).categories.contains(1.0))
103+
assert(splits(1)(1).categories.contains(0.0))
104+
105+
assert(splits(1)(2) == null)
106+
107+
108+
// Checks bins
109+
110+
assert(bins(0)(0).category == 1.0)
111+
assert(bins(0)(0).lowSplit.categories.length == 0)
112+
assert(bins(0)(0).highSplit.categories.length == 1)
113+
assert(bins(0)(0).highSplit.categories.contains(1.0))
114+
115+
assert(bins(0)(1).category == 0.0)
116+
assert(bins(0)(1).lowSplit.categories.length == 1)
117+
assert(bins(0)(1).lowSplit.categories.contains(1.0))
118+
assert(bins(0)(1).highSplit.categories.length == 2)
119+
assert(bins(0)(1).highSplit.categories.contains(1.0))
120+
assert(bins(0)(1).highSplit.categories.contains(0.0))
121+
122+
assert(bins(0)(2).category == Double.MaxValue)
123+
124+
assert(bins(1)(0).category == 0.0)
125+
assert(bins(1)(0).lowSplit.categories.length == 0)
126+
assert(bins(1)(0).highSplit.categories.length == 1)
127+
assert(bins(1)(0).highSplit.categories.contains(0.0))
128+
129+
assert(bins(1)(1).category == 1.0)
130+
assert(bins(1)(1).lowSplit.categories.length == 1)
131+
assert(bins(1)(1).lowSplit.categories.contains(0.0))
132+
assert(bins(1)(1).highSplit.categories.length == 2)
133+
assert(bins(1)(1).highSplit.categories.contains(0.0))
134+
assert(bins(1)(1).highSplit.categories.contains(1.0))
135+
136+
assert(bins(1)(2).category == Double.MaxValue)
79137

80138
}
81139

@@ -85,29 +143,106 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
85143
val rdd = sc.parallelize(arr)
86144
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
87145
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
88-
assert(splits.length==2)
89-
assert(bins.length==2)
90-
assert(splits(0).length==99)
91-
assert(bins(0).length==100)
92-
println(splits(0)(0))
93-
println(splits(0)(1))
94-
println(splits(0)(2))
95-
println(bins(0)(0))
96-
println(bins(0)(1))
97-
println(bins(0)(2))
98-
println(splits(1)(0))
99-
println(splits(1)(1))
100-
println(splits(1)(2))
101-
println(bins(1)(0))
102-
println(bins(1)(1))
103-
println(bins(0)(2))
104-
println(bins(0)(3))
105-
//TODO: Add asserts
106146

107-
}
147+
//Checking splits
148+
149+
assert(splits(0)(0).feature == 0)
150+
assert(splits(0)(0).threshold == Double.MinValue)
151+
assert(splits(0)(0).featureType == Categorical)
152+
assert(splits(0)(0).categories.length == 1)
153+
assert(splits(0)(0).categories.contains(1.0))
154+
155+
assert(splits(0)(1).feature == 0)
156+
assert(splits(0)(1).threshold == Double.MinValue)
157+
assert(splits(0)(1).featureType == Categorical)
158+
assert(splits(0)(1).categories.length == 2)
159+
assert(splits(0)(1).categories.contains(1.0))
160+
assert(splits(0)(1).categories.contains(0.0))
161+
162+
assert(splits(0)(2).feature == 0)
163+
assert(splits(0)(2).threshold == Double.MinValue)
164+
assert(splits(0)(2).featureType == Categorical)
165+
assert(splits(0)(2).categories.length == 3)
166+
assert(splits(0)(2).categories.contains(1.0))
167+
assert(splits(0)(2).categories.contains(0.0))
168+
assert(splits(0)(2).categories.contains(2.0))
169+
170+
assert(splits(0)(3) == null)
171+
172+
assert(splits(1)(0).feature == 1)
173+
assert(splits(1)(0).threshold == Double.MinValue)
174+
assert(splits(1)(0).featureType == Categorical)
175+
assert(splits(1)(0).categories.length == 1)
176+
assert(splits(1)(0).categories.contains(0.0))
177+
178+
assert(splits(1)(1).feature == 1)
179+
assert(splits(1)(1).threshold == Double.MinValue)
180+
assert(splits(1)(1).featureType == Categorical)
181+
assert(splits(1)(1).categories.length == 2)
182+
assert(splits(1)(1).categories.contains(1.0))
183+
assert(splits(1)(1).categories.contains(0.0))
184+
185+
assert(splits(1)(2).feature == 1)
186+
assert(splits(1)(2).threshold == Double.MinValue)
187+
assert(splits(1)(2).featureType == Categorical)
188+
assert(splits(1)(2).categories.length == 3)
189+
assert(splits(1)(2).categories.contains(1.0))
190+
assert(splits(1)(2).categories.contains(0.0))
191+
assert(splits(1)(2).categories.contains(2.0))
192+
193+
assert(splits(1)(3) == null)
194+
195+
196+
// Checks bins
197+
198+
assert(bins(0)(0).category == 1.0)
199+
assert(bins(0)(0).lowSplit.categories.length == 0)
200+
assert(bins(0)(0).highSplit.categories.length == 1)
201+
assert(bins(0)(0).highSplit.categories.contains(1.0))
202+
203+
assert(bins(0)(1).category == 0.0)
204+
assert(bins(0)(1).lowSplit.categories.length == 1)
205+
assert(bins(0)(1).lowSplit.categories.contains(1.0))
206+
assert(bins(0)(1).highSplit.categories.length == 2)
207+
assert(bins(0)(1).highSplit.categories.contains(1.0))
208+
assert(bins(0)(1).highSplit.categories.contains(0.0))
209+
210+
assert(bins(0)(2).category == 2.0)
211+
assert(bins(0)(2).lowSplit.categories.length == 2)
212+
assert(bins(0)(2).lowSplit.categories.contains(1.0))
213+
assert(bins(0)(2).lowSplit.categories.contains(0.0))
214+
assert(bins(0)(2).highSplit.categories.length == 3)
215+
assert(bins(0)(2).highSplit.categories.contains(1.0))
216+
assert(bins(0)(2).highSplit.categories.contains(0.0))
217+
assert(bins(0)(2).highSplit.categories.contains(2.0))
218+
219+
assert(bins(0)(3).category == Double.MaxValue)
220+
221+
assert(bins(1)(0).category == 0.0)
222+
assert(bins(1)(0).lowSplit.categories.length == 0)
223+
assert(bins(1)(0).highSplit.categories.length == 1)
224+
assert(bins(1)(0).highSplit.categories.contains(0.0))
225+
226+
assert(bins(1)(1).category == 1.0)
227+
assert(bins(1)(1).lowSplit.categories.length == 1)
228+
assert(bins(1)(1).lowSplit.categories.contains(0.0))
229+
assert(bins(1)(1).highSplit.categories.length == 2)
230+
assert(bins(1)(1).highSplit.categories.contains(0.0))
231+
assert(bins(1)(1).highSplit.categories.contains(1.0))
232+
233+
assert(bins(1)(2).category == 2.0)
234+
assert(bins(1)(2).lowSplit.categories.length == 2)
235+
assert(bins(1)(2).lowSplit.categories.contains(0.0))
236+
assert(bins(1)(2).lowSplit.categories.contains(1.0))
237+
assert(bins(1)(2).highSplit.categories.length == 3)
238+
assert(bins(1)(2).highSplit.categories.contains(0.0))
239+
assert(bins(1)(2).highSplit.categories.contains(1.0))
240+
assert(bins(1)(2).highSplit.categories.contains(2.0))
241+
242+
assert(bins(1)(3).category == Double.MaxValue)
108243

109-
//TODO: Test max feature value > num bins
110244

245+
}
111246

112247
test("classification stump with all categorical variables"){
113248
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
@@ -117,22 +252,41 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
117252
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
118253
strategy.numBins = 100
119254
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
120-
println(bestSplits(0)._1)
121-
println(bestSplits(0)._2)
122-
//TODO: Add asserts
255+
256+
val split = bestSplits(0)._1
257+
assert(split.categories.length == 1)
258+
assert(split.categories.contains(1.0))
259+
assert(split.featureType == Categorical)
260+
assert(split.threshold == Double.MinValue)
261+
262+
val stats = bestSplits(0)._2
263+
assert(stats.gain > 0)
264+
assert(stats.predict > 0.4)
265+
assert(stats.predict < 0.5)
266+
assert(stats.impurity > 0.2)
267+
123268
}
124269

125270
test("regression stump with all categorical variables"){
126271
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
127272
assert(arr.length == 1000)
128273
val rdd = sc.parallelize(arr)
129-
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
274+
val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
130275
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
131276
strategy.numBins = 100
132277
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
133-
println(bestSplits(0)._1)
134-
println(bestSplits(0)._2)
135-
//TODO: Add asserts
278+
279+
val split = bestSplits(0)._1
280+
assert(split.categories.length == 1)
281+
assert(split.categories.contains(1.0))
282+
assert(split.featureType == Categorical)
283+
assert(split.threshold == Double.MinValue)
284+
285+
val stats = bestSplits(0)._2
286+
assert(stats.gain > 0)
287+
assert(stats.predict > 0.4)
288+
assert(stats.predict < 0.5)
289+
assert(stats.impurity > 0.2)
136290
}
137291

138292

@@ -157,7 +311,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
157311
assert(0==bestSplits(0)._2.gain)
158312
assert(0==bestSplits(0)._2.leftImpurity)
159313
assert(0==bestSplits(0)._2.rightImpurity)
160-
println(bestSplits(0)._2.predict)
314+
161315
}
162316

163317
test("stump with fixed label 1 for Gini"){

0 commit comments

Comments
 (0)