Skip to content

Commit a55504b

Browse files
Start new dataset pre-processing
1 parent 13c58c9 commit a55504b

File tree

1 file changed

+60
-37
lines changed

1 file changed

+60
-37
lines changed

chatbot/textdata.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from chatbot.corpus.ubuntudata import UbuntuData
3434
from chatbot.corpus.lightweightdata import LightweightData
3535

36+
3637
class Batch:
3738
"""Struct containing batches info
3839
"""
@@ -74,8 +75,12 @@ def __init__(self, args):
7475

7576
# Path variables
7677
self.corpusDir = os.path.join(self.args.rootDir, 'data', self.args.corpus)
77-
self.samplesDir = os.path.join(self.args.rootDir, 'data/samples/')
78-
self.samplesName = self._constructName()
78+
basePath = self._constructBasePath()
79+
self.fullSamplesPath = basePath + '.pkl' # Full sentences length/vocab
80+
self.filteredSamplesPath = basePath + '-lenght{}-filter{}.pkl'.format(
81+
self.args.maxLength,
82+
self.args.filterVocab,
83+
) # Sentences/vocab filtered for this model
7984

8085
self.padToken = -1 # Padding
8186
self.goToken = -1 # Start of sequence
@@ -86,23 +91,24 @@ def __init__(self, args):
8691

8792
self.word2id = {}
8893
self.id2word = {} # For a rapid conversion
94+
self.idCount = {} # Useful to filters the words
8995

90-
self.loadCorpus(self.samplesDir)
96+
self.loadCorpus()
9197

9298
# Plot some stats:
9399
print('Loaded {}: {} words, {} QA'.format(self.args.corpus, len(self.word2id), len(self.trainingSamples)))
94100

95101
if self.args.playDataset:
96102
self.playDataset()
97103

98-
def _constructName(self):
99-
"""Return the name of the dataset that the program should use with the current parameters.
100-
Computer from the base name, the given tag (self.args.datasetTag) and the sentence length
104+
def _constructBasePath(self):
105+
"""Return the name of the base prefix of the current dataset
101106
"""
102-
baseName = 'dataset-{}'.format(self.args.corpus)
107+
path = os.path.join(self.args.rootDir, 'data/samples/')
108+
path += 'dataset-{}'.format(self.args.corpus)
103109
if self.args.datasetTag:
104-
baseName += '-' + self.args.datasetTag
105-
return '{}-{}.pkl'.format(baseName, self.args.maxLength)
110+
path += '-' + self.args.datasetTag
111+
return path
106112

107113
def makeLighter(self, ratioDataset):
108114
"""Only keep a small fraction of the dataset, given by the ratio
@@ -141,6 +147,9 @@ def _createBatch(self, samples):
141147
if not self.args.test and self.args.autoEncode: # Autoencode: use either the question or answer for both input and output
142148
k = random.randint(0, 1)
143149
sample = (sample[k], sample[k])
150+
# TODO: Why re-processed that at each epoch ? Could precompute that
151+
# once and reuse those every time. Is not the bottleneck so won't change
152+
# much ? and if preprocessing, should be compatible with autoEncode & cie.
144153
batch.encoderSeqs.append(list(reversed(sample[0]))) # Reverse inputs (and not outputs), little trick as defined on the original seq2seq paper
145154
batch.decoderSeqs.append([self.goToken] + sample[1] + [self.eosToken]) # Add the <go> and <eos> tokens
146155
batch.targetSeqs.append(batch.decoderSeqs[-1][1:]) # Same as decoder, but shifted to the left (ignore the <go>)
@@ -149,6 +158,7 @@ def _createBatch(self, samples):
149158
assert len(batch.encoderSeqs[i]) <= self.args.maxLengthEnco
150159
assert len(batch.decoderSeqs[i]) <= self.args.maxLengthDeco
151160

161+
# TODO: Should use tf batch function to automatically add padding and batch samples
152162
# Add padding & define weight
153163
batch.encoderSeqs[i] = [self.padToken] * (self.args.maxLengthEnco - len(batch.encoderSeqs[i])) + batch.encoderSeqs[i] # Left padding for the input
154164
batch.weights.append([1.0] * len(batch.targetSeqs[i]) + [0.0] * (self.args.maxLengthDeco - len(batch.targetSeqs[i])))
@@ -204,6 +214,8 @@ def genNextSamples():
204214
for i in range(0, self.getSampleSize(), self.args.batchSize):
205215
yield self.trainingSamples[i:min(i + self.args.batchSize, self.getSampleSize())]
206216

