Skip to content

Commit be841b1

Browse files
author
Eugen
committed
Merge pull request eugenp#195 from Doha2012/master
fix data collector
2 parents 114bf97 + df215fc commit be841b1

File tree

10 files changed

+8159
-8133
lines changed

10 files changed

+8159
-8133
lines changed

spring-security-oauth/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@
185185
<artifactId>guava</artifactId>
186186
<version>${guava.version}</version>
187187
</dependency>
188-
188+
189189
<!-- logging -->
190190

191191
<dependency>

spring-security-oauth/src/main/java/org/baeldung/config/ServletInitializer.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@ protected WebApplicationContext createRootApplicationContext() {
3030
@Override
3131
public void onStartup(ServletContext servletContext) throws ServletException {
3232
super.onStartup(servletContext);
33-
3433
servletContext.addListener(new SessionListener());
3534
registerProxyFilter(servletContext, "oauth2ClientContextFilter");
3635
registerProxyFilter(servletContext, "springSecurityFilterChain");
37-
3836
}
3937

4038
private void registerProxyFilter(ServletContext servletContext, String name) {

spring-security-oauth/src/main/java/org/baeldung/config/SessionListener.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import javax.servlet.http.HttpSessionEvent;
44
import javax.servlet.http.HttpSessionListener;
55

6+
import org.baeldung.reddit.util.MyFeatures;
67
import org.slf4j.Logger;
78
import org.slf4j.LoggerFactory;
89

@@ -13,6 +14,7 @@ public class SessionListener implements HttpSessionListener {
1314
public void sessionCreated(HttpSessionEvent event) {
1415
logger.info("==== Session is created ====");
1516
event.getSession().setMaxInactiveInterval(30 * 60);
17+
event.getSession().setAttribute("PREDICTION_FEATURE", MyFeatures.PREDICTION_FEATURE);
1618
}
1719

1820
@Override

spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import java.io.BufferedReader;
44
import java.io.FileReader;
55
import java.io.IOException;
6+
import java.util.Calendar;
7+
import java.util.TimeZone;
68

79
import org.apache.mahout.classifier.sgd.L2;
810
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
@@ -18,6 +20,7 @@ public class RedditClassifier {
1820

1921
public static int GOOD = 0;
2022
public static int BAD = 1;
23+
public static int MIN_SCORE = 5;
2124
private final OnlineLogisticRegression classifier;
2225
private final FeatureVectorEncoder titleEncoder;
2326
private final FeatureVectorEncoder domainEncoder;
@@ -44,7 +47,7 @@ public void trainClassifier(String fileName) throws IOException {
4447
}
4548

4649
while ((line != null) && (line != "")) {
47-
category = (line.startsWith("good")) ? GOOD : BAD;
50+
category = extractCategory(line);
4851
trainCount[category]++;
4952
features = convertLineToVector(line);
5053
classifier.train(category, features);
@@ -76,7 +79,7 @@ public double evaluateClassifier() throws IOException {
7679
Vector features;
7780
String line = reader.readLine();
7881
while ((line != null) && (line != "")) {
79-
category = (line.startsWith("good")) ? GOOD : BAD;
82+
category = extractCategory(line);
8083
evalCount[category]++;
8184
features = convertLineToVector(line);
8285
result = classify(features);
@@ -94,12 +97,21 @@ public double evaluateClassifier() throws IOException {
9497
}
9598

9699
// ==== private
100+
private int extractCategory(String line) {
101+
final int score = Integer.parseInt(line.substring(0, line.indexOf(';')));
102+
return (score < MIN_SCORE) ? BAD : GOOD;
103+
}
104+
97105
private Vector convertLineToVector(String line) {
98106
final Vector features = new RandomAccessSparseVector(4);
99107
final String[] items = line.split(";");
108+
final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));
109+
cal.setTimeInMillis(Long.parseLong(items[1]) * 1000);
110+
final int hour = cal.get(Calendar.HOUR_OF_DAY);
111+
100112
titleEncoder.addToVector(items[3], features);
101113
domainEncoder.addToVector(items[4], features);
102-
features.set(2, Integer.parseInt(items[1])); // hour of day
114+
features.set(2, hour); // hour of day
103115
features.set(3, Integer.parseInt(items[2])); // number of words in the title
104116
return features;
105117
}

spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditDataCollector.java

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
import java.io.FileWriter;
44
import java.io.IOException;
5-
import java.text.SimpleDateFormat;
65
import java.util.ArrayList;
7-
import java.util.Date;
86
import java.util.List;
97

108
import org.baeldung.reddit.util.UserAgentInterceptor;
@@ -20,57 +18,52 @@
2018
public class RedditDataCollector {
2119
public static final String TRAINING_FILE = "src/main/resources/train.csv";
2220
public static final String TEST_FILE = "src/main/resources/test.csv";
21+
public static final int LIMIT = 100;
22+
public static final Long YEAR = 31536000L;
2323
private final Logger logger = LoggerFactory.getLogger(getClass());
2424

25-
private String postAfter;
25+
private Long timestamp;
2626
private final RestTemplate restTemplate;
2727
private final String subreddit;
28-
private final int minScore;
2928

3029
public RedditDataCollector() {
3130
restTemplate = new RestTemplate();
3231
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
3332
list.add(new UserAgentInterceptor());
3433
restTemplate.setInterceptors(list);
35-
subreddit = "all";
36-
minScore = 4;
34+
subreddit = "java";
3735
}
3836

39-
public RedditDataCollector(String subreddit, int minScore) {
37+
public RedditDataCollector(String subreddit) {
4038
restTemplate = new RestTemplate();
4139
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
4240
list.add(new UserAgentInterceptor());
4341
restTemplate.setInterceptors(list);
4442
this.subreddit = subreddit;
45-
this.minScore = minScore;
4643
}
4744

4845
public void collectData() {
49-
final int limit = 100;
5046
final int noOfRounds = 80;
47+
timestamp = System.currentTimeMillis() / 1000;
5148
try {
5249
final FileWriter writer = new FileWriter(TRAINING_FILE);
5350
for (int i = 0; i < noOfRounds; i++) {
54-
getPosts(limit, writer);
51+
getPosts(writer);
5552
}
5653
writer.close();
5754

5855
final FileWriter testWriter = new FileWriter(TEST_FILE);
59-
getPosts(limit, testWriter);
56+
getPosts(testWriter);
6057
testWriter.close();
6158
} catch (final Exception e) {
6259
logger.error("write to file error", e);
6360
}
6461
}
6562

66-
// ==== private
67-
68-
private void getPosts(int limit, FileWriter writer) {
69-
String fullUrl = "http://www.reddit.com/r/" + subreddit + "/new.json?limit=" + limit;
70-
if (postAfter != null) {
71-
fullUrl += "&count=" + limit + "&after=" + postAfter;
72-
}
63+
// ==== Private
7364

65+
private void getPosts(FileWriter writer) {
66+
final String fullUrl = "http://www.reddit.com/r/" + subreddit + "/search.json?sort=new&q=timestamp:" + (timestamp - YEAR) + ".." + timestamp + "&restrict_sr=on&syntax=cloudsearch&limit=" + LIMIT;
7467
try {
7568
final JsonNode node = restTemplate.getForObject(fullUrl, JsonNode.class);
7669
parseNode(node, writer);
@@ -82,22 +75,18 @@ private void getPosts(int limit, FileWriter writer) {
8275
}
8376

8477
private void parseNode(JsonNode node, FileWriter writer) throws IOException {
85-
postAfter = node.get("data").get("after").asText();
86-
System.out.println(postAfter);
8778
String line;
88-
String category;
8979
List<String> words;
90-
final SimpleDateFormat df = new SimpleDateFormat("HH");
80+
int score;
9181
for (final JsonNode child : node.get("data").get("children")) {
92-
category = (child.get("data").get("score").asInt() < minScore) ? "bad" : "good";
82+
score = child.get("data").get("score").asInt();
9383
words = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(child.get("data").get("title").asText());
94-
final Date date = new Date(child.get("data").get("created_utc").asLong() * 1000);
84+
timestamp = child.get("data").get("created_utc").asLong();
9585

96-
line = category + ";";
97-
line += df.format(date) + ";";
86+
line = score + ";";
87+
line += timestamp + ";";
9888
line += words.size() + ";" + Joiner.on(' ').join(words) + ";";
9989
line += child.get("data").get("domain").asText() + "\n";
100-
10190
writer.write(line);
10291
}
10392
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package org.baeldung.reddit.util;
2+
3+
public enum MyFeatures {
4+
5+
PREDICTION_FEATURE(false);
6+
7+
private boolean active;
8+
9+
private MyFeatures(boolean active) {
10+
this.active = active;
11+
}
12+
13+
public boolean isActive() {
14+
return this.active;
15+
}
16+
17+
}

0 commit comments

Comments
 (0)