Skip to content

Commit 799c612

Browse files
author
Eugen
committed
Merge pull request eugenp#200 from Doha2012/master
minor changes to reddit classifier
2 parents 424c770 + 3ffe842 commit 799c612

File tree

3 files changed

+10
-19
lines changed

3 files changed

+10
-19
lines changed

spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
public class RedditClassifier {
2525
public static int GOOD = 0;
2626
public static int BAD = 1;
27-
public static int MIN_SCORE = 7;
2827

2928
private final int[] trainCount = { 0, 0 };
3029
private final int[] evalCount = { 0, 0 };
@@ -34,15 +33,16 @@ public class RedditClassifier {
3433
private final FeatureVectorEncoder titleEncoder;
3534
private final FeatureVectorEncoder domainEncoder;
3635
private final int noOfFeatures;
37-
36+
private final int minScore;
3837
private CrossFoldLearner learner;
3938
private double accuracy;
4039

4140
public RedditClassifier() {
42-
this(150, 1000);
41+
this(150, 1000, 7);
4342
}
4443

45-
public RedditClassifier(final int poolSize, final int noOfFeatures) {
44+
public RedditClassifier(final int poolSize, final int noOfFeatures, int minScore) {
45+
this.minScore = minScore;
4646
this.noOfFeatures = noOfFeatures;
4747
classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2());
4848
classifier.setPoolSize(poolSize);
@@ -154,20 +154,15 @@ private NamedVector extractVector(final String line) {
154154
internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title
155155

156156
domainEncoder.addToVector(theRootDomain, internalVector);
157-
final String[] words = title.split(" ");
158-
// titleEncoder.setProbes(words.length);
159-
160-
// TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to"
161-
for (final String word : words) {
162-
titleEncoder.addToVector(word, internalVector);
163-
}
157+
final List<String> words = Splitter.on(' ').splitToList(title);
158+
words.stream().filter(word -> word.length() > 2).forEach(word -> titleEncoder.addToVector(word, internalVector));
164159

165160
final String category = extractCategory(Integer.parseInt(numberOfVotes));
166161
return new NamedVector(internalVector, category);
167162
}
168163

169164
private String extractCategory(final int score) {
170-
return (score < MIN_SCORE) ? "BAD" : "GOOD";
165+
return (score < minScore) ? "BAD" : "GOOD";
171166
}
172167

173168
}

spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ public class RedditDataCollector {
2727
private final String subreddit;
2828

2929
public RedditDataCollector() {
30-
restTemplate = new RestTemplate();
31-
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
32-
list.add(new UserAgentInterceptor());
33-
restTemplate.setInterceptors(list);
34-
subreddit = "java";
30+
this("java");
3531
}
3632

3733
public RedditDataCollector(String subreddit) {

spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public void whenUsingDefaultClassifier_thenAccurate() throws IOException {
2323

2424
@Test
2525
public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
26-
final RedditClassifier classifier = new RedditClassifier(100, 500);
26+
final RedditClassifier classifier = new RedditClassifier(100, 500, 7);
2727
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
2828
final double result = classifier.getAccuracy();
2929
System.out.println("==== Custom Classifier (small) Accuracy = " + result);
@@ -33,7 +33,7 @@ public void givenSmallerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccura
3333

3434
@Test
3535
public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException {
36-
final RedditClassifier classifier = new RedditClassifier(250, 2500);
36+
final RedditClassifier classifier = new RedditClassifier(250, 2500, 7);
3737
classifier.trainClassifier(RedditDataCollector.DATA_FILE);
3838
final double result = classifier.getAccuracy();
3939
System.out.println("==== Custom Classifier (large) Accuracy = " + result);

0 commit comments

Comments
 (0)