Skip to content

Commit 26d3c4f

Browse files
committed
transition to pyrotch
1 parent fe21745 commit 26d3c4f

File tree

2 files changed

+83
-75
lines changed

2 files changed

+83
-75
lines changed

bmstparser/src/mstlstm.py

Lines changed: 81 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.nn.init import *
5+
from torch.autograd import Variable
36
from utils import read_conll, write_conll
47
from operator import itemgetter
58
import utils, time, random, decoder
69
import numpy as np
710

811

12+
def Parameter(shape, init=xavier_uniform):
13+
return Variable(init(torch.Tensor(*shape)), requires_grad=True)
14+
15+
916
class MSTParserLSTMModel(nn.Module):
1017
def __init__(self, vocab, pos, rels, w2i, options):
18+
super(MSTParserLSTMModel, self).__init__()
1119
random.seed(1)
12-
self.activations = {'tanh': nn.Tanh(), 'sigmoid': nn.Sigmoid(), 'relu': nn.ReLU(),
13-
#Not yet supporting tanh3
14-
#'tanh3': (lambda x: nn.Tanh()(cwise_multiply(cwise_multiply(x, x), x)))
20+
self.activations = {'tanh': F.tanh, 'sigmoid': F.sigmoid, 'relu': F.relu,
21+
# Not yet supporting tanh3
22+
# 'tanh3': (lambda x: nn.Tanh()(cwise_multiply(cwise_multiply(x, x), x)))
1523
}
1624
self.activation = self.activations[options.activation]
1725

@@ -56,12 +64,13 @@ def __init__(self, vocab, pos, rels, w2i, options):
5664
self.bbuilders = [nn.LSTMCell(1, self.ldims * 2, self.ldims),
5765
nn.LSTMCell(1, self.ldims * 2, self.ldims)]
5866
elif self.layers > 0:
67+
assert self.layers == 1, 'Not yet support deep LSTM'
5968
self.builders = [
6069
nn.LSTMCell(self.layers, self.wdims + self.pdims + self.edim, self.ldims),
6170
nn.LSTMCell(self.layers, self.wdims + self.pdims + self.edim, self.ldims)]
6271
else:
63-
self.builders = [SimpleRNNBuilder(1, self.wdims + self.pdims + self.edim, self.ldims, self.model),
64-
SimpleRNNBuilder(1, self.wdims + self.pdims + self.edim, self.ldims, self.model)]
72+
self.builders = [nn.RNNCell(self.wdims + self.pdims + self.edim, self.ldims),
73+
nn.RNNCell(self.wdims + self.pdims + self.edim, self.ldims)]
6574

6675
self.hidden_units = options.hidden_units
6776
self.hidden2_units = options.hidden2_units
@@ -72,34 +81,31 @@ def __init__(self, vocab, pos, rels, w2i, options):
7281
self.vocab['*INITIAL*'] = 2
7382
self.pos['*INITIAL*'] = 2
7483

75-
self.wlookup = self.model.add_lookup_parameters((len(vocab) + 3, self.wdims))
76-
self.plookup = self.model.add_lookup_parameters((len(pos) + 3, self.pdims))
77-
self.rlookup = self.model.add_lookup_parameters((len(rels), self.rdims))
84+
self.wlookup = nn.Embedding(len(vocab) + 3, self.wdims)
85+
self.plookup = nn.Embedding(len(pos) + 3, self.pdims)
86+
self.rlookup = nn.Embedding(len(rels), self.rdims)
7887

79-
self.hidLayerFOH = self.model.add_parameters((self.hidden_units, self.ldims * 2))
80-
self.hidLayerFOM = self.model.add_parameters((self.hidden_units, self.ldims * 2))
81-
self.hidBias = self.model.add_parameters((self.hidden_units))
88+
self.hidLayerFOH = Parameter((self.hidden_units, self.ldims * 2))
89+
self.hidLayerFOM = Parameter((self.hidden_units, self.ldims * 2))
90+
self.hidBias = Parameter((self.hidden_units))
8291

83-
self.hid2Layer = self.model.add_parameters((self.hidden2_units, self.hidden_units))
84-
self.hid2Bias = self.model.add_parameters((self.hidden2_units))
92+
self.hid2Layer = Parameter((self.hidden2_units, self.hidden_units))
93+
self.hid2Bias = Parameter((self.hidden2_units))
8594

86-
self.outLayer = self.model.add_parameters(
95+
self.outLayer = Parameter(
8796
(1, self.hidden2_units if self.hidden2_units > 0 else self.hidden_units))
8897

8998
if self.labelsFlag:
90-
self.rhidLayerFOH = self.model.add_parameters((self.hidden_units, 2 * self.ldims))
91-
self.rhidLayerFOM = self.model.add_parameters((self.hidden_units, 2 * self.ldims))
92-
self.rhidBias = self.model.add_parameters((self.hidden_units))
99+
self.rhidLayerFOH = Parameter((self.hidden_units, 2 * self.ldims))
100+
self.rhidLayerFOM = Parameter((self.hidden_units, 2 * self.ldims))
101+
self.rhidBias = Parameter((self.hidden_units))
93102

94-
self.rhid2Layer = self.model.add_parameters((self.hidden2_units, self.hidden_units))
95-
self.rhid2Bias = self.model.add_parameters((self.hidden2_units))
103+
self.rhid2Layer = Parameter((self.hidden2_units, self.hidden_units))
104+
self.rhid2Bias = Parameter((self.hidden2_units))
96105

97-
self.routLayer = self.model.add_parameters(
106+
self.routLayer = Parameter(
98107
(len(self.irels), self.hidden2_units if self.hidden2_units > 0 else self.hidden_units))
99-
self.routBias = self.model.add_parameters((len(self.irels)))
100-
101-
class MSTParserLSTM:
102-
def __init__(self, vocab, pos, rels, w2i, options):
108+
self.routBias = Parameter((len(self.irels)))
103109

104110
def __getExpr(self, sentence, i, j, train):
105111

@@ -146,68 +152,70 @@ def Save(self, filename):
146152
def Load(self, filename):
147153
self.model.load(filename)
148154

149-
def Predict(self, conll_path):
150-
with open(conll_path, 'r') as conllFP:
151-
for iSentence, sentence in enumerate(read_conll(conllFP)):
152-
conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
155+
def predict(self, sentence):
156+
for entry in sentence:
157+
wordvec = self.wlookup[int(self.vocab.get(entry.norm, 0))] if self.wdims > 0 else None
158+
posvec = self.plookup[int(self.pos[entry.pos])] if self.pdims > 0 else None
159+
evec = self.elookup[int(self.extrnd.get(entry.form, self.extrnd.get(entry.norm,
160+
0)))] if self.external_embedding is not None else None
161+
entry.vec = concatenate(filter(None, [wordvec, posvec, evec]))
153162

154-
for entry in conll_sentence:
155-
wordvec = self.wlookup[int(self.vocab.get(entry.norm, 0))] if self.wdims > 0 else None
156-
posvec = self.plookup[int(self.pos[entry.pos])] if self.pdims > 0 else None
157-
evec = self.elookup[int(self.extrnd.get(entry.form, self.extrnd.get(entry.norm,
158-
0)))] if self.external_embedding is not None else None
159-
entry.vec = concatenate(filter(None, [wordvec, posvec, evec]))
163+
entry.lstms = [entry.vec, entry.vec]
164+
entry.headfov = None
165+
entry.modfov = None
160166

161-
entry.lstms = [entry.vec, entry.vec]
162-
entry.headfov = None
163-
entry.modfov = None
167+
entry.rheadfov = None
168+
entry.rmodfov = None
164169

165-
entry.rheadfov = None
166-
entry.rmodfov = None
170+
if self.blstmFlag:
171+
lstm_forward = self.builders[0].initial_state()
172+
lstm_backward = self.builders[1].initial_state()
167173

168-
if self.blstmFlag:
169-
lstm_forward = self.builders[0].initial_state()
170-
lstm_backward = self.builders[1].initial_state()
174+
for entry, rentry in zip(sentence, reversed(sentence)):
175+
lstm_forward = lstm_forward.add_input(entry.vec)
176+
lstm_backward = lstm_backward.add_input(rentry.vec)
171177

172-
for entry, rentry in zip(conll_sentence, reversed(conll_sentence)):
173-
lstm_forward = lstm_forward.add_input(entry.vec)
174-
lstm_backward = lstm_backward.add_input(rentry.vec)
178+
entry.lstms[1] = lstm_forward.output()
179+
rentry.lstms[0] = lstm_backward.output()
175180

176-
entry.lstms[1] = lstm_forward.output()
177-
rentry.lstms[0] = lstm_backward.output()
181+
if self.bibiFlag:
182+
for entry in sentence:
183+
entry.vec = concatenate(entry.lstms)
178184

179-
if self.bibiFlag:
180-
for entry in conll_sentence:
181-
entry.vec = concatenate(entry.lstms)
185+
blstm_forward = self.bbuilders[0].initial_state()
186+
blstm_backward = self.bbuilders[1].initial_state()
182187

183-
blstm_forward = self.bbuilders[0].initial_state()
184-
blstm_backward = self.bbuilders[1].initial_state()
188+
for entry, rentry in zip(sentence, reversed(sentence)):
189+
blstm_forward = blstm_forward.add_input(entry.vec)
190+
blstm_backward = blstm_backward.add_input(rentry.vec)
185191

186-
for entry, rentry in zip(conll_sentence, reversed(conll_sentence)):
187-
blstm_forward = blstm_forward.add_input(entry.vec)
188-
blstm_backward = blstm_backward.add_input(rentry.vec)
192+
entry.lstms[1] = blstm_forward.output()
193+
rentry.lstms[0] = blstm_backward.output()
189194

190-
entry.lstms[1] = blstm_forward.output()
191-
rentry.lstms[0] = blstm_backward.output()
195+
scores, exprs = self.__evaluate(sentence, True)
196+
heads = decoder.parse_proj(scores)
192197

193-
scores, exprs = self.__evaluate(conll_sentence, True)
194-
heads = decoder.parse_proj(scores)
198+
for entry, head in zip(sentence, heads):
199+
entry.pred_parent_id = head
200+
entry.pred_relation = '_'
195201

196-
for entry, head in zip(conll_sentence, heads):
197-
entry.pred_parent_id = head
198-
entry.pred_relation = '_'
202+
if self.labelsFlag:
203+
for modifier, head in enumerate(heads[1:]):
204+
scores, exprs = self.__evaluateLabel(sentence, head, modifier + 1)
205+
sentence[modifier + 1].pred_relation = self.irels[
206+
max(enumerate(scores), key=itemgetter(1))[0]]
199207

200-
dump = False
201208

202-
if self.labelsFlag:
203-
for modifier, head in enumerate(heads[1:]):
204-
scores, exprs = self.__evaluateLabel(conll_sentence, head, modifier + 1)
205-
conll_sentence[modifier + 1].pred_relation = self.irels[
206-
max(enumerate(scores), key=itemgetter(1))[0]]
209+
class MSTTrainer:
210+
def __init__(self, vocab, pos, rels, w2i, options):
211+
self.model = MSTParserLSTMModel(vocab, pos, rels, w2i, options)
207212

208-
renew_cg()
209-
if not dump:
210-
yield sentence
213+
def Predict(self, conll_path):
214+
with open(conll_path, 'r') as conllFP:
215+
for iSentence, sentence in enumerate(read_conll(conllFP)):
216+
conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
217+
self.model.predict(conll_sentence)
218+
yield conll_sentence
211219

212220
def Train(self, conll_path):
213221
errors = 0
@@ -249,7 +257,7 @@ def Train(self, conll_path):
249257

250258
if self.external_embedding is not None:
251259
evec = self.elookup[self.extrnd.get(entry.form, self.extrnd.get(entry.norm, 0)) if (
252-
dropFlag or (random.random() < 0.5)) else 0]
260+
dropFlag or (random.random() < 0.5)) else 0]
253261
entry.vec = concatenate(filter(None, [wordvec, posvec, evec]))
254262

255263
entry.lstms = [entry.vec, entry.vec]
@@ -293,7 +301,7 @@ def Train(self, conll_path):
293301
rscores, rexprs = self.__evaluateLabel(conll_sentence, head, modifier + 1)
294302
goldLabelInd = self.rels[conll_sentence[modifier + 1].relation]
295303
wrongLabelInd = \
296-
max(((l, scr) for l, scr in enumerate(rscores) if l != goldLabelInd), key=itemgetter(1))[0]
304+
max(((l, scr) for l, scr in enumerate(rscores) if l != goldLabelInd), key=itemgetter(1))[0]
297305
if rscores[goldLabelInd] < rscores[wrongLabelInd] + 1:
298306
lerrs.append(rexprs[wrongLabelInd] - rexprs[goldLabelInd])
299307

bmstparser/src/parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
stored_opt.external_embedding = options.external_embedding
4444

4545
print 'Initializing lstm mstparser:'
46-
parser = mstlstm.MSTParserLSTM(words, pos, rels, w2i, stored_opt)
46+
parser = mstlstm.MSTTrainer(words, pos, rels, w2i, stored_opt)
4747

4848
parser.Load(options.model)
4949
conllu = (os.path.splitext(options.conll_test.lower())[1] == '.conllu')
@@ -69,7 +69,7 @@
6969
print 'Finished collecting vocab'
7070

7171
print 'Initializing lstm mstparser:'
72-
parser = mstlstm.MSTParserLSTM(words, pos, rels, w2i, options)
72+
parser = mstlstm.MSTTrainer(words, pos, rels, w2i, options)
7373

7474
for epoch in xrange(options.epochs):
7575
print 'Starting epoch', epoch

0 commit comments

Comments
 (0)