Skip to content

Commit 13c58c9

Browse files
Update config saving/restoring
1 parent fa451ce commit 13c58c9

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

chatbot/chatbot.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self):
7070
self.MODEL_NAME_BASE = 'model'
7171
self.MODEL_EXT = '.ckpt'
7272
self.CONFIG_FILENAME = 'params.ini'
73-
self.CONFIG_VERSION = '0.4'
73+
self.CONFIG_VERSION = '0.5'
7474
self.TEST_IN_NAME = 'data/test/samples.txt'
7575
self.TEST_OUT_SUFFIX = '_predictions.txt'
7676
self.SENTENCES_PREFIX = ['Q: ', 'A: ']
@@ -113,7 +113,7 @@ def parseArgs(args):
113113
datasetArgs.add_argument('--datasetTag', type=str, default='', help='add a tag to the dataset (file where to load the vocabulary and the precomputed samples, not the original corpus). Useful to manage multiple versions. Also used to define the file used for the lightweight format.') # The samples are computed from the corpus if it does not exist already. There are saved in \'data/samples/\'
114114
datasetArgs.add_argument('--ratioDataset', type=float, default=1.0, help='ratio of dataset used to avoid using the whole dataset') # Not implemented, useless ?
115115
datasetArgs.add_argument('--maxLength', type=int, default=10, help='maximum length of the sentence (for input and output), define number of maximum step of the RNN')
116-
datasetArgs.add_argument('--lightweightFile', type=str, default=None, help='file containing our lightweight-formatted corpus')
116+
datasetArgs.add_argument('--filterVocab', type=int, default=1, help='remove rarelly used words (by default words used only once). 0 to keep all words.')
117117

118118
# Network options (Warning: if modifying something here, also make the change on save/loadParams() )
119119
nnArgs = parser.add_argument_group('Network options', 'architecture related option')
@@ -536,18 +536,20 @@ def loadModelParams(self):
536536

537537
# Restoring the the parameters
538538
self.globStep = config['General'].getint('globStep')
539-
self.args.maxLength = config['General'].getint('maxLength') # We need to restore the model length because of the textData associated and the vocabulary size (TODO: Compatibility mode between different maxLength)
540539
self.args.watsonMode = config['General'].getboolean('watsonMode')
541540
self.args.autoEncode = config['General'].getboolean('autoEncode')
542541
self.args.corpus = config['General'].get('corpus')
543-
self.args.datasetTag = config['General'].get('datasetTag', '')
544-
self.args.embeddingSource = config['General'].get('embeddingSource', '')
542+
543+
self.args.datasetTag = config['Dataset'].get('datasetTag')
544+
self.args.maxLength = config['Dataset'].getint('maxLength') # We need to restore the model length because of the textData associated and the vocabulary size (TODO: Compatibility mode between different maxLength)
545+
self.args.filterVocab = config['Dataset'].getint('filterVocab')
545546

546547
self.args.hiddenSize = config['Network'].getint('hiddenSize')
547548
self.args.numLayers = config['Network'].getint('numLayers')
548549
self.args.softmaxSamples = config['Network'].getint('softmaxSamples')
549550
self.args.initEmbeddings = config['Network'].getboolean('initEmbeddings')
550551
self.args.embeddingSize = config['Network'].getint('embeddingSize')
552+
self.args.embeddingSource = config['Network'].get('embeddingSource')
551553

552554

553555
# No restoring for training params, batch size or other non model dependent parameters
@@ -556,11 +558,12 @@ def loadModelParams(self):
556558
print()
557559
print('Warning: Restoring parameters:')
558560
print('globStep: {}'.format(self.globStep))
559-
print('maxLength: {}'.format(self.args.maxLength))
560561
print('watsonMode: {}'.format(self.args.watsonMode))
561562
print('autoEncode: {}'.format(self.args.autoEncode))
562563
print('corpus: {}'.format(self.args.corpus))
563564
print('datasetTag: {}'.format(self.args.datasetTag))
565+
print('maxLength: {}'.format(self.args.maxLength))
566+
print('filterVocab: {}'.format(self.args.filterVocab))
564567
print('hiddenSize: {}'.format(self.args.hiddenSize))
565568
print('numLayers: {}'.format(self.args.numLayers))
566569
print('softmaxSamples: {}'.format(self.args.softmaxSamples))
@@ -585,19 +588,22 @@ def saveModelParams(self):
585588
config['General'] = {}
586589
config['General']['version'] = self.CONFIG_VERSION
587590
config['General']['globStep'] = str(self.globStep)
588-
config['General']['maxLength'] = str(self.args.maxLength)
589591
config['General']['watsonMode'] = str(self.args.watsonMode)
590592
config['General']['autoEncode'] = str(self.args.autoEncode)
591593
config['General']['corpus'] = str(self.args.corpus)
592-
config['General']['datasetTag'] = str(self.args.datasetTag)
593-
config['General']['embeddingSource'] = str(self.args.embeddingSource)
594+
595+
config['Dataset'] = {}
596+
config['Dataset']['datasetTag'] = str(self.args.datasetTag)
597+
config['Dataset']['maxLength'] = str(self.args.maxLength)
598+
config['Dataset']['filterVocab'] = str(self.args.filterVocab)
594599

595600
config['Network'] = {}
596601
config['Network']['hiddenSize'] = str(self.args.hiddenSize)
597602
config['Network']['numLayers'] = str(self.args.numLayers)
598603
config['Network']['softmaxSamples'] = str(self.args.softmaxSamples)
599604
config['Network']['initEmbeddings'] = str(self.args.initEmbeddings)
600605
config['Network']['embeddingSize'] = str(self.args.embeddingSize)
606+
config['Network']['embeddingSource'] = str(self.args.embeddingSource)
601607

602608
# Keep track of the learning params (but without restoring them)
603609
config['Training (won\'t be restored)'] = {}

0 commit comments

Comments
 (0)