Skip to content

Commit 0c1021b

Browse files
Load and save full corpus
1 parent a55504b commit 0c1021b

File tree

1 file changed

+66
-36
lines changed

1 file changed

+66
-36
lines changed

chatbot/textdata.py

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,20 @@ def __init__(self, args):
9090
self.trainingSamples = [] # 2d array containing each question and his answer [[input,target]]
9191

9292
self.word2id = {}
93-
self.id2word = {} # For a rapid conversion
93+
self.id2word = {} # For a rapid conversion (TODO: Could replace dict by list)
9494
self.idCount = {} # Useful to filters the words
9595

9696
self.loadCorpus()
9797

9898
# Plot some stats:
99-
print('Loaded {}: {} words, {} QA'.format(self.args.corpus, len(self.word2id), len(self.trainingSamples)))
99+
self._printStats()
100100

101101
if self.args.playDataset:
102102
self.playDataset()
103103

104+
def _printStats(self):
105+
print('Loaded {}: {} words, {} QA'.format(self.args.corpus, len(self.word2id), len(self.trainingSamples)))
106+
104107
def _constructBasePath(self):
105108
"""Return the name of the base prefix of the current dataset
106109
"""
@@ -255,39 +258,49 @@ def loadCorpus(self):
255258
# Corpus creation
256259
corpusData = TextData.availableCorpus[self.args.corpus](self.corpusDir + optional)
257260
self.createFullCorpus(corpusData.getConversations())
261+
self._printStats()
262+
self.saveDataset(self.fullSamplesPath)
263+
else:
264+
self.loadDataset(self.fullSamplesPath)
258265

259266
print('Filtering words...')
260267
self.filterFromFull() # Extract the sub vocabulary for the given maxLength and filterVocab
261268

262269
# Saving
263270
print('Saving dataset...')
264-
self.saveDataset() # Saving tf samples
271+
self.saveDataset(self.filteredSamplesPath) # Saving tf samples
265272
else:
266-
self.loadDataset()
273+
self.loadDataset(self.filteredSamplesPath)
267274

268275
assert self.padToken == 0
269276

270-
def saveDataset(self):
277+
def saveDataset(self, filename):
271278
"""Save samples to file
279+
Args:
280+
filename (str): pickle filename
272281
"""
273282

274-
with open(os.path.join(self.filteredSamplesPath), 'wb') as handle:
283+
with open(os.path.join(filename), 'wb') as handle:
275284
data = { # Warning: If adding something here, also modifying loadDataset
276285
'word2id': self.word2id,
277286
'id2word': self.id2word,
287+
'idCount': self.idCount,
278288
'trainingSamples': self.trainingSamples
279-
}
289+
}
280290
pickle.dump(data, handle, -1) # Using the highest protocol available
281291

282-
def loadDataset(self):
292+
def loadDataset(self, filename):
283293
"""Load samples from file
294+
Args:
295+
filename (str): pickle filename
284296
"""
285-
dataset_path = os.path.join(self.filteredSamplesPath)
297+
dataset_path = os.path.join(filename)
286298
print('Loading dataset from {}'.format(dataset_path))
287299
with open(dataset_path, 'rb') as handle:
288300
data = pickle.load(handle) # Warning: If adding something here, also modifying saveDataset
289301
self.word2id = data['word2id']
290302
self.id2word = data['id2word']
303+
self.idCount = data['idCount']
291304
self.trainingSamples = data['trainingSamples']
292305

293306
self.padToken = self.word2id['<pad>']
@@ -305,6 +318,33 @@ def filterFromFull():
305318
# 2: then, iterate over word count and compute new matching ids (can be unknown),
306319
# reiterate over the entire dataset to actualize new ids.
307320
pass # TODO
321+
words = []
322+
323+
# Extract sentences
324+
sentencesToken = nltk.sent_tokenize(line)
325+
326+
# We add sentence by sentence until we reach the maximum length
327+
for i in range(len(sentencesToken)):
328+
# If question: we only keep the last sentences
329+
# If answer: we only keep the first sentences
330+
if not isTarget:
331+
i = len(sentencesToken)-1 - i
332+
333+
tokens = nltk.word_tokenize(sentencesToken[i])
334+
335+
# If the total length is not too big, we still can add one more sentence
336+
if len(words) + len(tokens) <= self.args.maxLength: # TODO: Filter don't happen here
337+
tempWords = []
338+
for token in tokens:
339+
tempWords.append(self.getWordId(token)) # Create the vocabulary and the training sentences
340+
341+
if isTarget:
342+
words = words + tempWords
343+
else:
344+
words = tempWords + words
345+
else:
346+
break # We reach the max length already
347+
308348

309349
def createFullCorpus(self, conversations):
310350
"""Extract all data from the given vocabulary.
@@ -350,34 +390,22 @@ def extractText(self, line, isTarget=False):
350390
Return:
351391
list<int>: the list of the word ids of the sentence
352392
"""
353-
words = []
393+
sentences = [] # List[List[str]]
354394

355395
# Extract sentences
356396
sentencesToken = nltk.sent_tokenize(line)
357397

358398
# We add sentence by sentence until we reach the maximum length
359399
for i in range(len(sentencesToken)):
360-
# If question: we only keep the last sentences
361-
# If answer: we only keep the first sentences
362-
if not isTarget:
363-
i = len(sentencesToken)-1 - i
364-
365400
tokens = nltk.word_tokenize(sentencesToken[i])
366401

367-
# If the total length is not too big, we still can add one more sentence
368-
if len(words) + len(tokens) <= self.args.maxLength: # TODO: Filter don't happen here
369-
tempWords = []
370-
for token in tokens:
371-
tempWords.append(self.getWordId(token)) # Create the vocabulary and the training sentences
402+
tempWords = []
403+
for token in tokens:
404+
tempWords.append(self.getWordId(token)) # Create the vocabulary and the training sentences
372405

373-
if isTarget:
374-
words = words + tempWords
375-
else:
376-
words = tempWords + words
377-
else:
378-
break # We reach the max length already
406+
sentences.append(tempWords)
379407

380-
return words
408+
return sentences
381409

382410
def getWordId(self, word, create=True):
383411
"""Get the id of the word (and add it to the dictionary if not existing). If the word does not exist and
@@ -392,17 +420,19 @@ def getWordId(self, word, create=True):
392420

393421
word = word.lower() # Ignore case
394422

423+
# At inference, we simply look up for the word
424+
if not create:
425+
wordId = self.word2id.get(word, self.unknownToken)
395426
# Get the id if the word already exist
396-
wordId = self.word2id.get(word, -1)
397-
427+
elif word in self.word2id:
428+
wordId = self.word2id[word]
429+
self.idCount[wordId] += 1
398430
# If not, we create a new entry
399-
if wordId == -1:
400-
if create:
401-
wordId = len(self.word2id)
402-
self.word2id[word] = wordId
403-
self.id2word[wordId] = word
404-
else:
405-
wordId = self.unknownToken
431+
else:
432+
wordId = len(self.word2id)
433+
self.word2id[word] = wordId
434+
self.id2word[wordId] = word
435+
self.idCount[wordId] = 1
406436

407437
return wordId
408438

0 commit comments

Comments
 (0)