@@ -26,12 +26,12 @@ public class RedditClassifier {
2626 public static int GOOD = 0 ;
2727 public static int BAD = 1 ;
2828 public static int MIN_SCORE = 10 ;
29- public static int NUM_OF_FEATURES = 1000 ;
3029
3130 private final AdaptiveLogisticRegression classifier ;
3231 private final FeatureVectorEncoder titleEncoder ;
3332 private final FeatureVectorEncoder domainEncoder ;
3433 private CrossFoldLearner learner ;
34+ private final int noOfFeatures ;
3535 private double accuracy ;
3636
3737 private final int [] trainCount = { 0 , 0 };
@@ -41,31 +41,44 @@ public class RedditClassifier {
4141 private final int [] correctCount = { 0 , 0 };
4242
4343 public RedditClassifier () {
44- classifier = new AdaptiveLogisticRegression (2 , NUM_OF_FEATURES , new L2 ());
44+ noOfFeatures = 1000 ;
45+ classifier = new AdaptiveLogisticRegression (2 , 1000 , new L2 ());
4546 classifier .setPoolSize (150 );
4647 titleEncoder = new AdaptiveWordValueEncoder ("title" );
4748 titleEncoder .setProbes (2 );
4849 domainEncoder = new StaticWordValueEncoder ("domain" );
4950 domainEncoder .setProbes (1 );
5051 }
5152
53+ public RedditClassifier (int poolSize , int noOfFeatures ) {
54+ this .noOfFeatures = noOfFeatures ;
55+ classifier = new AdaptiveLogisticRegression (2 , noOfFeatures , new L2 ());
56+ classifier .setPoolSize (poolSize );
57+ titleEncoder = new AdaptiveWordValueEncoder ("title" );
58+ titleEncoder .setProbes (1 );
59+ domainEncoder = new StaticWordValueEncoder ("domain" );
60+ domainEncoder .setProbes (1 );
61+ }
62+
5263 public void trainClassifier (String fileName ) throws IOException {
5364 final List <NamedVector > vectors = extractVectors (readDataFile (fileName ));
54- final int noOfTraining = (int ) (RedditDataCollector .DATA_SIZE * 0.8 );
65+ final int size = vectors .size ();
66+ final int noOfTraining = (int ) (size * 0.8 );
5567 final List <NamedVector > trainingData = vectors .subList (0 , noOfTraining );
56- final List <NamedVector > testData = vectors .subList (noOfTraining , RedditDataCollector . DATA_SIZE );
68+ final List <NamedVector > testData = vectors .subList (noOfTraining , size );
5769 int category ;
5870 for (final NamedVector vector : trainingData ) {
5971 category = (vector .getName () == "GOOD" ) ? GOOD : BAD ;
6072 classifier .train (category , vector );
6173 trainCount [category ]++;
6274 }
6375 System .out .println ("Training count ========= Good = " + trainCount [0 ] + " ___ Bad = " + trainCount [1 ]);
76+ System .out .println ("----------------------------------------------------------------- \n " );
6477 evaluateClassifier (testData );
6578 }
6679
6780 public Vector convertPost (String title , String domain , int hour ) {
68- final Vector vector = new RandomAccessSparseVector (NUM_OF_FEATURES );
81+ final Vector vector = new RandomAccessSparseVector (noOfFeatures );
6982 final List <String > words = Splitter .onPattern ("\\ W" ).omitEmptyStrings ().splitToList (title );
7083 vector .set (0 , hour );
7184 vector .set (1 , words .size ());
@@ -105,10 +118,10 @@ private void evaluateClassifier(List<NamedVector> vectors) throws IOException {
105118 wrong ++;
106119 }
107120 }
108- System .out .println ("Eval count ========= Good = " + evalCount [0 ] + " ___ Bad = " + evalCount [1 ]);
109- System .out .println ("Test result ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong );
110- System .out .println ("Test result ======== Correct Good = " + correctCount [0 ] + " ----- Correct Bad = " + correctCount [1 ]);
111- System .out .println ("Test result ======== Good accuracy = " + (correctCount [0 ] / (evalCount [0 ] + 0.0 )) + " ----- Bad accuracy = " + (correctCount [1 ] / (evalCount [1 ] + 0.0 )));
121+ System .out .println ("Eval count =================== Good = " + evalCount [0 ] + " ----- Bad = " + evalCount [1 ] + " \n " );
122+ System .out .println ("Overall Evaluation ===== ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong );
123+ System .out .println ("Correctly Evaluated === ======== Correct Good = " + correctCount [0 ] + " ----- Correct Bad = " + correctCount [1 ]);
124+ System .out .println ("Correctly Evaluated (%) ======== Good accuracy = " + (correctCount [0 ] / (evalCount [0 ] + 0.0 )) + " ----- Bad accuracy = " + (correctCount [1 ] / (evalCount [1 ] + 0.0 )));
112125 this .accuracy = correct / (wrong + correct + 0.0 );
113126 }
114127
@@ -133,7 +146,7 @@ private List<NamedVector> extractVectors(List<String> lines) {
133146 private NamedVector extractVector (String line ) {
134147 final String [] items = line .split ("," );
135148 final String category = extractCategory (Integer .parseInt (items [0 ]));
136- final NamedVector vector = new NamedVector (new RandomAccessSparseVector (NUM_OF_FEATURES ), category );
149+ final NamedVector vector = new NamedVector (new RandomAccessSparseVector (noOfFeatures ), category );
137150 final Calendar cal = Calendar .getInstance (TimeZone .getTimeZone ("GMT" ));
138151 cal .setTimeInMillis (Long .parseLong (items [1 ]) * 1000 );
139152
0 commit comments