Skip to content

Commit fbb574a

Browse files
author
Eugen
committed
Merge pull request eugenp#194 from Doha2012/master
add reddit classifier
2 parents 4448dd6 + eecee8b commit fbb574a

File tree

10 files changed

+8442
-3
lines changed

10 files changed

+8442
-3
lines changed

spring-security-oauth/pom.xml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,14 @@
156156
<scope>runtime</scope>
157157
</dependency>
158158

159-
159+
<!-- apache mahout -->
160+
161+
<dependency>
162+
<groupId>org.apache.mahout</groupId>
163+
<artifactId>mahout-core</artifactId>
164+
<version>0.9</version>
165+
</dependency>
166+
160167
<!-- marshalling -->
161168

162169
<dependency>

spring-security-oauth/src/main/java/org/baeldung/config/WebConfig.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
package org.baeldung.config;
22

3+
import java.io.IOException;
4+
import java.util.ArrayList;
35
import java.util.Arrays;
6+
import java.util.List;
47

8+
import org.baeldung.reddit.classifier.RedditClassifier;
9+
import org.baeldung.reddit.util.UserAgentInterceptor;
510
import org.baeldung.web.schedule.ScheduledTasks;
11+
import org.springframework.beans.factory.annotation.Autowired;
612
import org.springframework.beans.factory.annotation.Value;
713
import org.springframework.context.annotation.Bean;
814
import org.springframework.context.annotation.ComponentScan;
915
import org.springframework.context.annotation.Configuration;
1016
import org.springframework.context.annotation.PropertySource;
1117
import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
18+
import org.springframework.core.env.Environment;
19+
import org.springframework.core.io.ClassPathResource;
20+
import org.springframework.core.io.Resource;
21+
import org.springframework.http.client.ClientHttpRequestInterceptor;
1222
import org.springframework.scheduling.annotation.EnableAsync;
1323
import org.springframework.scheduling.annotation.EnableScheduling;
1424
import org.springframework.security.oauth2.client.OAuth2ClientContext;
@@ -36,6 +46,9 @@
3646
@ComponentScan({ "org.baeldung.web" })
3747
public class WebConfig extends WebMvcConfigurerAdapter {
3848

49+
@Autowired
50+
private Environment env;
51+
3952
@Bean
4053
public static PropertySourcesPlaceholderConfigurer propertySourcesPlaceholderConfigurer() {
4154
return new PropertySourcesPlaceholderConfigurer();
@@ -63,10 +76,22 @@ public void addViewControllers(final ViewControllerRegistry registry) {
6376
@Bean
6477
public ScheduledTasks scheduledTasks(OAuth2ProtectedResourceDetails reddit) {
6578
final ScheduledTasks s = new ScheduledTasks();
66-
s.setRedditRestTemplate(new OAuth2RestTemplate(reddit));
79+
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
80+
list.add(new UserAgentInterceptor());
81+
final OAuth2RestTemplate restTemplate = new OAuth2RestTemplate(reddit);
82+
restTemplate.setInterceptors(list);
83+
s.setRedditRestTemplate(restTemplate);
6784
return s;
6885
}
6986

87+
@Bean
88+
public RedditClassifier redditClassifier() throws IOException {
89+
final Resource file = new ClassPathResource("train.csv");
90+
final RedditClassifier redditClassifier = new RedditClassifier();
91+
redditClassifier.trainClassifier(file.getFile().getAbsolutePath());
92+
return redditClassifier;
93+
}
94+
7095
@Override
7196
public void addResourceHandlers(ResourceHandlerRegistry registry) {
7297
registry.addResourceHandler("/resources/**").addResourceLocations("/resources/");
@@ -108,6 +133,9 @@ public OAuth2ProtectedResourceDetails reddit() {
108133
@Bean
109134
public OAuth2RestTemplate redditRestTemplate(OAuth2ClientContext clientContext) {
110135
final OAuth2RestTemplate template = new OAuth2RestTemplate(reddit(), clientContext);
136+
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
137+
list.add(new UserAgentInterceptor());
138+
template.setInterceptors(list);
111139
final AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList(new MyAuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(),
112140
new ClientCredentialsAccessTokenProvider()));
113141
template.setAccessTokenProvider(accessTokenProvider);
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package org.baeldung.reddit.classifier;
2+
3+
import java.io.BufferedReader;
4+
import java.io.FileReader;
5+
import java.io.IOException;
6+
7+
import org.apache.mahout.classifier.sgd.L2;
8+
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
9+
import org.apache.mahout.math.RandomAccessSparseVector;
10+
import org.apache.mahout.math.Vector;
11+
import org.apache.mahout.vectorizer.encoders.AdaptiveWordValueEncoder;
12+
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
13+
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
14+
15+
import com.google.common.base.Splitter;
16+
17+
public class RedditClassifier {
18+
19+
public static int GOOD = 0;
20+
public static int BAD = 1;
21+
private final OnlineLogisticRegression classifier;
22+
private final FeatureVectorEncoder titleEncoder;
23+
private final FeatureVectorEncoder domainEncoder;
24+
25+
private final int[] trainCount = { 0, 0 };
26+
27+
private final int[] evalCount = { 0, 0 };
28+
29+
public RedditClassifier() {
30+
classifier = new OnlineLogisticRegression(2, 4, new L2(1));
31+
titleEncoder = new AdaptiveWordValueEncoder("title");
32+
titleEncoder.setProbes(1);
33+
domainEncoder = new StaticWordValueEncoder("domain");
34+
domainEncoder.setProbes(1);
35+
}
36+
37+
public void trainClassifier(String fileName) throws IOException {
38+
final BufferedReader reader = new BufferedReader(new FileReader(fileName));
39+
int category;
40+
Vector features;
41+
String line = reader.readLine();
42+
if (line == null) {
43+
new RedditDataCollector().collectData();
44+
}
45+
46+
while ((line != null) && (line != "")) {
47+
category = (line.startsWith("good")) ? GOOD : BAD;
48+
trainCount[category]++;
49+
features = convertLineToVector(line);
50+
classifier.train(category, features);
51+
line = reader.readLine();
52+
}
53+
reader.close();
54+
System.out.println("Training count ========= " + trainCount[0] + "___" + trainCount[1]);
55+
}
56+
57+
public int classify(Vector features) {
58+
return classifier.classifyFull(features).maxValueIndex();
59+
}
60+
61+
public Vector convertPost(String title, String domain, int hour) {
62+
final Vector features = new RandomAccessSparseVector(4);
63+
final int noOfWords = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title).size();
64+
titleEncoder.addToVector(title, features);
65+
domainEncoder.addToVector(domain, features);
66+
features.set(2, hour);
67+
features.set(3, noOfWords);
68+
return features;
69+
}
70+
71+
public double evaluateClassifier() throws IOException {
72+
final BufferedReader reader = new BufferedReader(new FileReader(RedditDataCollector.TEST_FILE));
73+
int category, result;
74+
int correct = 0;
75+
int wrong = 0;
76+
Vector features;
77+
String line = reader.readLine();
78+
while ((line != null) && (line != "")) {
79+
category = (line.startsWith("good")) ? GOOD : BAD;
80+
evalCount[category]++;
81+
features = convertLineToVector(line);
82+
result = classify(features);
83+
if (category == result) {
84+
correct++;
85+
} else {
86+
wrong++;
87+
}
88+
line = reader.readLine();
89+
}
90+
reader.close();
91+
System.out.println(correct + " ----- " + wrong);
92+
System.out.println("Eval count ========= " + evalCount[0] + "___" + evalCount[1]);
93+
return correct / (wrong + correct + 0.0);
94+
}
95+
96+
// ==== private
97+
private Vector convertLineToVector(String line) {
98+
final Vector features = new RandomAccessSparseVector(4);
99+
final String[] items = line.split(";");
100+
titleEncoder.addToVector(items[3], features);
101+
domainEncoder.addToVector(items[4], features);
102+
features.set(2, Integer.parseInt(items[1])); // hour of day
103+
features.set(3, Integer.parseInt(items[2])); // number of words in the title
104+
return features;
105+
}
106+
107+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package org.baeldung.reddit.classifier;
2+
3+
import java.io.FileWriter;
4+
import java.io.IOException;
5+
import java.text.SimpleDateFormat;
6+
import java.util.ArrayList;
7+
import java.util.Date;
8+
import java.util.List;
9+
10+
import org.baeldung.reddit.util.UserAgentInterceptor;
11+
import org.slf4j.Logger;
12+
import org.slf4j.LoggerFactory;
13+
import org.springframework.http.client.ClientHttpRequestInterceptor;
14+
import org.springframework.web.client.RestTemplate;
15+
16+
import com.fasterxml.jackson.databind.JsonNode;
17+
import com.google.common.base.Joiner;
18+
import com.google.common.base.Splitter;
19+
20+
public class RedditDataCollector {
21+
public static final String TRAINING_FILE = "src/main/resources/train.csv";
22+
public static final String TEST_FILE = "src/main/resources/test.csv";
23+
private final Logger logger = LoggerFactory.getLogger(getClass());
24+
25+
private String postAfter;
26+
private final RestTemplate restTemplate;
27+
private final String subreddit;
28+
private final int minScore;
29+
30+
public RedditDataCollector() {
31+
restTemplate = new RestTemplate();
32+
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
33+
list.add(new UserAgentInterceptor());
34+
restTemplate.setInterceptors(list);
35+
subreddit = "all";
36+
minScore = 4;
37+
}
38+
39+
public RedditDataCollector(String subreddit, int minScore) {
40+
restTemplate = new RestTemplate();
41+
final List<ClientHttpRequestInterceptor> list = new ArrayList<ClientHttpRequestInterceptor>();
42+
list.add(new UserAgentInterceptor());
43+
restTemplate.setInterceptors(list);
44+
this.subreddit = subreddit;
45+
this.minScore = minScore;
46+
}
47+
48+
public void collectData() {
49+
final int limit = 100;
50+
final int noOfRounds = 80;
51+
try {
52+
final FileWriter writer = new FileWriter(TRAINING_FILE);
53+
for (int i = 0; i < noOfRounds; i++) {
54+
getPosts(limit, writer);
55+
}
56+
writer.close();
57+
58+
final FileWriter testWriter = new FileWriter(TEST_FILE);
59+
getPosts(limit, testWriter);
60+
testWriter.close();
61+
} catch (final Exception e) {
62+
logger.error("write to file error", e);
63+
}
64+
}
65+
66+
// ==== private
67+
68+
private void getPosts(int limit, FileWriter writer) {
69+
String fullUrl = "http://www.reddit.com/r/" + subreddit + "/new.json?limit=" + limit;
70+
if (postAfter != null) {
71+
fullUrl += "&count=" + limit + "&after=" + postAfter;
72+
}
73+
74+
try {
75+
final JsonNode node = restTemplate.getForObject(fullUrl, JsonNode.class);
76+
parseNode(node, writer);
77+
Thread.sleep(3000);
78+
} catch (final Exception e) {
79+
logger.error("server error", e);
80+
}
81+
82+
}
83+
84+
private void parseNode(JsonNode node, FileWriter writer) throws IOException {
85+
postAfter = node.get("data").get("after").asText();
86+
System.out.println(postAfter);
87+
String line;
88+
String category;
89+
List<String> words;
90+
final SimpleDateFormat df = new SimpleDateFormat("HH");
91+
for (final JsonNode child : node.get("data").get("children")) {
92+
category = (child.get("data").get("score").asInt() < minScore) ? "bad" : "good";
93+
words = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(child.get("data").get("title").asText());
94+
final Date date = new Date(child.get("data").get("created_utc").asLong() * 1000);
95+
96+
line = category + ";";
97+
line += df.format(date) + ";";
98+
line += words.size() + ";" + Joiner.on(' ').join(words) + ";";
99+
line += child.get("data").get("domain").asText() + "\n";
100+
101+
writer.write(line);
102+
}
103+
}
104+
105+
public static void main(String[] args) {
106+
final RedditDataCollector collector = new RedditDataCollector();
107+
collector.collectData();
108+
}
109+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package org.baeldung.reddit.util;
2+
3+
import java.io.IOException;
4+
5+
import org.springframework.http.HttpHeaders;
6+
import org.springframework.http.HttpRequest;
7+
import org.springframework.http.client.ClientHttpRequestExecution;
8+
import org.springframework.http.client.ClientHttpRequestInterceptor;
9+
import org.springframework.http.client.ClientHttpResponse;
10+
11+
public class UserAgentInterceptor implements ClientHttpRequestInterceptor {
12+
13+
@Override
14+
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException {
15+
16+
final HttpHeaders headers = request.getHeaders();
17+
headers.add("User-Agent", "Schedule with Reddit");
18+
return execution.execute(request, body);
19+
}
20+
}

spring-security-oauth/src/main/java/org/baeldung/web/RedditController.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.baeldung.persistence.dao.UserRepository;
1313
import org.baeldung.persistence.model.Post;
1414
import org.baeldung.persistence.model.User;
15+
import org.baeldung.reddit.classifier.RedditClassifier;
1516
import org.baeldung.reddit.util.RedditApiConstants;
1617
import org.slf4j.Logger;
1718
import org.slf4j.LoggerFactory;
@@ -30,6 +31,7 @@
3031
import org.springframework.web.bind.annotation.RequestMapping;
3132
import org.springframework.web.bind.annotation.RequestMethod;
3233
import org.springframework.web.bind.annotation.RequestParam;
34+
import org.springframework.web.bind.annotation.ResponseBody;
3335
import org.springframework.web.bind.annotation.ResponseStatus;
3436

3537
import com.fasterxml.jackson.databind.JsonNode;
@@ -40,6 +42,7 @@ public class RedditController {
4042
private final Logger logger = LoggerFactory.getLogger(getClass());
4143

4244
private static final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm");
45+
private final SimpleDateFormat dfHour = new SimpleDateFormat("HH");
4346

4447
@Autowired
4548
private OAuth2RestTemplate redditRestTemplate;
@@ -50,6 +53,9 @@ public class RedditController {
5053
@Autowired
5154
private PostRepository postReopsitory;
5255

56+
@Autowired
57+
private RedditClassifier redditClassifier;
58+
5359
@RequestMapping("/login")
5460
public final String redditLogin() {
5561
final JsonNode node = redditRestTemplate.getForObject("https://oauth.reddit.com/api/v1/me", JsonNode.class);
@@ -122,6 +128,14 @@ public final String getScheduledPosts(final Model model) {
122128
return "postListView";
123129
}
124130

131+
@RequestMapping(value = "/predicatePostResponse", method = RequestMethod.POST)
132+
@ResponseBody
133+
public final String predicatePostResponse(@RequestParam(value = "title") final String title, @RequestParam(value = "domain") final String domain) {
134+
final int hour = Integer.parseInt(dfHour.format(new Date()));
135+
final int result = redditClassifier.classify(redditClassifier.convertPost(title, domain, hour));
136+
return (result == RedditClassifier.GOOD) ? "{Good Response}" : "{Bad response}";
137+
}
138+
125139
// === post actions
126140

127141
@RequestMapping(value = "/deletePost/{id}", method = RequestMethod.DELETE)

0 commit comments

Comments
 (0)