2727#include " dynet/rnn.h"
2828#include " c2.h"
2929
30- unsigned batch_size = 50 ;
3130float pdrop = 0.3 ;
3231bool DEBUG = false ;
3332cpyp::Corpus corpus;
@@ -368,6 +367,13 @@ if(DEBUG) std::cerr<<"action index " << action_count<<"\n";
368367 current_valid_actions.push_back (a);
369368 }
370369if (DEBUG) std::cerr<<" possible action " << current_valid_actions.size ()<<" \n " ;
370+ if (DEBUG){
371+
372+ for (unsigned i = 0 ; i < current_valid_actions.size (); i ++){
373+ std::cerr<<current_valid_actions[i]<<" " ;
374+ }
375+ std::cerr<<" \n " ;
376+ }
371377 // stack attention
372378 Expression prev_h = state_lstm.final_h ()[0 ];
373379 vector<Expression> s_att;
@@ -419,13 +425,17 @@ if(DEBUG) std::cerr<<"attention ok\n";
419425 Expression n_combo = rectify (combo);
420426 Expression rt = affine_transform ({rtbias, combo2rt, n_combo});
421427if (DEBUG) std::cerr<<" to action layer ok\n " ;
422- Expression adiste = log_softmax (rt, current_valid_actions);
428+ Expression rts = select_rows (rt, current_valid_actions);
429+ // Expression adiste = log_softmax(rt, current_valid_actions);
430+ Expression adiste = log_softmax (rts);
431+ if (DEBUG) std::cerr<<" select action ok\n " ;
423432 vector<float > adist = as_vector (hg->incremental_forward (adiste));
424- double best_score = adist[current_valid_actions[0 ]];
433+ // double best_score = adist[current_valid_actions[0]];
434+ double best_score = adist[0 ];
425435 unsigned best_a = current_valid_actions[0 ];
426436 for (unsigned i = 1 ; i < current_valid_actions.size (); ++i) {
427- if (adist[current_valid_actions[i] ] > best_score) {
428- best_score = adist[current_valid_actions[i] ];
437+ if (adist[i ] > best_score) {
438+ best_score = adist[i ];
429439 best_a = current_valid_actions[i];
430440 }
431441 }
@@ -682,12 +692,11 @@ int main(int argc, char** argv) {
682692 cerr << " TRAINING STARTED AT: " << localtime (&time_start) << endl;
683693 while (!requested_stop) {
684694 ++iter;
685- for (unsigned sii = 0 ; sii < status_every_i_iterations; ++sii) {
686-
687- ComputationGraph hg;
688- vector<Expression> batch_nll;
689695
690- for (unsigned batch = 0 ; batch <= batch_size; ++batch){
696+ {
697+ ComputationGraph hg;
698+ vector<Expression> batch_nll;
699+ for (unsigned sii = 0 ; sii < status_every_i_iterations; ++sii) {
691700
692701 if (si == corpus.nsentences ) {
693702 si = 0 ;
@@ -714,11 +723,12 @@ int main(int argc, char** argv) {
714723 llh += lp;
715724 ++si;
716725 trs += actions.size ();
717- }
718- hg.backward (sum (batch_nll));
719- sgd->update (1.0 );
720726 }
727+ hg.backward (sum (batch_nll));
728+ sgd->update (1.0 );
721729 sgd->status ();
730+ }
731+
722732 time_t time_now = std::chrono::system_clock::to_time_t (std::chrono::system_clock::now ());
723733 cerr << " update #" << iter << " (epoch " << (tot_seen / corpus.nsentences ) << " |time=" << localtime (&time_now) << " )\t llh: " << llh<<" ppl: " << exp (llh / trs) << " err: " << (trs - right) / trs << endl;
724734 llh = trs = right = 0 ;
0 commit comments