Skip to content

Commit 42369da

Browse files
committed
update classifier
1 parent cd4927c commit 42369da

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class RedditClassifier {
2525

2626
public static int GOOD = 0;
2727
public static int BAD = 1;
28-
public static int MIN_SCORE = 7;
28+
public static int MIN_SCORE = 10;
2929
public static int NUM_OF_FEATURES = 1000;
3030

3131
private final AdaptiveLogisticRegression classifier;
@@ -42,9 +42,9 @@ public class RedditClassifier {
4242

4343
public RedditClassifier() {
4444
classifier = new AdaptiveLogisticRegression(2, NUM_OF_FEATURES, new L2());
45-
classifier.setPoolSize(50);
45+
classifier.setPoolSize(150);
4646
titleEncoder = new AdaptiveWordValueEncoder("title");
47-
titleEncoder.setProbes(1);
47+
titleEncoder.setProbes(2);
4848
domainEncoder = new StaticWordValueEncoder("domain");
4949
domainEncoder.setProbes(1);
5050
}
@@ -65,13 +65,15 @@ public void trainClassifier(String fileName) throws IOException {
6565
}
6666

6767
public Vector convertPost(String title, String domain, int hour) {
68-
final Vector features = new RandomAccessSparseVector(4);
69-
final int noOfWords = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title).size();
70-
titleEncoder.addToVector(title, features);
71-
domainEncoder.addToVector(domain, features);
72-
features.set(2, hour);
73-
features.set(3, noOfWords);
74-
return features;
68+
final Vector vector = new RandomAccessSparseVector(NUM_OF_FEATURES);
69+
final List<String> words = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title);
70+
vector.set(0, hour);
71+
vector.set(1, words.size());
72+
domainEncoder.addToVector(domain, vector);
73+
for (final String word : words) {
74+
titleEncoder.addToVector(word, vector);
75+
}
76+
return vector;
7577
}
7678

7779
public int classify(Vector features) {
@@ -106,6 +108,7 @@ private void evaluateClassifier(List<NamedVector> vectors) throws IOException {
106108
System.out.println("Eval count ========= Good = " + evalCount[0] + " ___ Bad = " + evalCount[1]);
107109
System.out.println("Test result ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong);
108110
System.out.println("Test result ======== Correct Good = " + correctCount[0] + " ----- Correct Bad = " + correctCount[1]);
111+
System.out.println("Test result ======== Good accuracy = " + (correctCount[0] / (evalCount[0] + 0.0)) + " ----- Bad accuracy = " + (correctCount[1] / (evalCount[1] + 0.0)));
109112
this.accuracy = correct / (wrong + correct + 0.0);
110113
}
111114

0 commit comments

Comments
 (0)