Skip to content

Commit 1448018

Browse files
committed
the trainer looks bug free
1 parent 5395553 commit 1448018

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

bmstparser/src/mstlstm.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,6 @@ def __evaluateLabel(self, sentence, i, j):
211211

212212
return output.data.numpy()[0], output[0]
213213

214-
def Save(self, filename):
215-
self.model.save(filename)
216-
217-
def Load(self, filename):
218-
self.model.load(filename)
219-
220214
def predict(self, sentence):
221215
for entry in sentence:
222216
wordvec = self.wlookup(scalar(int(self.vocab.get(entry.norm, 0)))) if self.wdims > 0 else None
@@ -269,7 +263,7 @@ def predict(self, sentence):
269263
scores, exprs = self.__evaluateLabel(sentence, head, modifier + 1)
270264
sentence[modifier + 1].pred_relation = self.irels[max(enumerate(scores), key=itemgetter(1))[0]]
271265

272-
def get_loss(self, sentence, errs, lerrs):
266+
def forward(self, sentence, errs, lerrs):
273267

274268
for entry in sentence:
275269
c = float(self.wordsCount.get(entry.norm, 0))
@@ -358,6 +352,9 @@ def save(self, fn):
358352
torch.save(self.model.state_dict(), tmp)
359353
shutil.move(tmp, fn)
360354

355+
def load(self, fn):
356+
self.model.load_state_dict(torch.load(fn))
357+
361358
def train(self, conll_path):
362359
print torch.__version__
363360
batch = 1
@@ -384,7 +381,7 @@ def train(self, conll_path):
384381
etotal = 0
385382

386383
conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
387-
e = self.model.get_loss(conll_sentence, errs, lerrs)
384+
e = self.model.forward(conll_sentence, errs, lerrs)
388385
eerrors += e
389386
eloss += e
390387
mloss += e

bmstparser/src/parser.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,21 @@
4545

4646
print 'Initializing lstm mstparser:'
4747
parser = mstlstm.MSTParserLSTM(words, pos, rels, w2i, stored_opt)
48-
49-
parser.Load(options.model)
48+
parser.load(options.model)
5049
conllu = (os.path.splitext(options.conll_test.lower())[1] == '.conllu')
51-
tespath = os.path.join(options.output, 'test_pred.conll' if not conllu else 'test_pred.conllu')
50+
testpath = os.path.join(options.output, 'test_pred.conll' if not conllu else 'test_pred.conllu')
5251

5352
ts = time.time()
5453
test_res = list(parser.predict(options.conll_test))
5554
te = time.time()
5655
print 'Finished predicting test.', te - ts, 'seconds.'
57-
utils.write_conll(tespath, test_res)
56+
utils.write_conll(testpath, test_res)
5857

5958
if not conllu:
60-
os.system('perl src/utils/eval.pl -g ' + options.conll_test + ' -s ' + tespath + ' > ' + tespath + '.txt')
59+
os.system('perl src/utils/eval.pl -g ' + options.conll_test + ' -s ' + testpath + ' > ' + testpath + '.txt')
6160
else:
6261
os.system(
63-
'python src/utils/evaluation_script/conll17_ud_eval.py -v -w src/utils/evaluation_script/weights.clas ' + options.conll_test + ' ' + tespath + ' > ' + testpath + '.txt')
62+
'python src/utils/evaluation_script/conll17_ud_eval.py -v -w src/utils/evaluation_script/weights.clas ' + options.conll_test + ' ' + testpath + ' > ' + testpath + '.txt')
6463
else:
6564
print 'Preparing vocab'
6665
words, w2i, pos, rels = utils.vocab(options.conll_train)

0 commit comments

Comments
 (0)