Skip to content

Commit 580f619

Browse files
committed
update gpu
1 parent b1bd319 commit 580f619

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

impl/dependency-parser-batch.cc

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include "dynet/rnn.h"
2828
#include "c2.h"
2929

30-
unsigned batch_size = 50;
3130
float pdrop = 0.3;
3231
bool DEBUG = false;
3332
cpyp::Corpus corpus;
@@ -368,6 +367,13 @@ if(DEBUG) std::cerr<<"action index " << action_count<<"\n";
368367
current_valid_actions.push_back(a);
369368
}
370369
if(DEBUG) std::cerr<<"possible action " << current_valid_actions.size()<<"\n";
370+
if(DEBUG){
371+
372+
for(unsigned i = 0; i < current_valid_actions.size(); i ++){
373+
std::cerr<<current_valid_actions[i]<<" ";
374+
}
375+
std::cerr<<"\n";
376+
}
371377
//stack attention
372378
Expression prev_h = state_lstm.final_h()[0];
373379
vector<Expression> s_att;
@@ -419,13 +425,17 @@ if(DEBUG) std::cerr<<"attention ok\n";
419425
Expression n_combo = rectify(combo);
420426
Expression rt = affine_transform({rtbias, combo2rt, n_combo});
421427
if(DEBUG) std::cerr<<"to action layer ok\n";
422-
Expression adiste = log_softmax(rt, current_valid_actions);
428+
Expression rts = select_rows(rt, current_valid_actions);
429+
//Expression adiste = log_softmax(rt, current_valid_actions);
430+
Expression adiste = log_softmax(rts);
431+
if(DEBUG) std::cerr<<"select action ok\n";
423432
vector<float> adist = as_vector(hg->incremental_forward(adiste));
424-
double best_score = adist[current_valid_actions[0]];
433+
//double best_score = adist[current_valid_actions[0]];
434+
double best_score = adist[0];
425435
unsigned best_a = current_valid_actions[0];
426436
for (unsigned i = 1; i < current_valid_actions.size(); ++i) {
427-
if (adist[current_valid_actions[i]] > best_score) {
428-
best_score = adist[current_valid_actions[i]];
437+
if (adist[i] > best_score) {
438+
best_score = adist[i];
429439
best_a = current_valid_actions[i];
430440
}
431441
}
@@ -682,12 +692,11 @@ int main(int argc, char** argv) {
682692
cerr << "TRAINING STARTED AT: " << localtime(&time_start) << endl;
683693
while(!requested_stop) {
684694
++iter;
685-
for (unsigned sii = 0; sii < status_every_i_iterations; ++sii) {
686-
687-
ComputationGraph hg;
688-
vector<Expression> batch_nll;
689695

690-
for(unsigned batch = 0; batch <= batch_size; ++batch){
696+
{
697+
ComputationGraph hg;
698+
vector<Expression> batch_nll;
699+
for (unsigned sii = 0; sii < status_every_i_iterations; ++sii) {
691700

692701
if (si == corpus.nsentences) {
693702
si = 0;
@@ -714,11 +723,12 @@ int main(int argc, char** argv) {
714723
llh += lp;
715724
++si;
716725
trs += actions.size();
717-
}
718-
hg.backward(sum(batch_nll));
719-
sgd->update(1.0);
720726
}
727+
hg.backward(sum(batch_nll));
728+
sgd->update(1.0);
721729
sgd->status();
730+
}
731+
722732
time_t time_now = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
723733
cerr << "update #" << iter << " (epoch " << (tot_seen / corpus.nsentences) << " |time=" << localtime(&time_now) << ")\tllh: "<< llh<<" ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << endl;
724734
llh = trs = right = 0;

0 commit comments

Comments
 (0)