Skip to content

Commit

Permalink
Fix clab#2: calculate and print las on test
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhers committed Jul 28, 2015
1 parent 53a95fd commit 1cb8c5e
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions parser/lstm-parse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ void signal_callback_handler(int /* signum */) {
requested_stop = true;
}

unsigned compute_correct(const map<int,int>& ref, const map<int,int>& hyp, unsigned len) {
template<typename T>
unsigned compute_correct(const map<int,T>& ref, const map<int,T>& hyp, unsigned len) {
unsigned res = 0;
for (unsigned i = 0; i < len; ++i) {
auto ri = ref.find(i);
Expand All @@ -440,6 +441,24 @@ unsigned compute_correct(const map<int,int>& ref, const map<int,int>& hyp, unsig
return res;
}

template<typename T1, typename T2>
unsigned compute_correct(const map<int,T1>& ref1, const map<int,T1>& hyp1,
const map<int,T2>& ref2, const map<int,T2>& hyp2, unsigned len) {
unsigned res = 0;
for (unsigned i = 0; i < len; ++i) {
auto r1 = ref1.find(i);
auto h1 = hyp1.find(i);
auto r2 = ref2.find(i);
auto h2 = hyp2.find(i);
assert(r1 != ref1.end());
assert(h1 != hyp1.end());
assert(r2 != ref2.end());
assert(h2 != hyp2.end());
if (r1->second == h1->second && r2->second == h2->second) ++res;
}
return res;
}

void output_conll(const vector<unsigned>& sentence, const vector<unsigned>& pos,
const vector<string>& sentenceUnkStrings,
const map<unsigned, string>& intToWords,
Expand Down Expand Up @@ -714,7 +733,8 @@ int main(int argc, char** argv) {
double llh = 0;
double trs = 0;
double right = 0;
double correct_heads = 0;
double correct_heads_unlabeled = 0;
double correct_heads_labeled = 0;
double total_heads = 0;
auto t_start = std::chrono::high_resolution_clock::now();
unsigned corpus_size = corpus.nsentencesDev;
Expand All @@ -736,11 +756,12 @@ int main(int argc, char** argv) {
map<int,int> ref = parser.compute_heads(sentence.size(), actions, corpus.actions, &rel_ref);
map<int,int> hyp = parser.compute_heads(sentence.size(), pred, corpus.actions, &rel_hyp);
output_conll(sentence, sentencePos, sentenceUnkStr, corpus.intToWords, corpus.intToPos, hyp, rel_hyp);
correct_heads += compute_correct(ref, hyp, sentence.size() - 1);
correct_heads_unlabeled += compute_correct(ref, hyp, sentence.size() - 1);
correct_heads_labeled += compute_correct(ref, hyp, rel_ref, rel_hyp, sentence.size() - 1);
total_heads += sentence.size() - 1;
}
auto t_end = std::chrono::high_resolution_clock::now();
cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << "\t[" << corpus_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads_unlabeled / total_heads) << " las: " << (correct_heads_labeled / total_heads) << "\t[" << corpus_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
}
for (unsigned i = 0; i < corpus.actions.size(); ++i) {
//cerr << corpus.actions[i] << '\t' << parser.p_r->values[i].transpose() << endl;
Expand Down

0 comments on commit 1cb8c5e

Please sign in to comment.