17
17
18
18
package org .apache .spark .mllib .tree
19
19
20
- import org .apache .spark .SparkContext ._
21
20
import scala .util .control .Breaks ._
21
+ import org .apache .spark .SparkContext ._
22
22
import org .apache .spark .rdd .RDD
23
23
import org .apache .spark .mllib .tree .model ._
24
24
import org .apache .spark .{SparkContext , Logging }
@@ -101,7 +101,6 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
101
101
val decisionTreeModel = {
102
102
return new DecisionTreeModel (topNode, strategy.algo)
103
103
}
104
-
105
104
return decisionTreeModel
106
105
}
107
106
@@ -538,10 +537,10 @@ object DecisionTree extends Serializable with Logging {
538
537
}
539
538
540
539
if (leftCount == 0 ) {
541
- return new InformationGainStats (0 ,topImpurity,Double .MinValue ,topImpurity,1 )
540
+ return new InformationGainStats (0 , topImpurity, Double .MinValue , topImpurity,1 )
542
541
}
543
542
if (rightCount == 0 ) {
544
- return new InformationGainStats (0 ,topImpurity,topImpurity,Double .MinValue ,0 )
543
+ return new InformationGainStats (0 , topImpurity, topImpurity, Double .MinValue ,0 )
545
544
}
546
545
547
546
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
@@ -561,7 +560,7 @@ object DecisionTree extends Serializable with Logging {
561
560
// val predict = leftCount / (leftCount + rightCount)
562
561
val predict = (left1Count + right1Count) / (leftCount + rightCount)
563
562
564
- new InformationGainStats (gain,impurity,leftImpurity,rightImpurity,predict)
563
+ new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict)
565
564
}
566
565
case Regression => {
567
566
val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
@@ -584,12 +583,12 @@ object DecisionTree extends Serializable with Logging {
584
583
}
585
584
586
585
if (leftCount == 0 ) {
587
- return new InformationGainStats (0 ,topImpurity,Double .MinValue ,topImpurity,
586
+ return new InformationGainStats (0 , topImpurity, Double .MinValue , topImpurity,
588
587
rightSum/ rightCount)
589
588
}
590
589
if (rightCount == 0 ) {
591
- return new InformationGainStats (0 ,topImpurity,topImpurity,
592
- Double .MinValue ,leftSum/ leftCount)
590
+ return new InformationGainStats (0 , topImpurity ,topImpurity,
591
+ Double .MinValue , leftSum/ leftCount)
593
592
}
594
593
595
594
val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
@@ -1024,7 +1023,4 @@ object DecisionTree extends Serializable with Logging {
1024
1023
.mean()
1025
1024
meanSumOfSquares
1026
1025
}
1027
-
1028
-
1029
-
1030
1026
}
0 commit comments