11import torch
22import torch .nn as nn
3+ import torch .nn .functional as F
4+ from torch .nn .init import *
5+ from torch .autograd import Variable
36from utils import read_conll , write_conll
47from operator import itemgetter
58import utils , time , random , decoder
69import numpy as np
710
811
12+ def Parameter (shape , init = xavier_uniform ):
13+ return Variable (init (torch .Tensor (* shape )), requires_grad = True )
14+
15+
916class MSTParserLSTMModel (nn .Module ):
1017 def __init__ (self , vocab , pos , rels , w2i , options ):
18+ super (MSTParserLSTMModel , self ).__init__ ()
1119 random .seed (1 )
12- self .activations = {'tanh' : nn . Tanh () , 'sigmoid' : nn . Sigmoid () , 'relu' : nn . ReLU () ,
13- #Not yet supporting tanh3
14- #'tanh3': (lambda x: nn.Tanh()(cwise_multiply(cwise_multiply(x, x), x)))
20+ self .activations = {'tanh' : F . tanh , 'sigmoid' : F . sigmoid , 'relu' : F . relu ,
21+ # Not yet supporting tanh3
22+ # 'tanh3': (lambda x: nn.Tanh()(cwise_multiply(cwise_multiply(x, x), x)))
1523 }
1624 self .activation = self .activations [options .activation ]
1725
@@ -56,12 +64,13 @@ def __init__(self, vocab, pos, rels, w2i, options):
5664 self .bbuilders = [nn .LSTMCell (1 , self .ldims * 2 , self .ldims ),
5765 nn .LSTMCell (1 , self .ldims * 2 , self .ldims )]
5866 elif self .layers > 0 :
67+ assert self .layers == 1 , 'Not yet support deep LSTM'
5968 self .builders = [
6069 nn .LSTMCell (self .layers , self .wdims + self .pdims + self .edim , self .ldims ),
6170 nn .LSTMCell (self .layers , self .wdims + self .pdims + self .edim , self .ldims )]
6271 else :
63- self .builders = [SimpleRNNBuilder ( 1 , self .wdims + self .pdims + self .edim , self .ldims , self . model ),
64- SimpleRNNBuilder ( 1 , self .wdims + self .pdims + self .edim , self .ldims , self . model )]
72+ self .builders = [nn . RNNCell ( self .wdims + self .pdims + self .edim , self .ldims ),
73+ nn . RNNCell ( self .wdims + self .pdims + self .edim , self .ldims )]
6574
6675 self .hidden_units = options .hidden_units
6776 self .hidden2_units = options .hidden2_units
@@ -72,34 +81,31 @@ def __init__(self, vocab, pos, rels, w2i, options):
7281 self .vocab ['*INITIAL*' ] = 2
7382 self .pos ['*INITIAL*' ] = 2
7483
75- self .wlookup = self . model . add_lookup_parameters (( len (vocab ) + 3 , self .wdims ) )
76- self .plookup = self . model . add_lookup_parameters (( len (pos ) + 3 , self .pdims ) )
77- self .rlookup = self . model . add_lookup_parameters (( len (rels ), self .rdims ) )
84+ self .wlookup = nn . Embedding ( len (vocab ) + 3 , self .wdims )
85+ self .plookup = nn . Embedding ( len (pos ) + 3 , self .pdims )
86+ self .rlookup = nn . Embedding ( len (rels ), self .rdims )
7887
79- self .hidLayerFOH = self . model . add_parameters ((self .hidden_units , self .ldims * 2 ))
80- self .hidLayerFOM = self . model . add_parameters ((self .hidden_units , self .ldims * 2 ))
81- self .hidBias = self . model . add_parameters ((self .hidden_units ))
88+ self .hidLayerFOH = Parameter ((self .hidden_units , self .ldims * 2 ))
89+ self .hidLayerFOM = Parameter ((self .hidden_units , self .ldims * 2 ))
90+ self .hidBias = Parameter ((self .hidden_units ))
8291
83- self .hid2Layer = self . model . add_parameters ((self .hidden2_units , self .hidden_units ))
84- self .hid2Bias = self . model . add_parameters ((self .hidden2_units ))
92+ self .hid2Layer = Parameter ((self .hidden2_units , self .hidden_units ))
93+ self .hid2Bias = Parameter ((self .hidden2_units ))
8594
86- self .outLayer = self . model . add_parameters (
95+ self .outLayer = Parameter (
8796 (1 , self .hidden2_units if self .hidden2_units > 0 else self .hidden_units ))
8897
8998 if self .labelsFlag :
90- self .rhidLayerFOH = self . model . add_parameters ((self .hidden_units , 2 * self .ldims ))
91- self .rhidLayerFOM = self . model . add_parameters ((self .hidden_units , 2 * self .ldims ))
92- self .rhidBias = self . model . add_parameters ((self .hidden_units ))
99+ self .rhidLayerFOH = Parameter ((self .hidden_units , 2 * self .ldims ))
100+ self .rhidLayerFOM = Parameter ((self .hidden_units , 2 * self .ldims ))
101+ self .rhidBias = Parameter ((self .hidden_units ))
93102
94- self .rhid2Layer = self . model . add_parameters ((self .hidden2_units , self .hidden_units ))
95- self .rhid2Bias = self . model . add_parameters ((self .hidden2_units ))
103+ self .rhid2Layer = Parameter ((self .hidden2_units , self .hidden_units ))
104+ self .rhid2Bias = Parameter ((self .hidden2_units ))
96105
97- self .routLayer = self . model . add_parameters (
106+ self .routLayer = Parameter (
98107 (len (self .irels ), self .hidden2_units if self .hidden2_units > 0 else self .hidden_units ))
99- self .routBias = self .model .add_parameters ((len (self .irels )))
100-
101- class MSTParserLSTM :
102- def __init__ (self , vocab , pos , rels , w2i , options ):
108+ self .routBias = Parameter ((len (self .irels )))
103109
104110 def __getExpr (self , sentence , i , j , train ):
105111
@@ -146,68 +152,70 @@ def Save(self, filename):
146152 def Load (self , filename ):
147153 self .model .load (filename )
148154
149- def Predict (self , conll_path ):
150- with open (conll_path , 'r' ) as conllFP :
151- for iSentence , sentence in enumerate (read_conll (conllFP )):
152- conll_sentence = [entry for entry in sentence if isinstance (entry , utils .ConllEntry )]
155+ def predict (self , sentence ):
156+ for entry in sentence :
157+ wordvec = self .wlookup [int (self .vocab .get (entry .norm , 0 ))] if self .wdims > 0 else None
158+ posvec = self .plookup [int (self .pos [entry .pos ])] if self .pdims > 0 else None
159+ evec = self .elookup [int (self .extrnd .get (entry .form , self .extrnd .get (entry .norm ,
160+ 0 )))] if self .external_embedding is not None else None
161+ entry .vec = concatenate (filter (None , [wordvec , posvec , evec ]))
153162
154- for entry in conll_sentence :
155- wordvec = self .wlookup [int (self .vocab .get (entry .norm , 0 ))] if self .wdims > 0 else None
156- posvec = self .plookup [int (self .pos [entry .pos ])] if self .pdims > 0 else None
157- evec = self .elookup [int (self .extrnd .get (entry .form , self .extrnd .get (entry .norm ,
158- 0 )))] if self .external_embedding is not None else None
159- entry .vec = concatenate (filter (None , [wordvec , posvec , evec ]))
163+ entry .lstms = [entry .vec , entry .vec ]
164+ entry .headfov = None
165+ entry .modfov = None
160166
161- entry .lstms = [entry .vec , entry .vec ]
162- entry .headfov = None
163- entry .modfov = None
167+ entry .rheadfov = None
168+ entry .rmodfov = None
164169
165- entry .rheadfov = None
166- entry .rmodfov = None
170+ if self .blstmFlag :
171+ lstm_forward = self .builders [0 ].initial_state ()
172+ lstm_backward = self .builders [1 ].initial_state ()
167173
168- if self . blstmFlag :
169- lstm_forward = self . builders [ 0 ]. initial_state ( )
170- lstm_backward = self . builders [ 1 ]. initial_state ( )
174+ for entry , rentry in zip ( sentence , reversed ( sentence )) :
175+ lstm_forward = lstm_forward . add_input ( entry . vec )
176+ lstm_backward = lstm_backward . add_input ( rentry . vec )
171177
172- for entry , rentry in zip (conll_sentence , reversed (conll_sentence )):
173- lstm_forward = lstm_forward .add_input (entry .vec )
174- lstm_backward = lstm_backward .add_input (rentry .vec )
178+ entry .lstms [1 ] = lstm_forward .output ()
179+ rentry .lstms [0 ] = lstm_backward .output ()
175180
176- entry .lstms [1 ] = lstm_forward .output ()
177- rentry .lstms [0 ] = lstm_backward .output ()
181+ if self .bibiFlag :
182+ for entry in sentence :
183+ entry .vec = concatenate (entry .lstms )
178184
179- if self .bibiFlag :
180- for entry in conll_sentence :
181- entry .vec = concatenate (entry .lstms )
185+ blstm_forward = self .bbuilders [0 ].initial_state ()
186+ blstm_backward = self .bbuilders [1 ].initial_state ()
182187
183- blstm_forward = self .bbuilders [0 ].initial_state ()
184- blstm_backward = self .bbuilders [1 ].initial_state ()
188+ for entry , rentry in zip (sentence , reversed (sentence )):
189+ blstm_forward = blstm_forward .add_input (entry .vec )
190+ blstm_backward = blstm_backward .add_input (rentry .vec )
185191
186- for entry , rentry in zip (conll_sentence , reversed (conll_sentence )):
187- blstm_forward = blstm_forward .add_input (entry .vec )
188- blstm_backward = blstm_backward .add_input (rentry .vec )
192+ entry .lstms [1 ] = blstm_forward .output ()
193+ rentry .lstms [0 ] = blstm_backward .output ()
189194
190- entry . lstms [ 1 ] = blstm_forward . output ( )
191- rentry . lstms [ 0 ] = blstm_backward . output ( )
195+ scores , exprs = self . __evaluate ( sentence , True )
196+ heads = decoder . parse_proj ( scores )
192197
193- scores , exprs = self .__evaluate (conll_sentence , True )
194- heads = decoder .parse_proj (scores )
198+ for entry , head in zip (sentence , heads ):
199+ entry .pred_parent_id = head
200+ entry .pred_relation = '_'
195201
196- for entry , head in zip (conll_sentence , heads ):
197- entry .pred_parent_id = head
198- entry .pred_relation = '_'
202+ if self .labelsFlag :
203+ for modifier , head in enumerate (heads [1 :]):
204+ scores , exprs = self .__evaluateLabel (sentence , head , modifier + 1 )
205+ sentence [modifier + 1 ].pred_relation = self .irels [
206+ max (enumerate (scores ), key = itemgetter (1 ))[0 ]]
199207
200- dump = False
201208
202- if self .labelsFlag :
203- for modifier , head in enumerate (heads [1 :]):
204- scores , exprs = self .__evaluateLabel (conll_sentence , head , modifier + 1 )
205- conll_sentence [modifier + 1 ].pred_relation = self .irels [
206- max (enumerate (scores ), key = itemgetter (1 ))[0 ]]
209+ class MSTTrainer :
210+ def __init__ (self , vocab , pos , rels , w2i , options ):
211+ self .model = MSTParserLSTMModel (vocab , pos , rels , w2i , options )
207212
208- renew_cg ()
209- if not dump :
210- yield sentence
213+ def Predict (self , conll_path ):
214+ with open (conll_path , 'r' ) as conllFP :
215+ for iSentence , sentence in enumerate (read_conll (conllFP )):
216+ conll_sentence = [entry for entry in sentence if isinstance (entry , utils .ConllEntry )]
217+ self .model .predict (conll_sentence )
218+ yield conll_sentence
211219
212220 def Train (self , conll_path ):
213221 errors = 0
@@ -249,7 +257,7 @@ def Train(self, conll_path):
249257
250258 if self .external_embedding is not None :
251259 evec = self .elookup [self .extrnd .get (entry .form , self .extrnd .get (entry .norm , 0 )) if (
252- dropFlag or (random .random () < 0.5 )) else 0 ]
260+ dropFlag or (random .random () < 0.5 )) else 0 ]
253261 entry .vec = concatenate (filter (None , [wordvec , posvec , evec ]))
254262
255263 entry .lstms = [entry .vec , entry .vec ]
@@ -293,7 +301,7 @@ def Train(self, conll_path):
293301 rscores , rexprs = self .__evaluateLabel (conll_sentence , head , modifier + 1 )
294302 goldLabelInd = self .rels [conll_sentence [modifier + 1 ].relation ]
295303 wrongLabelInd = \
296- max (((l , scr ) for l , scr in enumerate (rscores ) if l != goldLabelInd ), key = itemgetter (1 ))[0 ]
304+ max (((l , scr ) for l , scr in enumerate (rscores ) if l != goldLabelInd ), key = itemgetter (1 ))[0 ]
297305 if rscores [goldLabelInd ] < rscores [wrongLabelInd ] + 1 :
298306 lerrs .append (rexprs [wrongLabelInd ] - rexprs [goldLabelInd ])
299307
0 commit comments