Skip to content

Commit 57342fc

Browse files
committed
Add feature augmentation code (kg vector generation)
1 parent 80b86b7 commit 57342fc

File tree

4 files changed

+405
-6
lines changed

4 files changed

+405
-6
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,18 @@ The arguments of the command represent
8989

9090
The location of the result file is specified by config.\{zhang15_dbpedia, news20\}_train_augmented_aggregated_path.
9191

92+
93+
### How to perform feature augmentation / create v_{w,c}
94+
95+
An example:
96+
```bash
97+
python3 kg_vector_generation.py --data dbpedia
98+
```
99+
The argument of the command represent
100+
* `data`: Dataset, either `dbpedia` or `20news`.
101+
102+
The locations of the result files are specified by config.\{zhang15_dbpedia, news20\}_kg_vector_dir.
103+
92104
### How to train / test Phase 1
93105

94106
Pending

src_reject/config.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@
111111

112112
word_embed_file_path = "../data/glove/glove.6B.200d.txt"
113113
word_embed_gensim_file_path = '../data/glove/glove.6B.200d.gensim.txt'
114-
conceptnet_path = "../wordEmbeddings/conceptnet-assertions-en-5.6.0.csv"
114+
conceptnet_path = "../data/conceptnet-assertions-en-5.6.0.csv"
115115
POS_OF_WORD_path = "../data/POS_OF_WORD.pickle"
116116
WORD_TOPIC_TRANSLATION_path = "../data/WORD_TOPIC_TRANSLATION.pickle"
117117

@@ -195,9 +195,8 @@
195195

196196
# zhang15_dbpedia_kg_vector_dir = zhang15_dbpedia_dir + "KG_VECTOR_3/"
197197
# zhang15_dbpedia_kg_vector_prefix = "KG_VECTORS_3_"
198-
# TODO by Peter, how to get KG_Vector files
198+
zhang15_dbpedia_kg_vector_node_data_path = zhang15_dbpedia_dir + 'NODES_DATA.pickle'
199199
zhang15_dbpedia_kg_vector_dir = zhang15_dbpedia_dir + "KG_VECTOR_CLUSTER_3GROUP/"
200-
# zhang15_dbpedia_kg_vector_dir = zhang15_dbpedia_dir + "KG_VECTOR_CLUSTER_ALLGROUP/"
201200
zhang15_dbpedia_kg_vector_prefix = "VECTORS_CLUSTER_3_"
202201

203202
zhang15_dbpedia_word_embed_matrix_path = zhang15_dbpedia_dir + "word_embed_matrix.npz"
@@ -293,9 +292,7 @@
293292

294293
news20_vocab_path = news20_dir + "vocab.txt"
295294

296-
# TODO by Peter, how to get kg vectors
297-
# news20_kg_vector_dir = news20_dir + "KG_VECTOR_3_Lem/"
298-
# news20_kg_vector_prefix = "lemmatised_KG_VECTORS_3_"
295+
news20_kg_vector_node_data_path = news20_dir + 'NODES_DATA.pickle'
299296
news20_kg_vector_dir = news20_dir + "KG_VECTOR_CLUSTER_3GROUP/"
300297
news20_kg_vector_prefix = "VECTORS_CLUSTER_3_"
301298

