Skip to content

Commit 346d5f1

Browse files
committed
Limit vocabulary size
1 parent 4364b39 commit 346d5f1

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

chatbot/chatbot.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def parseArgs(args):
114114
datasetArgs.add_argument('--ratioDataset', type=float, default=1.0, help='ratio of dataset used to avoid using the whole dataset') # Not implemented, useless ?
115115
datasetArgs.add_argument('--maxLength', type=int, default=10, help='maximum length of the sentence (for input and output), define number of maximum step of the RNN')
116116
datasetArgs.add_argument('--filterVocab', type=int, default=1, help='remove rarelly used words (by default words used only once). 0 to keep all words.')
117+
datasetArgs.add_argument('--increaseTrainingPairs', type=bool, default=False, help='Use every line in the dataset as both input and target, thus multiplying by two the training set.')
118+
datasetArgs.add_argument('--vocabularySize', type=int, default=40000, help='Limit the number of words in the vocabulary')
119+
117120

118121
# Network options (Warning: if modifying something here, also make the change on save/loadParams() )
119122
nnArgs = parser.add_argument_group('Network options', 'architecture related option')
@@ -543,6 +546,9 @@ def loadModelParams(self):
543546
self.args.datasetTag = config['Dataset'].get('datasetTag')
544547
self.args.maxLength = config['Dataset'].getint('maxLength') # We need to restore the model length because of the textData associated and the vocabulary size (TODO: Compatibility mode between different maxLength)
545548
self.args.filterVocab = config['Dataset'].getint('filterVocab')
549+
self.increaseTrainingPairs = config['Dataset'].getboolean('increaseTrainingPairs')
550+
self.args.vocabularySize = config['Dataset'].getint('vocabularySize')
551+
546552

547553
self.args.hiddenSize = config['Network'].getint('hiddenSize')
548554
self.args.numLayers = config['Network'].getint('numLayers')
@@ -564,6 +570,8 @@ def loadModelParams(self):
564570
print('datasetTag: {}'.format(self.args.datasetTag))
565571
print('maxLength: {}'.format(self.args.maxLength))
566572
print('filterVocab: {}'.format(self.args.filterVocab))
573+
print('increaseTrainingPairs: {}'.format(self.args.increaseTrainingPairs))
574+
print('vocabularySize: {}'.format(self.args.vocabularySize))
567575
print('hiddenSize: {}'.format(self.args.hiddenSize))
568576
print('numLayers: {}'.format(self.args.numLayers))
569577
print('softmaxSamples: {}'.format(self.args.softmaxSamples))
@@ -596,7 +604,10 @@ def saveModelParams(self):
596604
config['Dataset']['datasetTag'] = str(self.args.datasetTag)
597605
config['Dataset']['maxLength'] = str(self.args.maxLength)
598606
config['Dataset']['filterVocab'] = str(self.args.filterVocab)
599-
607+
config['Dataset']['increaseTrainingPairs'] = str(self.args.increaseTrainingPairs)
608+
config['Dataset']['vocabularySize'] = str(self.args.vocabularySize)
609+
610+
600611
config['Network'] = {}
601612
config['Network']['hiddenSize'] = str(self.args.hiddenSize)
602613
config['Network']['numLayers'] = str(self.args.numLayers)

chatbot/textdata.py

100644100755
Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import random
2727
import string
2828
from collections import OrderedDict
29+
import collections
2930

