diff --git a/parser/lstm-parse.cc b/parser/lstm-parse.cc index ad9dfe6..a4b9214 100644 --- a/parser/lstm-parse.cc +++ b/parser/lstm-parse.cc @@ -845,7 +845,8 @@ void signal_callback_handler(int /* signum */) { requested_stop = true; } -unsigned compute_correct(const map& ref, const map& hyp, unsigned len) { +template +unsigned compute_correct(const map& ref, const map& hyp, unsigned len) { unsigned res = 0; for (unsigned i = 0; i < len; ++i) { auto ri = ref.find(i); @@ -857,6 +858,24 @@ unsigned compute_correct(const map& ref, const map& hyp, unsig return res; } +template +unsigned compute_correct(const map& ref1, const map& hyp1, + const map& ref2, const map& hyp2, unsigned len) { + unsigned res = 0; + for (unsigned i = 0; i < len; ++i) { + auto r1 = ref1.find(i); + auto h1 = hyp1.find(i); + auto r2 = ref2.find(i); + auto h2 = hyp2.find(i); + assert(r1 != ref1.end()); + assert(h1 != hyp1.end()); + assert(r2 != ref2.end()); + assert(h2 != hyp2.end()); + if (r1->second == h1->second && r2->second == h2->second) ++res; + } + return res; +} + void output_conll(const vector& sentence, const vector& pos, const vector& sentenceUnkStrings, const map& intToWords, @@ -1142,7 +1161,8 @@ int main(int argc, char** argv) { double llh = 0; double trs = 0; double right = 0; - double correct_heads = 0; + double correct_heads_unlabeled = 0; + double correct_heads_labeled = 0; double total_heads = 0; auto t_start = std::chrono::high_resolution_clock::now(); unsigned corpus_size = corpus.nsentencesDev; @@ -1169,11 +1189,12 @@ int main(int argc, char** argv) { map ref = parser.compute_heads(sentence.size(), actions, corpus.actions, &rel_ref); map hyp = parser.compute_heads(sentence.size(), pred, corpus.actions, &rel_hyp); output_conll(sentence, sentencePos, sentenceUnkStr, corpus.intToWords, corpus.intToPos, hyp, rel_hyp); - correct_heads += compute_correct(ref, hyp, sentence.size() - 1); + correct_heads_unlabeled += compute_correct(ref, hyp, sentence.size() - 1); + correct_heads_labeled += compute_correct(ref, hyp, rel_ref, rel_hyp, sentence.size() - 1); total_heads += sentence.size() - 1; } auto t_end = std::chrono::high_resolution_clock::now(); - cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << "\t[" << corpus_size << " sents in " << std::chrono::duration(t_end-t_start).count() << " ms]" << endl; + cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads_unlabeled / total_heads) << " las: " << (correct_heads_labeled / total_heads) << "\t[" << corpus_size << " sents in " << std::chrono::duration(t_end-t_start).count() << " ms]" << endl; } for (unsigned i = 0; i < corpus.actions.size(); ++i) { //cerr << corpus.actions[i] << '\t' << parser.p_r->values[i].transpose() << endl;