Skip to content

Commit 424c770

Browse files
author
eugenp
committed
cleanup work
1 parent cddb1ed commit 424c770

File tree

2 files changed

+19
-25
lines changed

2 files changed

+19
-25
lines changed

spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,24 @@
2222
import com.google.common.io.Files;
2323

2424
public class RedditClassifier {
25-
2625
public static int GOOD = 0;
2726
public static int BAD = 1;
28-
public static int MIN_SCORE = 10;
27+
public static int MIN_SCORE = 7;
28+
29+
private final int[] trainCount = { 0, 0 };
30+
private final int[] evalCount = { 0, 0 };
31+
private final int[] correctCount = { 0, 0 };
2932

3033
private final AdaptiveLogisticRegression classifier;
3134
private final FeatureVectorEncoder titleEncoder;
3235
private final FeatureVectorEncoder domainEncoder;
33-
private CrossFoldLearner learner;
3436
private final int noOfFeatures;
35-
private double accuracy;
36-
37-
private final int[] trainCount = { 0, 0 };
3837

39-
private final int[] evalCount = { 0, 0 };
40-
41-
private final int[] correctCount = { 0, 0 };
38+
private CrossFoldLearner learner;
39+
private double accuracy;
4240

4341
public RedditClassifier() {
44-
noOfFeatures = 1000;
45-
classifier = new AdaptiveLogisticRegression(2, 1000, new L2());
46-
classifier.setPoolSize(150);
47-
titleEncoder = new AdaptiveWordValueEncoder("title");
48-
titleEncoder.setProbes(2);
49-
domainEncoder = new StaticWordValueEncoder("domain");
50-
domainEncoder.setProbes(1);
42+
this(150, 1000);
5143
}
5244

5345
public RedditClassifier(final int poolSize, final int noOfFeatures) {
@@ -60,6 +52,8 @@ public RedditClassifier(final int poolSize, final int noOfFeatures) {
6052
domainEncoder.setProbes(1);
6153
}
6254

55+
// API
56+
6357
public void trainClassifier(final String fileName) throws IOException {
6458
final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
6559
final int size = vectors.size();
@@ -151,25 +145,25 @@ private NamedVector extractVector(final String line) {
151145
final String title = items[3];
152146
final String theRootDomain = items[4];
153147

154-
final String category = extractCategory(Integer.parseInt(numberOfVotes));
155-
156-
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(noOfFeatures), category);
148+
final RandomAccessSparseVector internalVector = new RandomAccessSparseVector(noOfFeatures);
157149

158150
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
159151
cal.setTimeInMillis(Long.parseLong(time) * 1000);
160-
vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
152+
internalVector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day
161153

162-
vector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
154+
internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
163155

164-
domainEncoder.addToVector(theRootDomain, vector);
156+
domainEncoder.addToVector(theRootDomain, internalVector);
165157
final String[] words = title.split(" ");
166158
// titleEncoder.setProbes(words.length);
167159

168160
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
169161
for (final String word : words) {
170-
titleEncoder.addToVector(word, vector);
162+
titleEncoder.addToVector(word, internalVector);
171163
}
172-
return vector;
164+
165+
final String category = extractCategory(Integer.parseInt(numberOfVotes));
166+
return new NamedVector(internalVector, category);
173167
}
174168

175169
private String extractCategory(final int score) {

spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccura
3333

3434
@Test
3535
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
36-
final RedditClassifier classifier = new RedditClassifier(200, 2000);
36+
final RedditClassifier classifier = new RedditClassifier(250, 2500);
3737
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
3838
final double result = classifier.getAccuracy();
3939
System.out.println("==== Custom Classifier (large) Accuracy = " + result);

0 commit comments

Comments
 (0)