3333from chatbot .corpus .ubuntudata import UbuntuData
3434from chatbot .corpus .lightweightdata import LightweightData
3535
36+
3637class Batch :
3738 """Struct containing batches info
3839 """
@@ -74,8 +75,12 @@ def __init__(self, args):
7475
7576 # Path variables
7677 self .corpusDir = os .path .join (self .args .rootDir , 'data' , self .args .corpus )
77- self .samplesDir = os .path .join (self .args .rootDir , 'data/samples/' )
78- self .samplesName = self ._constructName ()
78+ basePath = self ._constructBasePath ()
79+ self .fullSamplesPath = basePath + '.pkl' # Full sentences length/vocab
80+ self .filteredSamplesPath = basePath + '-lenght{}-filter{}.pkl' .format (
81+ self .args .maxLength ,
82+ self .args .filterVocab ,
83+ ) # Sentences/vocab filtered for this model
7984
8085 self .padToken = - 1 # Padding
8186 self .goToken = - 1 # Start of sequence
@@ -86,23 +91,24 @@ def __init__(self, args):
8691
8792 self .word2id = {}
8893 self .id2word = {} # For a rapid conversion
94+ self .idCount = {} # Useful to filters the words
8995
90- self .loadCorpus (self . samplesDir )
96+ self .loadCorpus ()
9197
9298 # Plot some stats:
9399 print ('Loaded {}: {} words, {} QA' .format (self .args .corpus , len (self .word2id ), len (self .trainingSamples )))
94100
95101 if self .args .playDataset :
96102 self .playDataset ()
97103
98- def _constructName (self ):
99- """Return the name of the dataset that the program should use with the current parameters.
100- Computer from the base name, the given tag (self.args.datasetTag) and the sentence length
104+ def _constructBasePath (self ):
105+ """Return the name of the base prefix of the current dataset
101106 """
102- baseName = 'dataset-{}' .format (self .args .corpus )
107+ path = os .path .join (self .args .rootDir , 'data/samples/' )
108+ path += 'dataset-{}' .format (self .args .corpus )
103109 if self .args .datasetTag :
104- baseName += '-' + self .args .datasetTag
105- return '{}-{}.pkl' . format ( baseName , self . args . maxLength )
110+ path += '-' + self .args .datasetTag
111+ return path
106112
107113 def makeLighter (self , ratioDataset ):
108114 """Only keep a small fraction of the dataset, given by the ratio
@@ -141,6 +147,9 @@ def _createBatch(self, samples):
141147 if not self .args .test and self .args .autoEncode : # Autoencode: use either the question or answer for both input and output
142148 k = random .randint (0 , 1 )
143149 sample = (sample [k ], sample [k ])
150+ # TODO: Why re-processed that at each epoch ? Could precompute that
151+ # once and reuse those every time. Is not the bottleneck so won't change
152+ # much ? and if preprocessing, should be compatible with autoEncode & cie.
144153 batch .encoderSeqs .append (list (reversed (sample [0 ]))) # Reverse inputs (and not outputs), little trick as defined on the original seq2seq paper
145154 batch .decoderSeqs .append ([self .goToken ] + sample [1 ] + [self .eosToken ]) # Add the <go> and <eos> tokens
146155 batch .targetSeqs .append (batch .decoderSeqs [- 1 ][1 :]) # Same as decoder, but shifted to the left (ignore the <go>)
@@ -149,6 +158,7 @@ def _createBatch(self, samples):
149158 assert len (batch .encoderSeqs [i ]) <= self .args .maxLengthEnco
150159 assert len (batch .decoderSeqs [i ]) <= self .args .maxLengthDeco
151160
161+ # TODO: Should use tf batch function to automatically add padding and batch samples
152162 # Add padding & define weight
153163 batch .encoderSeqs [i ] = [self .padToken ] * (self .args .maxLengthEnco - len (batch .encoderSeqs [i ])) + batch .encoderSeqs [i ] # Left padding for the input
154164 batch .weights .append ([1.0 ] * len (batch .targetSeqs [i ]) + [0.0 ] * (self .args .maxLengthDeco - len (batch .targetSeqs [i ])))
@@ -204,6 +214,8 @@ def genNextSamples():
204214 for i in range (0 , self .getSampleSize (), self .args .batchSize ):
205215 yield self .trainingSamples [i :min (i + self .args .batchSize , self .getSampleSize ())]
206216
217+ # TODO: Should replace that by generator (better: by tf.queue)
218+
207219 for samples in genNextSamples ():
208220 batch = self ._createBatch (samples )
209221 batches .append (batch )
@@ -223,56 +235,54 @@ def getVocabularySize(self):
223235 """
224236 return len (self .word2id )
225237
226- def loadCorpus (self , dirName ):
238+ def loadCorpus (self ):
227239 """Load/create the conversations data
228- Args:
229- dirName (str): The directory where to load/save the model
230240 """
231- datasetExist = False
232- if os .path .exists (os .path .join (dirName , self .samplesName )):
233- datasetExist = True
234-
241+ datasetExist = os .path .isfile (self .filteredSamplesPath )
235242 if not datasetExist : # First time we load the database: creating all files
236243 print ('Training samples not found. Creating dataset...' )
237244
238- optional = ''
239- if self .args .corpus == 'lightweight' and not self .args .datasetTag :
240- raise ValueError ('Use the --datasetTag to define the lightweight file to use.' )
241- else :
242- optional = '/' + self .args .datasetTag # HACK: Forward the filename
245+ datasetExist = os .path .isfile (self .fullSamplesPath ) # Try to construct the dataset from the preprocessed entry
246+ if not datasetExist :
247+ print ('Constructing full dataset...' )
243248
244- # Corpus creation
245- corpusData = TextData .availableCorpus [self .args .corpus ](self .corpusDir + optional )
246- self .createCorpus (corpusData .getConversations ())
249+ optional = ''
250+ if self .args .corpus == 'lightweight' and not self .args .datasetTag :
251+ raise ValueError ('Use the --datasetTag to define the lightweight file to use.' )
252+ else :
253+ optional = '/' + self .args .datasetTag # HACK: Forward the filename
254+
255+ # Corpus creation
256+ corpusData = TextData .availableCorpus [self .args .corpus ](self .corpusDir + optional )
257+ self .createFullCorpus (corpusData .getConversations ())
258+
259+ print ('Filtering words...' )
260+ self .filterFromFull () # Extract the sub vocabulary for the given maxLength and filterVocab
247261
248262 # Saving
249263 print ('Saving dataset...' )
250- self .saveDataset (dirName ) # Saving tf samples
264+ self .saveDataset () # Saving tf samples
251265 else :
252- self .loadDataset (dirName )
266+ self .loadDataset ()
253267
254268 assert self .padToken == 0
255269
256- def saveDataset (self , dirName ):
270+ def saveDataset (self ):
257271 """Save samples to file
258- Args:
259- dirName (str): The directory where to load/save the model
260272 """
261273
262- with open (os .path .join (dirName , self .samplesName ), 'wb' ) as handle :
274+ with open (os .path .join (self .filteredSamplesPath ), 'wb' ) as handle :
263275 data = { # Warning: If adding something here, also modifying loadDataset
264276 'word2id' : self .word2id ,
265277 'id2word' : self .id2word ,
266278 'trainingSamples' : self .trainingSamples
267279 }
268280 pickle .dump (data , handle , - 1 ) # Using the highest protocol available
269281
270- def loadDataset (self , dirName ):
282+ def loadDataset (self ):
271283 """Load samples from file
272- Args:
273- dirName (str): The directory where to load the model
274284 """
275- dataset_path = os .path .join (dirName , self .samplesName )
285+ dataset_path = os .path .join (self .filteredSamplesPath )
276286 print ('Loading dataset from {}' .format (dataset_path ))
277287 with open (dataset_path , 'rb' ) as handle :
278288 data = pickle .load (handle ) # Warning: If adding something here, also modifying saveDataset
@@ -285,8 +295,21 @@ def loadDataset(self, dirName):
285295 self .eosToken = self .word2id ['<eos>' ]
286296 self .unknownToken = self .word2id ['<unknown>' ] # Restore special words
287297
288- def createCorpus (self , conversations ):
289- """Extract all data from the given vocabulary
298+ def filterFromFull ():
299+ """ Load the pre-processed full corpus and filter the vocabulary / sentences
300+ to match the given model option
301+ """
302+ # 2 steps:
303+ # 1: first, iterate over all samples, and add sentences into the
304+ # training set. Decrement word count for unused sentences ?
305+ # 2: then, iterate over word count and compute new matching ids (can be unknown),
306+ # reiterate over the entire dataset to actualize new ids.
307+ pass # TODO
308+
309+ def createFullCorpus (self , conversations ):
310+ """Extract all data from the given vocabulary.
311+ Save the data on disk. Note that the entire corpus is pre-processed
312+ without restriction on the sentence lenght or vocab size.
290313 """
291314 # Add standard tokens
292315 self .padToken = self .getWordId ('<pad>' ) # Padding (Warning: first things to add > id=0 !!)
@@ -342,7 +365,7 @@ def extractText(self, line, isTarget=False):
342365 tokens = nltk .word_tokenize (sentencesToken [i ])
343366
344367 # If the total length is not too big, we still can add one more sentence
345- if len (words ) + len (tokens ) <= self .args .maxLength :
368+ if len (words ) + len (tokens ) <= self .args .maxLength : # TODO: Filter don't happen here
346369 tempWords = []
347370 for token in tokens :
348371 tempWords .append (self .getWordId (token )) # Create the vocabulary and the training sentences
0 commit comments