2222import com .google .common .io .Files ;
2323
2424public class RedditClassifier {
25-
2625 public static int GOOD = 0 ;
2726 public static int BAD = 1 ;
28- public static int MIN_SCORE = 10 ;
27+ public static int MIN_SCORE = 7 ;
28+
29+ private final int [] trainCount = { 0 , 0 };
30+ private final int [] evalCount = { 0 , 0 };
31+ private final int [] correctCount = { 0 , 0 };
2932
3033 private final AdaptiveLogisticRegression classifier ;
3134 private final FeatureVectorEncoder titleEncoder ;
3235 private final FeatureVectorEncoder domainEncoder ;
33- private CrossFoldLearner learner ;
3436 private final int noOfFeatures ;
35- private double accuracy ;
36-
37- private final int [] trainCount = { 0 , 0 };
3837
39- private final int [] evalCount = { 0 , 0 };
40-
41- private final int [] correctCount = { 0 , 0 };
38+ private CrossFoldLearner learner ;
39+ private double accuracy ;
4240
4341 public RedditClassifier () {
44- noOfFeatures = 1000 ;
45- classifier = new AdaptiveLogisticRegression (2 , 1000 , new L2 ());
46- classifier .setPoolSize (150 );
47- titleEncoder = new AdaptiveWordValueEncoder ("title" );
48- titleEncoder .setProbes (2 );
49- domainEncoder = new StaticWordValueEncoder ("domain" );
50- domainEncoder .setProbes (1 );
42+ this (150 , 1000 );
5143 }
5244
5345 public RedditClassifier (final int poolSize , final int noOfFeatures ) {
@@ -60,6 +52,8 @@ public RedditClassifier(final int poolSize, final int noOfFeatures) {
6052 domainEncoder .setProbes (1 );
6153 }
6254
55+ // API
56+
6357 public void trainClassifier (final String fileName ) throws IOException {
6458 final List <NamedVector > vectors = extractVectors (readDataFile (fileName ));
6559 final int size = vectors .size ();
@@ -151,25 +145,25 @@ private NamedVector extractVector(final String line) {
151145 final String title = items [3 ];
152146 final String theRootDomain = items [4 ];
153147
154- final String category = extractCategory (Integer .parseInt (numberOfVotes ));
155-
156- final NamedVector vector = new NamedVector (new RandomAccessSparseVector (noOfFeatures ), category );
148+ final RandomAccessSparseVector internalVector = new RandomAccessSparseVector (noOfFeatures );
157149
158150 final Calendar cal = Calendar .getInstance (TimeZone .getTimeZone ("GMT" ));
159151 cal .setTimeInMillis (Long .parseLong (time ) * 1000 );
160- vector .set (0 , cal .get (Calendar .HOUR_OF_DAY )); // hour of day
152+ internalVector .set (0 , cal .get (Calendar .HOUR_OF_DAY )); // hour of day
161153
162- vector .set (1 , Integer .parseInt (numberOfWordInTitle )); // number of words in the title
154+ internalVector .set (1 , Integer .parseInt (numberOfWordInTitle )); // number of words in the title
163155
164- domainEncoder .addToVector (theRootDomain , vector );
156+ domainEncoder .addToVector (theRootDomain , internalVector );
165157 final String [] words = title .split (" " );
166158 // titleEncoder.setProbes(words.length);
167159
168160 // TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
169161 for (final String word : words ) {
170- titleEncoder .addToVector (word , vector );
162+ titleEncoder .addToVector (word , internalVector );
171163 }
172- return vector ;
164+
165+ final String category = extractCategory (Integer .parseInt (numberOfVotes ));
166+ return new NamedVector (internalVector , category );
173167 }
174168
175169 private String extractCategory (final int score ) {
0 commit comments