@@ -211,12 +211,6 @@ def __evaluateLabel(self, sentence, i, j):
211211
212212 return output .data .numpy ()[0 ], output [0 ]
213213
214- def Save (self , filename ):
215- self .model .save (filename )
216-
217- def Load (self , filename ):
218- self .model .load (filename )
219-
220214 def predict (self , sentence ):
221215 for entry in sentence :
222216 wordvec = self .wlookup (scalar (int (self .vocab .get (entry .norm , 0 )))) if self .wdims > 0 else None
@@ -269,7 +263,7 @@ def predict(self, sentence):
269263 scores , exprs = self .__evaluateLabel (sentence , head , modifier + 1 )
270264 sentence [modifier + 1 ].pred_relation = self .irels [max (enumerate (scores ), key = itemgetter (1 ))[0 ]]
271265
272- def get_loss (self , sentence , errs , lerrs ):
266+ def forward (self , sentence , errs , lerrs ):
273267
274268 for entry in sentence :
275269 c = float (self .wordsCount .get (entry .norm , 0 ))
@@ -358,6 +352,9 @@ def save(self, fn):
358352 torch .save (self .model .state_dict (), tmp )
359353 shutil .move (tmp , fn )
360354
355+ def load (self , fn ):
356+ self .model .load_state_dict (torch .load (fn ))
357+
361358 def train (self , conll_path ):
362359 print torch .__version__
363360 batch = 1
@@ -384,7 +381,7 @@ def train(self, conll_path):
384381 etotal = 0
385382
386383 conll_sentence = [entry for entry in sentence if isinstance (entry , utils .ConllEntry )]
387- e = self .model .get_loss (conll_sentence , errs , lerrs )
384+ e = self .model .forward (conll_sentence , errs , lerrs )
388385 eerrors += e
389386 eloss += e
390387 mloss += e
0 commit comments