src_reject/kg_vector_generation.py

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
import pickle, json, requests, csv, copy, os, re
2+
import numpy as np
3+
import pprint as pp
4+
import urllib.request, urllib.parse
5+
from sklearn.metrics.pairwise import cosine_similarity
6+
from json import JSONDecodeError
7+
from text_to_uri import *
8+
9+
import nltk
10+
from nltk.corpus import stopwords
11+
from nltk.tokenize import word_tokenize
12+
from nltk.stem import WordNetLemmatizer
13+
from tqdm import tqdm
14+
15+
import config
16+
17+
## Global variables initialisation
18+
19+
lemmatizer = WordNetLemmatizer()
20+
stop_words = set(stopwords.words('english'))
21+
22+
pos_dict = {'JJ': 'a', 'JJR': 'a', 'JJS': 'a',
23+
'NN': 'n', 'NNP': 'n', 'NNPS': 'n', 'NNS': 'n',
24+
'RB': 'r', 'RBR': 'r', 'RBS': 'r',
25+
'VB': 'v', 'VBD': 'v', 'VBG': 'v', 'VBN': 'v', 'VBP': 'v', 'VBZ': 'v'}
26+
27+
NODES_DATA = dict()
28+
lemmatise_dict = dict()
29+
30+
31+
## Functions
32+
33+
### Category (Class) related functions
34+
35+
class Category:
36+
37+
def __init__(self, label, description, hierarchy): # Create an empty path
38+
self.label = label.strip().lower()
39+
self.description = description.strip().lower()
40+
self.hierarchy = hierarchy.strip().lower().split(';')
41+
self.nodes = {'the_class':[],
42+
'super_class':[],
43+
'description':[],
44+
}
45+
self.find_nodes()
46+
47+
def __repr__(self):
48+
return self.label + ' => \n' + '\n'.join([key + ': ' + str(val) for key, val in self.nodes.items()]) + '\n'
49+
50+
def find_nodes(self):
51+
# The class
52+
self.nodes['the_class'] = get_all_nodes_from_label(self.label)
53+
54+
# Super class
55+
for a_super_class in self.hierarchy:
56+
self.nodes['super_class'].extend(get_all_nodes_from_label(a_super_class))
57+
self.nodes['super_class'] = list(set(self.nodes['super_class']))
58+
59+
# Description
60+
text = nltk.pos_tag(word_tokenize(self.description))
61+
for token in text:
62+
if token[1].startswith('NN') and token[0] not in stop_words: # Noun and not stop words
63+
self.nodes['description'].extend(get_all_nodes_from_label(token[0]))
64+
self.nodes['description'] = list(set(self.nodes['description']))
65+
66+
def get_all_nodes(self):
67+
ans = []
68+
for key, val in self.nodes.items():
69+
ans.extend(val)
70+
return set(ans)
71+
72+
73+
def get_class_info(filename):
74+
with open(filename, encoding = 'utf-8') as csvfile:
75+
reader = csv.DictReader(csvfile, delimiter=',')
76+
ans = [row for row in reader]
77+
print("No. of classes =", len(ans))
78+
print("Header =", ans[0].keys())
79+
return ans
80+
81+
def get_all_nodes_from_label(label):
82+
ans = []
83+
if standardized_uri('en', label) in NODES_DATA:
84+
ans.append(standardized_uri('en', label))
85+
for token in label.split():
86+
if token not in stop_words:
87+
token = lemmatise_ConceptNet_label(token)
88+
if standardized_uri('en', token) in NODES_DATA and standardized_uri('en', token) not in ans:
89+
ans.append(standardized_uri('en', token))
90+
return ans
91+
92+
### ConceptNet (nodes) related functions
93+
94+
class ConceptNet_node:
95+
96+
def __init__(self, uri): # Create a node
97+
self.uri = remove_word_sense(uri)
98+
self.label = uri[uri.rfind('/')+1:]
99+
self.neighbors = {0: set([self.uri]),
100+
1: set()}
101+
102+
def find_neighbors(self, hop):
103+
if hop not in self.neighbors:
104+
one_hop_less = self.find_neighbors(hop-1)
105+
ans = set()
106+
for n in one_hop_less:
107+
ans = ans.union(NODES_DATA[n].find_neighbors(1))
108+
ans = ans.difference(self.find_neighbors_within(hop-1))
109+
self.neighbors[hop] = ans
110+
print('Finish finding neighbors of ', self.uri, 'hop =', hop)
111+
return self.neighbors[hop]
112+
113+
def find_neighbors_within(self, hop):
114+
assert hop >= 0, 'Hop number must be non-negative'
115+
if hop == 0:
116+
return self.neighbors[0]
117+
else:
118+
return self.find_neighbors(hop).union(self.find_neighbors_within(hop-1))
119+
120+
121+
def get_neighbors_of_cluster(node_set, hop):
122+
ans = set()
123+
for n in node_set:
124+
assert n in NODES_DATA, "Invalid node " + n
125+
ans = ans.union(NODES_DATA[n].find_neighbors_within(hop))
126+
return ans
127+
128+
def remove_word_sense(sub):
129+
if sub.count('/') > 3:
130+
if sub.count('/') > 4:
131+
print(sub)
132+
assert False, "URI error (with more than 4 slashes)"
133+
sub = sub[:sub.rfind('/')]
134+
return sub
135+
136+
def get_label_from_uri(uri):
137+
uri = remove_word_sense(uri)
138+
return uri[uri.rfind('/')+1:]
139+
140+
def lemmatise_ConceptNet_label(label):
141+
if '_' in label:
142+
return label
143+
else:
144+
tag = nltk.pos_tag([label])[0][1]
145+
if tag not in pos_dict:
146+
return label
147+
else:
148+
return lemmatizer.lemmatize(label, pos_dict[tag])
149+
150+
def lemmatise_ConceptNet_uri(uri):
151+
label = get_label_from_uri(uri)
152+
lemmatised_label = lemmatise_ConceptNet_label(label)
153+
return standardized_uri('en', lemmatised_label)
154+
155+
def create_lemmatised_dict(ns): # ns is a set of nodes from read_all_nodes()
156+
nodes = dict()
157+
for n in tqdm(ns):
158+
nodes[n] = lemmatise_ConceptNet_uri(n)
159+
return nodes
160+
161+
### Loading ConceptNet functions
162+
163+
def read_all_nodes(filename): # get all distinct uri in conceptnet (without part of speech)
164+
nodes = set()
165+
with open(filename, 'r', encoding = "utf8") as csvfile:
166+
reader = csv.reader(csvfile, delimiter='\t')
167+
for line in tqdm(reader):
168+
if not line[2].startswith('/c/en/') or not line[3].startswith('/c/en/'): # only relationships with english nodes
169+
continue
170+
sub = remove_word_sense(line[2])
171+
obj = remove_word_sense(line[3])
172+
nodes.add(sub)
173+
nodes.add(obj)
174+
return nodes
175+
176+
177+
def load_one_hop_data(filename, NODES_DATA, rel_list):
178+
count_edges = 0
179+
with open(filename, 'r', encoding = "utf8") as csvfile:
180+
reader = csv.reader(csvfile, delimiter='\t')
181+
for line in tqdm(reader):
182+
rel = line[1].strip()
183+
if rel_list is None or rel in rel_list:
184+
details = json.loads(line[4])
185+
w = details['weight']
186+
if w < 1.0:
187+
continue
188+
if not line[2].startswith('/c/en/') or not line[3].startswith('/c/en/'): # only relationships with english nodes
189+
continue
190+
sub = lemmatise_dict[remove_word_sense(line[2])]
191+
obj = lemmatise_dict[remove_word_sense(line[3])]
192+
if sub != obj:
193+
NODES_DATA[sub].neighbors[1].add(obj)
194+
NODES_DATA[obj].neighbors[1].add(sub)
195+
count_edges += 1
196+
print("Total no. of registered edges =", count_edges)
197+
198+
199+
def load_ConceptNet():
200+
global lemmatise_dict, NODES_DATA
201+
202+
filename = config.conceptnet_path
203+
204+
# Find all lemmatised nodes
205+
print("Reading all nodes from ConceptNet")
206+
ALL_NODES = read_all_nodes(filename)
207+
print('Before lemmatising, no. of all nodes = ', len(ALL_NODES))
208+
lemmatise_dict = create_lemmatised_dict(ALL_NODES)
209+
ALL_NODES = set(lemmatise_dict.values())
210+
print('After lemmatising, no. of all nodes = ', len(ALL_NODES))
211+
212+
# Create all lemmatised nodes objects in the process
213+
for n in ALL_NODES:
214+
NODES_DATA[n] = ConceptNet_node(n)
215+
del ALL_NODES
216+
print('Finish creating lemmatised nodes')
217+
218+
# Load one hop data from ConceptNet
219+
rel_list = ['/r/IsA', '/r/PartOf', '/r/AtLocation', '/r/RelatedTo']
220+
load_one_hop_data(filename, NODES_DATA, rel_list)
221+
print('Finish loading one hop data')
222+
223+
### Creating KG vector function
224+
225+
def get_vector_of(n, all_c_nodes, hop): # n = uri, c = Category_node
226+
v = np.zeros(3 * hop + 1)
227+
v[0] = 1.0 if n in all_c_nodes else 0.0
228+
for i in range(hop):
229+
have_hops = [n in NODES_DATA[c].find_neighbors(i+1) for c in all_c_nodes]
230+
if len(have_hops) > 0:
231+
v[3 * i + 1] = float(any(have_hops))
232+
v[3 * i + 2] = float(sum(have_hops))
233+
v[3 * i + 3] = float(np.mean(have_hops))
234+
else:
235+
v[3 * i + 1] = 0.0
236+
v[3 * i + 2] = 0.0
237+
v[3 * i + 3] = 0.0
238+
return v
239+
240+
241+
## Main Program
242+
def main_program(class_filename, node_data_filename, kg_vector_dir, kg_vector_prefix):
243+
# - Load conceptnet
244+
load_ConceptNet()
245+
246+
# - Load class data and form a cluster of nodes for each class
247+
class_nodes = set()
248+
class_info = get_class_info(class_filename)
249+
classes = [Category(c['ConceptNet'], c['ClassDescription'], c['Hierarchy']) for c in class_info]
250+
for c in classes:
251+
class_nodes = class_nodes.union(c.get_all_nodes())
252+
print(len(class_nodes), class_nodes)
253+
254+
for c in classes:
255+
print(c)
256+
257+
class_clusters = dict()
258+
for c in classes:
259+
class_clusters[c.label] = c.get_all_nodes()
260+
print(class_clusters)
261+
262+
# - Find neighbors of nodes in each cluster
263+
264+
for c in tqdm(class_nodes):
265+
print('Processing class', c)
266+
NODES_DATA[c].find_neighbors(3)
267+
268+
pickle.dump(NODES_DATA, open(node_data_filename, "wb"))
269+
270+
271+
# - Calculate KG vectors for each class
272+
273+
for c in tqdm(classes):
274+
all_c_nodes = c.get_all_nodes()
275+
all_neighbors = get_neighbors_of_cluster(all_c_nodes, hop = 3)
276+
print(c, len(all_neighbors))
277+
278+
vectors = dict()
279+
for n in all_neighbors:
280+
# Consider each partition of nodes separately
281+
vectors[n] = np.concatenate((get_vector_of(n, c.nodes['the_class'], hop = 3), get_vector_of(n, c.nodes['super_class'], hop = 3), get_vector_of(n, c.nodes['description'], hop = 3)), axis = 0)
282+
283+
pickle.dump(vectors, open(kg_vector_dir + kg_vector_prefix + c.label + ".pickle", "wb"))
284+
print('Finish calculating vectors for', c.label)
285+
286+
if __name__ == "__main__":
287+
print(config.dataset)
288+
if config.dataset == "dbpedia":
289+
main_program(config.zhang15_dbpedia_class_label_path, config.zhang15_dbpedia_kg_vector_node_data_path, config.zhang15_dbpedia_kg_vector_dir, config.zhang15_dbpedia_kg_vector_prefix)
290+
elif config.dataset == "20news":
291+
main_program(config.news20_class_label_path, config.news20_kg_vector_node_data_path, config.news20_kg_vector_dir, config.news20_kg_vector_prefix)
292+
else:
293+
raise Exception("config.dataset %s not found" % config.dataset)
294+
pass

0 commit comments

Comments
 (0)