@@ -90,17 +90,20 @@ def __init__(self, args):
9090 self .trainingSamples = [] # 2d array containing each question and his answer [[input,target]]
9191
9292 self .word2id = {}
93- self .id2word = {} # For a rapid conversion
93+ self .id2word = {} # For a rapid conversion (TODO: Could replace dict by list)
9494 self .idCount = {} # Useful to filters the words
9595
9696 self .loadCorpus ()
9797
9898 # Plot some stats:
99- print ( 'Loaded {}: {} words, {} QA' . format ( self .args . corpus , len ( self . word2id ), len ( self . trainingSamples )) )
99+ self ._printStats ( )
100100
101101 if self .args .playDataset :
102102 self .playDataset ()
103103
104+ def _printStats (self ):
105+ print ('Loaded {}: {} words, {} QA' .format (self .args .corpus , len (self .word2id ), len (self .trainingSamples )))
106+
104107 def _constructBasePath (self ):
105108 """Return the name of the base prefix of the current dataset
106109 """
@@ -255,39 +258,49 @@ def loadCorpus(self):
255258 # Corpus creation
256259 corpusData = TextData .availableCorpus [self .args .corpus ](self .corpusDir + optional )
257260 self .createFullCorpus (corpusData .getConversations ())
261+ self ._printStats ()
262+ self .saveDataset (self .fullSamplesPath )
263+ else :
264+ self .loadDataset (self .fullSamplesPath )
258265
259266 print ('Filtering words...' )
260267 self .filterFromFull () # Extract the sub vocabulary for the given maxLength and filterVocab
261268
262269 # Saving
263270 print ('Saving dataset...' )
264- self .saveDataset () # Saving tf samples
271+ self .saveDataset (self . filteredSamplesPath ) # Saving tf samples
265272 else :
266- self .loadDataset ()
273+ self .loadDataset (self . filteredSamplesPath )
267274
268275 assert self .padToken == 0
269276
270- def saveDataset (self ):
277+ def saveDataset (self , filename ):
271278 """Save samples to file
279+ Args:
280+ filename (str): pickle filename
272281 """
273282
274- with open (os .path .join (self . filteredSamplesPath ), 'wb' ) as handle :
283+ with open (os .path .join (filename ), 'wb' ) as handle :
275284 data = { # Warning: If adding something here, also modifying loadDataset
276285 'word2id' : self .word2id ,
277286 'id2word' : self .id2word ,
287+ 'idCount' : self .idCount ,
278288 'trainingSamples' : self .trainingSamples
279- }
289+ }
280290 pickle .dump (data , handle , - 1 ) # Using the highest protocol available
281291
282- def loadDataset (self ):
292+ def loadDataset (self , filename ):
283293 """Load samples from file
294+ Args:
295+ filename (str): pickle filename
284296 """
285- dataset_path = os .path .join (self . filteredSamplesPath )
297+ dataset_path = os .path .join (filename )
286298 print ('Loading dataset from {}' .format (dataset_path ))
287299 with open (dataset_path , 'rb' ) as handle :
288300 data = pickle .load (handle ) # Warning: If adding something here, also modifying saveDataset
289301 self .word2id = data ['word2id' ]
290302 self .id2word = data ['id2word' ]
303+ self .idCount = data ['idCount' ]
291304 self .trainingSamples = data ['trainingSamples' ]
292305
293306 self .padToken = self .word2id ['<pad>' ]
@@ -305,6 +318,33 @@ def filterFromFull():
305318 # 2: then, iterate over word count and compute new matching ids (can be unknown),
306319 # reiterate over the entire dataset to actualize new ids.
307320 pass # TODO
321+ words = []
322+
323+ # Extract sentences
324+ sentencesToken = nltk .sent_tokenize (line )
325+
326+ # We add sentence by sentence until we reach the maximum length
327+ for i in range (len (sentencesToken )):
328+ # If question: we only keep the last sentences
329+ # If answer: we only keep the first sentences
330+ if not isTarget :
331+ i = len (sentencesToken )- 1 - i
332+
333+ tokens = nltk .word_tokenize (sentencesToken [i ])
334+
335+ # If the total length is not too big, we still can add one more sentence
336+ if len (words ) + len (tokens ) <= self .args .maxLength : # TODO: Filter don't happen here
337+ tempWords = []
338+ for token in tokens :
339+ tempWords .append (self .getWordId (token )) # Create the vocabulary and the training sentences
340+
341+ if isTarget :
342+ words = words + tempWords
343+ else :
344+ words = tempWords + words
345+ else :
346+ break # We reach the max length already
347+
308348
309349 def createFullCorpus (self , conversations ):
310350 """Extract all data from the given vocabulary.
@@ -350,34 +390,22 @@ def extractText(self, line, isTarget=False):
350390 Return:
351391 list<int>: the list of the word ids of the sentence
352392 """
353- words = []
393+ sentences = [] # List[List[str] ]
354394
355395 # Extract sentences
356396 sentencesToken = nltk .sent_tokenize (line )
357397
358398 # We add sentence by sentence until we reach the maximum length
359399 for i in range (len (sentencesToken )):
360- # If question: we only keep the last sentences
361- # If answer: we only keep the first sentences
362- if not isTarget :
363- i = len (sentencesToken )- 1 - i
364-
365400 tokens = nltk .word_tokenize (sentencesToken [i ])
366401
367- # If the total length is not too big, we still can add one more sentence
368- if len (words ) + len (tokens ) <= self .args .maxLength : # TODO: Filter don't happen here
369- tempWords = []
370- for token in tokens :
371- tempWords .append (self .getWordId (token )) # Create the vocabulary and the training sentences
402+ tempWords = []
403+ for token in tokens :
404+ tempWords .append (self .getWordId (token )) # Create the vocabulary and the training sentences
372405
373- if isTarget :
374- words = words + tempWords
375- else :
376- words = tempWords + words
377- else :
378- break # We reach the max length already
406+ sentences .append (tempWords )
379407
380- return words
408+ return sentences
381409
382410 def getWordId (self , word , create = True ):
383411 """Get the id of the word (and add it to the dictionary if not existing). If the word does not exist and
@@ -392,17 +420,19 @@ def getWordId(self, word, create=True):
392420
393421 word = word .lower () # Ignore case
394422
423+ # At inference, we simply look up for the word
424+ if not create :
425+ wordId = self .word2id .get (word , self .unknownToken )
395426 # Get the id if the word already exist
396- wordId = self .word2id .get (word , - 1 )
397-
427+ elif word in self .word2id :
428+ wordId = self .word2id [word ]
429+ self .idCount [wordId ] += 1
398430 # If not, we create a new entry
399- if wordId == - 1 :
400- if create :
401- wordId = len (self .word2id )
402- self .word2id [word ] = wordId
403- self .id2word [wordId ] = word
404- else :
405- wordId = self .unknownToken
431+ else :
432+ wordId = len (self .word2id )
433+ self .word2id [word ] = wordId
434+ self .id2word [wordId ] = word
435+ self .idCount [wordId ] = 1
406436
407437 return wordId
408438
0 commit comments