3031
from chatbot.corpus.cornelldata import CornellData
3132
from chatbot.corpus.opensubsdata import OpensubsData
@@ -77,9 +78,10 @@ def __init__(self, args):
7778
self.corpusDir = os.path.join(self.args.rootDir, 'data', self.args.corpus)
7879
basePath = self._constructBasePath()
7980
self.fullSamplesPath = basePath + '.pkl' # Full sentences length/vocab
80-
self.filteredSamplesPath = basePath + '-lenght{}-filter{}.pkl'.format(
81+
self.filteredSamplesPath = basePath + '-lenght{}-filter{}-vocabSize{}.pkl'.format(
8182
self.args.maxLength,
8283
self.args.filterVocab,
84+
self.args.vocabularySize,
8385
) # Sentences/vocab filtered for this model
8486

8587
self.padToken = -1 # Padding
@@ -366,20 +368,25 @@ def mergeSentences(sentences, fromEnd=False):
366368
}
367369
new_mapping = {} # Map the full words ids to the new one (TODO: Should be a list)
368370
newId = 0
371+
372+
print("Filtering dataset with vocabSize={} and wordCount > {}", self.args.vocabularySize,self.args.filterVocab)
373+
word_counter = collections.Counter(self.idCount)
374+
selected_word_ids = word_counter.most_common(self.args.vocabularySize)
375+
selected_word_ids = { k:v for k, v in selected_word_ids if v>self.args.filterVocab }
376+
369377
for wordId, count in [(i, self.idCount[i]) for i in range(len(self.idCount))]: # Iterate in order
370-
if (count <= self.args.filterVocab and
371-
wordId not in specialTokens): # Cadidate to filtering (Warning: don't filter special token)
372-
new_mapping[wordId] = self.unknownToken
373-
del self.word2id[self.id2word[wordId]] # The word isn't used anymore
374-
del self.id2word[wordId]
375-
else: # Update the words ids
378+
if wordId in selected_word_ids or wordId in specialTokens: #update word id
376379
new_mapping[wordId] = newId
377380
word = self.id2word[wordId] # The new id has changed, update the dictionaries
378381
del self.id2word[wordId] # Will be recreated if newId == wordId
379382
self.word2id[word] = newId
380383
self.id2word[newId] = word
381384
newId += 1
382-
385+
else: #Not in our list nor special, map it to unknownToken
386+
new_mapping[wordId] = self.unknownToken
387+
del self.word2id[self.id2word[wordId]] # The word isn't used anymore
388+
del self.id2word[wordId]
389+
383390
# Last step: replace old ids by new ones and filters empty sentences
384391
def replace_words(words):
385392
valid = False # Filter empty sequences
@@ -390,15 +397,25 @@ def replace_words(words):
390397
return valid
391398

392399
self.trainingSamples.clear()
400+
self.idCount.clear() # Let's recreate idCount
401+
393402
for inputWords, targetWords in tqdm(newSamples, desc='Replace ids:', leave=False):
394403
valid = True
395404
valid &= replace_words(inputWords)
396405
valid &= replace_words(targetWords)
406+
valid &= targetWords.count(self.unknownToken) == 0 # Filter target with out-of-vocabulary target words
397407

398408
if valid:
399409
self.trainingSamples.append([inputWords, targetWords]) # TODO: Could replace list by tuple
410+
#Recreate idCount
411+
for wordId in inputWords + targetWords:
412+
if wordId in self.idCount:
413+
self.idCount[wordId] = self.idCount[wordId] + 1
414+
else:
415+
self.idCount[wordId] = 1
416+
print("Final vocabulary size of", len(self.word2id) - len(specialTokens))
417+
400418

401-
self.idCount.clear() # Not usefull anymore. Free data
402419

403420

404421
def createFullCorpus(self, conversations):
@@ -424,9 +441,14 @@ def extractConversation(self, conversation):
424441
Args:
425442
conversation (Obj): a conversation object containing the lines to extract
426443
"""
444+
445+
if self.args.increaseTrainingPairs:
446+
step = 1
447+
else:
448+
step = 2
427449

428450
# Iterate over all the lines of the conversation
429-
for i in tqdm_wrap(range(len(conversation['lines']) - 1), # We ignore the last line (no answer for it)
451+
for i in tqdm_wrap(range(0,len(conversation['lines']) - 1, step ), # We ignore the last line (no answer for it)
430452
desc='Conversation', leave=False):
431453
inputLine = conversation['lines'][i]
432454
targetLine = conversation['lines'][i+1]

0 commit comments

Comments
 (0)