217+
# TODO: Should replace that by generator (better: by tf.queue)
218+
207219
for samples in genNextSamples():
208220
batch = self._createBatch(samples)
209221
batches.append(batch)
@@ -223,56 +235,54 @@ def getVocabularySize(self):
223235
"""
224236
return len(self.word2id)
225237

226-
def loadCorpus(self, dirName):
238+
def loadCorpus(self):
227239
"""Load/create the conversations data
228-
Args:
229-
dirName (str): The directory where to load/save the model
230240
"""
231-
datasetExist = False
232-
if os.path.exists(os.path.join(dirName, self.samplesName)):
233-
datasetExist = True
234-
241+
datasetExist = os.path.isfile(self.filteredSamplesPath)
235242
if not datasetExist: # First time we load the database: creating all files
236243
print('Training samples not found. Creating dataset...')
237244

238-
optional = ''
239-
if self.args.corpus == 'lightweight' and not self.args.datasetTag:
240-
raise ValueError('Use the --datasetTag to define the lightweight file to use.')
241-
else:
242-
optional = '/' + self.args.datasetTag # HACK: Forward the filename
245+
datasetExist = os.path.isfile(self.fullSamplesPath) # Try to construct the dataset from the preprocessed entry
246+
if not datasetExist:
247+
print('Constructing full dataset...')
243248

244-
# Corpus creation
245-
corpusData = TextData.availableCorpus[self.args.corpus](self.corpusDir + optional)
246-
self.createCorpus(corpusData.getConversations())
249+
optional = ''
250+
if self.args.corpus == 'lightweight' and not self.args.datasetTag:
251+
raise ValueError('Use the --datasetTag to define the lightweight file to use.')
252+
else:
253+
optional = '/' + self.args.datasetTag # HACK: Forward the filename
254+
255+
# Corpus creation
256+
corpusData = TextData.availableCorpus[self.args.corpus](self.corpusDir + optional)
257+
self.createFullCorpus(corpusData.getConversations())
258+
259+
print('Filtering words...')
260+
self.filterFromFull() # Extract the sub vocabulary for the given maxLength and filterVocab
247261

248262
# Saving
249263
print('Saving dataset...')
250-
self.saveDataset(dirName) # Saving tf samples
264+
self.saveDataset() # Saving tf samples
251265
else:
252-
self.loadDataset(dirName)
266+
self.loadDataset()
253267

254268
assert self.padToken == 0
255269

256-
def saveDataset(self, dirName):
270+
def saveDataset(self):
257271
"""Save samples to file
258-
Args:
259-
dirName (str): The directory where to load/save the model
260272
"""
261273

262-
with open(os.path.join(dirName, self.samplesName), 'wb') as handle:
274+
with open(os.path.join(self.filteredSamplesPath), 'wb') as handle:
263275
data = { # Warning: If adding something here, also modifying loadDataset
264276
'word2id': self.word2id,
265277
'id2word': self.id2word,
266278
'trainingSamples': self.trainingSamples
267279
}
268280
pickle.dump(data, handle, -1) # Using the highest protocol available
269281

270-
def loadDataset(self, dirName):
282+
def loadDataset(self):
271283
"""Load samples from file
272-
Args:
273-
dirName (str): The directory where to load the model
274284
"""
275-
dataset_path = os.path.join(dirName, self.samplesName)
285+
dataset_path = os.path.join(self.filteredSamplesPath)
276286
print('Loading dataset from {}'.format(dataset_path))
277287
with open(dataset_path, 'rb') as handle:
278288
data = pickle.load(handle) # Warning: If adding something here, also modifying saveDataset
@@ -285,8 +295,21 @@ def loadDataset(self, dirName):
285295
self.eosToken = self.word2id['<eos>']
286296
self.unknownToken = self.word2id['<unknown>'] # Restore special words
287297

288-
def createCorpus(self, conversations):
289-
"""Extract all data from the given vocabulary
298+
def filterFromFull():
299+
""" Load the pre-processed full corpus and filter the vocabulary / sentences
300+
to match the given model option
301+
"""
302+
# 2 steps:
303+
# 1: first, iterate over all samples, and add sentences into the
304+
# training set. Decrement word count for unused sentences ?
305+
# 2: then, iterate over word count and compute new matching ids (can be unknown),
306+
# reiterate over the entire dataset to actualize new ids.
307+
pass # TODO
308+
309+
def createFullCorpus(self, conversations):
310+
"""Extract all data from the given vocabulary.
311+
Save the data on disk. Note that the entire corpus is pre-processed
312+
without restriction on the sentence lenght or vocab size.
290313
"""
291314
# Add standard tokens
292315
self.padToken = self.getWordId('<pad>') # Padding (Warning: first things to add > id=0 !!)
@@ -342,7 +365,7 @@ def extractText(self, line, isTarget=False):
342365
tokens = nltk.word_tokenize(sentencesToken[i])
343366

344367
# If the total length is not too big, we still can add one more sentence
345-
if len(words) + len(tokens) <= self.args.maxLength:
368+
if len(words) + len(tokens) <= self.args.maxLength: # TODO: Filter don't happen here
346369
tempWords = []
347370
for token in tokens:
348371
tempWords.append(self.getWordId(token)) # Create the vocabulary and the training sentences

0 commit comments

Comments
 (0)