Skip to content

Commit dc1bcbd

Browse files
committed
modify reddit classifier test
1 parent 91ea275 commit dc1bcbd

File tree

7 files changed

+12419
-8126
lines changed

7 files changed

+12419
-8126
lines changed

spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public ScheduledTasks scheduledTasks(OAuth2ProtectedResourceDetails reddit) {
8686

8787
@Bean
8888
public RedditClassifier redditClassifier() throws IOException {
89-
final Resource file = new ClassPathResource("train.csv");
89+
final Resource file = new ClassPathResource("data.csv");
9090
final RedditClassifier redditClassifier = new RedditClassifier();
9191
redditClassifier.trainClassifier(file.getFile().getAbsolutePath());
9292
return redditClassifier;

spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ public class RedditClassifier {
2626
public static int GOOD = 0;
2727
public static int BAD = 1;
2828
public static int MIN_SCORE = 10;
29-
public static int NUM_OF_FEATURES = 1000;
3029

3130
private final AdaptiveLogisticRegression classifier;
3231
private final FeatureVectorEncoder titleEncoder;
3332
private final FeatureVectorEncoder domainEncoder;
3433
private CrossFoldLearner learner;
34+
private final int noOfFeatures;
3535
private double accuracy;
3636

3737
private final int[] trainCount = { 0, 0 };
@@ -41,31 +41,44 @@ public class RedditClassifier {
4141
private final int[] correctCount = { 0, 0 };
4242

4343
public RedditClassifier() {
44-
classifier = new AdaptiveLogisticRegression(2, NUM_OF_FEATURES, new L2());
44+
noOfFeatures = 1000;
45+
classifier = new AdaptiveLogisticRegression(2, 1000, new L2());
4546
classifier.setPoolSize(150);
4647
titleEncoder = new AdaptiveWordValueEncoder("title");
4748
titleEncoder.setProbes(2);
4849
domainEncoder = new StaticWordValueEncoder("domain");
4950
domainEncoder.setProbes(1);
5051
}
5152

53+
public RedditClassifier(int poolSize, int noOfFeatures) {
54+
this.noOfFeatures = noOfFeatures;
55+
classifier = new AdaptiveLogisticRegression(2, noOfFeatures, new L2());
56+
classifier.setPoolSize(poolSize);
57+
titleEncoder = new AdaptiveWordValueEncoder("title");
58+
titleEncoder.setProbes(1);
59+
domainEncoder = new StaticWordValueEncoder("domain");
60+
domainEncoder.setProbes(1);
61+
}
62+
5263
public void trainClassifier(String fileName) throws IOException {
5364
final List<NamedVector> vectors = extractVectors(readDataFile(fileName));
54-
final int noOfTraining = (int) (RedditDataCollector.DATA_SIZE * 0.8);
65+
final int size = vectors.size();
66+
final int noOfTraining = (int) (size * 0.8);
5567
final List<NamedVector> trainingData = vectors.subList(0, noOfTraining);
56-
final List<NamedVector> testData = vectors.subList(noOfTraining, RedditDataCollector.DATA_SIZE);
68+
final List<NamedVector> testData = vectors.subList(noOfTraining, size);
5769
int category;
5870
for (final NamedVector vector : trainingData) {
5971
category = (vector.getName() == "GOOD") ? GOOD : BAD;
6072
classifier.train(category, vector);
6173
trainCount[category]++;
6274
}
6375
System.out.println("Training count ========= Good = " + trainCount[0] + " ___ Bad = " + trainCount[1]);
76+
System.out.println("----------------------------------------------------------------- \n");
6477
evaluateClassifier(testData);
6578
}
6679

6780
public Vector convertPost(String title, String domain, int hour) {
68-
final Vector vector = new RandomAccessSparseVector(NUM_OF_FEATURES);
81+
final Vector vector = new RandomAccessSparseVector(noOfFeatures);
6982
final List<String> words = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title);
7083
vector.set(0, hour);
7184
vector.set(1, words.size());
@@ -105,10 +118,10 @@ private void evaluateClassifier(List<NamedVector> vectors) throws IOException {
105118
wrong++;
106119
}
107120
}
108-
System.out.println("Eval count ========= Good = " + evalCount[0] + " ___ Bad = " + evalCount[1]);
109-
System.out.println("Test result ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong);
110-
System.out.println("Test result ======== Correct Good = " + correctCount[0] + " ----- Correct Bad = " + correctCount[1]);
111-
System.out.println("Test result ======== Good accuracy = " + (correctCount[0] / (evalCount[0] + 0.0)) + " ----- Bad accuracy = " + (correctCount[1] / (evalCount[1] + 0.0)));
121+
System.out.println("Eval count =================== Good = " + evalCount[0] + " ----- Bad = " + evalCount[1] + "\n");
122+
System.out.println("Overall Evaluation ============= Correct prediction = " + correct + " ----- Wrong prediction = " + wrong);
123+
System.out.println("Correctly Evaluated =========== Correct Good = " + correctCount[0] + " ----- Correct Bad = " + correctCount[1]);
124+
System.out.println("Correctly Evaluated (%) ======== Good accuracy = " + (correctCount[0] / (evalCount[0] + 0.0)) + " ----- Bad accuracy = " + (correctCount[1] / (evalCount[1] + 0.0)));
112125
this.accuracy = correct / (wrong + correct + 0.0);
113126
}
114127

@@ -133,7 +146,7 @@ private List<NamedVector> extractVectors(List<String> lines) {
133146
private NamedVector extractVector(String line) {
134147
final String[] items = line.split(",");
135148
final String category = extractCategory(Integer.parseInt(items[0]));
136-
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(NUM_OF_FEATURES), category);
149+
final NamedVector vector = new NamedVector(new RandomAccessSparseVector(noOfFeatures), category);
137150
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
138151
cal.setTimeInMillis(Long.parseLong(items[1]) * 1000);
139152

spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import com.google.common.base.Splitter;
1717

1818
public class RedditDataCollector {
19-
public static final String TRAINING_FILE = "src/main/resources/train.csv";
20-
public static final int DATA_SIZE = 8000;
19+
public static final String DATA_FILE = "src/main/resources/data.csv";
20+
public static final int DATA_SIZE = 20000;
2121
public static final int LIMIT = 100;
2222
public static final Long YEAR = 31536000L;
2323
private final Logger logger = LoggerFactory.getLogger(getClass());
@@ -45,10 +45,11 @@ public RedditDataCollector(String subreddit) {
4545
public void collectData() throws IOException {
4646
final int noOfRounds = DATA_SIZE / LIMIT;
4747
timestamp = System.currentTimeMillis() / 1000;
48-
final FileWriter writer = new FileWriter(TRAINING_FILE);
48+
final FileWriter writer = new FileWriter(DATA_FILE);
4949
writer.write("Score, Timestamp in utc, Number of wrods in title, Title, Domain \n");
5050
for (int i = 0; i < noOfRounds; i++) {
5151
getPosts(writer);
52+
System.out.println(i);
5253
}
5354
writer.close();
5455
}

0 commit comments

Comments
 (0)