@@ -559,9 +559,7 @@ int main(int argc, char** argv) {
559559// =====================================================================================================================
560560
561561 parser::StandardOracle corpus (&termdict, &adict, &posdict, &arcdict);
562- parser::StandardOracle dev_corpus (&termdict, &adict, &posdict, &arcdict);
563562 corpus.load_oracle (params.train_file , true );
564- dev_corpus.load_oracle (params.dev_file , true );
565563
566564 if (params.words_file != " " ) {
567565 cerr << " Loading from " << params.words_file << " with" << params.pretrained_dim << " dimensions\n " ;
@@ -580,6 +578,8 @@ int main(int argc, char** argv) {
580578 }
581579
582580 termdict.freeze ();
581+ termdict.set_unk (" UNK" );
582+ adict.convert (" RIGHT-ARC(preconj)" ); // dev data has new possible action
583583 adict.freeze ();
584584 arcdict.freeze ();
585585 posdict.freeze ();
@@ -593,11 +593,12 @@ int main(int argc, char** argv) {
593593 if (wc.second == 1 ) singletons[wc.first ] = true ;
594594 }
595595
596- ARC_SIZE = arcdict.size ()+ 10 ;
597- POS_SIZE = posdict.size ()+ 10 ;
598- VOCAB_SIZE = termdict.size ()+ 10 ;
599- ACTION_SIZE = adict.size ()+ 10 ;
596+ ARC_SIZE = arcdict.size ();
597+ POS_SIZE = posdict.size ();
598+ VOCAB_SIZE = termdict.size ();
599+ ACTION_SIZE = adict.size ();
600600
601+
601602 for (unsigned i = 0 ; i < adict.size (); ++i) possible_actions.push_back (i);
602603
603604 cerr<<" action:\n " ;
@@ -615,6 +616,11 @@ int main(int argc, char** argv) {
615616 cerr<<i<<" :" <<arcdict.convert (i)<<" \n " ;
616617 }
617618
619+ parser::StandardOracle dev_corpus (&termdict, &adict, &posdict, &arcdict);
620+ parser::StandardOracle test_corpus (&termdict, &adict, &posdict, &arcdict);
621+ if (params.dev_file != " " ) dev_corpus.load_oracle (params.dev_file , true );
622+ if (params.test_file != " " ) test_corpus.load_oracle (params.test_file , true );
623+
618624// ==========================================================================================================================
619625
620626 Model model;
@@ -763,8 +769,8 @@ int main(int argc, char** argv) {
763769 delete sgd;
764770 } // should do training?
765771 else { // do test evaluation
766- ofstream out (" test.out" );
767- unsigned test_size = dev_corpus .size ();
772+ ofstream out (" test.out" );
773+ unsigned test_size = test_corpus .size ();
768774
769775 double llh = 0 ;
770776 double trs = 0 ;
@@ -774,8 +780,8 @@ int main(int argc, char** argv) {
774780 double total_heads = 0 ;
775781 if (params.samples !=0 ){
776782 for (unsigned sii = 0 ; sii < test_size; ++sii) {
777- const auto & sentence=dev_corpus .sents [sii];
778- const vector<int >& actions=dev_corpus .actions [sii];
783+ const auto & sentence=test_corpus .sents [sii];
784+ const vector<int >& actions=test_corpus .actions [sii];
779785 for (unsigned z = 0 ; z < params.samples ; ++z) {
780786 ComputationGraph hg;
781787 vector<int > pred;
@@ -790,8 +796,8 @@ int main(int argc, char** argv) {
790796 }
791797 auto t_start = std::chrono::high_resolution_clock::now ();
792798 for (unsigned sii = 0 ; sii < test_size; ++sii) {
793- const auto & sentence=dev_corpus .sents [sii];
794- const vector<int >& actions=dev_corpus .actions [sii];
799+ const auto & sentence=test_corpus .sents [sii];
800+ const vector<int >& actions=test_corpus .actions [sii];
795801 ComputationGraph hg;
796802 vector<int > pred;
797803 Expression nll = parser.log_prob_parser (&hg, sentence, actions, &right, &pred, false , false );
0 commit comments