Skip to content

Commit

Permalink
Added essential metrics for every step
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrojlazevedo committed Mar 5, 2020
1 parent 066c22e commit 97956ca
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 11 deletions.
101 changes: 95 additions & 6 deletions metrics.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import jsonlines
import json
import sys
from scorer import fever_score

train_file = "data/subsample_train.jsonl"
train_relevant_file = "data/subsample_train_relevant_docs.jsonl"
train_concatenate_file = "data/subsample_train_concatenation.jsonl"
train_predictions_file = "predictions/predictions_train.jsonl"

train_file = jsonlines.open(train_file)
train_relevant_file = jsonlines.open(train_relevant_file)
train_concatenate_file = jsonlines.open(train_concatenate_file)
train_predictions_file = jsonlines.open(train_predictions_file)

train_set = []
train_relevant = []
train_concatenate = []
train_prediction = []

for lines in train_file:
Expand All @@ -23,6 +27,17 @@
lines['claim'] = lines['claim'].replace("-RRB-", " ) ")
train_relevant.append(lines)

for lines in train_concatenate_file:
lines['claim'] = lines['claim'].replace("-LRB-", " ( ")
lines['claim'] = lines['claim'].replace("-RRB-", " ) ")
train_concatenate.append(lines)

# this evidence addition is irrelevant
info_by_id = dict((d['id'], dict(d, index=index)) for (index, d) in enumerate(train_set))
for lines in train_predictions_file:
lines['evidence'] = info_by_id.get(lines['id'])['evidence']
train_prediction.append(lines)

# All claims
stop = 0

Expand All @@ -32,7 +47,8 @@
verifiable : boolean of 1 and 0 with respective meaning
docs : set of documents that verify the claim
docs_sep : set of documents seperated
sentences: list of tuples of <doc, line>
evidences: list of tuples of <doc, line>
difficulties: list of the number of sentences needed to be evidence
'''
gold_data = []

Expand All @@ -51,21 +67,26 @@
gold_documents_seperated = set()
sentences_pair = set()
evidences = claim['evidence']

difficulties = []
for evidence in evidences:
doc_name = ''
difficulty = 0
if len(evidence) > 1: # needs more than 1 doc to be verifiable
for e in evidence:
doc_name += str(e[2])
doc_name += " "
sentences_pair.add((str(e[2]), str(e[3]))) # add gold sentences
gold_documents_seperated.add(str(e[2])) # add the document
difficulty += 1
doc_name = doc_name[:-1] # erase the last blank space
else:
doc_name = str(evidence[0][2])
gold_documents_seperated.add(str(evidence[0][2]))
sentences_pair.add((str(evidence[0][2]), str(evidence[0][3])))
difficulty = 1
difficulties.append(difficulty)
gold_documents.add(doc_name)
gold_dict['difficulties'] = difficulties
gold_dict['docs'] = gold_documents
gold_dict['evidences'] = sentences_pair
gold_dict['docs_sep'] = gold_documents_seperated
Expand All @@ -92,6 +113,13 @@
recall_incorrect = 0
specificity = 0

precision_sent_correct = 0
precision_sent_incorrect = 0
recall_sent_correct = 0
recall_sent_incorrect = 0
sent_found = 0
sent_found_if_doc_found = 0

total_claim = 0
for claim in train_relevant:
_id = claim['id']
Expand All @@ -102,6 +130,7 @@
continue

# document analysis
# TODO: Analyse NER and TF-IDF
doc_correct = 0
doc_incorrect = 0
gold_incorrect = 0
Expand Down Expand Up @@ -130,7 +159,31 @@
doc_found += 1

# sentence analysis TODO: check sentences
sentences = set()
for sent in claim['predicted_sentences']:
sentences.add((str(sent[0]), str(sent[1])))

evidences = gold_dict['evidences']
sent_correct = 0
sent_incorrect = 0
flag = False
for sent in sentences:
if sent in evidences:
sent_correct += 1
flag = True
else:
sent_incorrect += 1

if flag:
sent_found += 1

if doc_correct and flag:
sent_found_if_doc_found += 1

precision_sent_correct += sent_correct / len(sentences)
precision_sent_incorrect += sent_incorrect / len(sentences)
recall_sent_correct += sent_correct / len(evidences)
recall_sent_incorrect += sent_incorrect / len(evidences)

# TODO: create all possible pair in order to see if it appears in gold_dict['docs']
# claim['predicted_sentences']
Expand All @@ -140,17 +193,53 @@
stop += 1
if stop == -1:
break

# scores from fever
results = fever_score(train_prediction, actual = train_set)

precision_correct /= total_claim
precision_incorrect /= total_claim
recall_correct /= total_claim
recall_incorrect /= total_claim
specificity /= total_claim
doc_found /= total_claim


print("\n#############")
print("# DOCUMENTS #")
print("#############")
print("Precision (Document Retrieved):\t\t\t\t\t\t " + str(precision_correct)) # precision
print("Fall-out (incorrect documents):\t\t\t " + str(precision_incorrect)) # precision
print("Fall-out (incorrect documents):\t\t\t\t\t\t " + str(precision_incorrect)) # precision
print("Recall (Relevant Documents):\t\t\t\t\t\t " + str(recall_correct)) # recall
print("Percentage of gold documents NOT found:\t\t\t\t " + str(recall_incorrect)) # recall
print("Fall-out: " + str(specificity))
print("Percentage of at least one document found correctly: " + str(doc_found)) # recall
print("Percentage of at least one document found correctly: " + str(doc_found)) # recall


precision_sent_correct /= total_claim
precision_sent_incorrect /= total_claim
recall_sent_correct /= total_claim
recall_sent_incorrect /= total_claim
sent_found /= total_claim
sent_found_if_doc_found /= total_claim
another_sent = sent_found_if_doc_found / doc_found

print("\n#############")
print("# SENTENCES #")
print("#############")
print("Precision (Sentences Retrieved):\t\t\t\t\t " + str(precision_sent_correct)) # precision
print("Precision (incorrect Sentences):\t\t\t\t\t " + str(precision_sent_incorrect)) # precision
print("Recall (Relevant Sentences):\t\t\t\t\t\t " + str(recall_sent_correct)) # recall
print("Percentage of gold Sentences NOT found:\t\t\t\t " + str(recall_sent_incorrect)) # recall
print("Percentage of at least one Sentence found correctly: " + str(sent_found)) # recall
print("Percentage of at least one Sentence found correctly: " + str(sent_found_if_doc_found)) # recall
print("Percentage of at least one Sentence found correctly: " + str(another_sent)) # recall

print("\n#########")
print("# FEVER #")
print("#########")
print("Strict_score: \t\t" + str(results[0]))
print("Acc_score: \t\t\t" + str(results[1]))
print("Precision: \t\t\t" + str(results[2]))
print("Recall: \t\t\t" + str(results[3]))
print("F1-Score: \t\t\t" + str(results[4]))

4 changes: 2 additions & 2 deletions rte/rte.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def determinePredictedLabel(preds):
[len(nonePredictions), len(supportPredictions), len(contradictionPredictions)])
mostCommonPrediction = np.argmax(numberOfPredictionsPerLabel)

if mostCommonPrediction == 0:
if mostCommonPrediction == 1:
return (0, supportPredictions)
elif mostCommonPrediction == 1:
elif mostCommonPrediction == 2:
return (1, contradictionPredictions)
else:
return (2, [])
Expand Down
153 changes: 153 additions & 0 deletions scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import six

def check_predicted_evidence_format(instance):
if 'predicted_evidence' in instance.keys() and len(instance['predicted_evidence']):
assert all(isinstance(prediction, list)
for prediction in instance["predicted_evidence"]), \
"Predicted evidence must be a list of (page,line) lists"

assert all(len(prediction) == 2
for prediction in instance["predicted_evidence"]), \
"Predicted evidence must be a list of (page,line) lists"

assert all(isinstance(prediction[0], six.string_types)
for prediction in instance["predicted_evidence"]), \
"Predicted evidence must be a list of (page<string>,line<int>) lists"

assert all(isinstance(prediction[1], int)
for prediction in instance["predicted_evidence"]), \
"Predicted evidence must be a list of (page<string>,line<int>) lists"


def is_correct_label(instance):
return instance["label"].upper() == instance["predicted_label"].upper()


def is_strictly_correct(instance, max_evidence=None):
#Strict evidence matching is only for NEI class
check_predicted_evidence_format(instance)

if instance["label"].upper() != "NOT ENOUGH INFO" and is_correct_label(instance):
assert 'predicted_evidence' in instance, "Predicted evidence must be provided for strict scoring"

if max_evidence is None:
max_evidence = len(instance["predicted_evidence"])


for evience_group in instance["evidence"]:
#Filter out the annotation ids. We just want the evidence page and line number
actual_sentences = [[e[2], e[3]] for e in evience_group]
#Only return true if an entire group of actual sentences is in the predicted sentences
if all([actual_sent in instance["predicted_evidence"][:max_evidence] for actual_sent in actual_sentences]):
return True

#If the class is NEI, we don't score the evidence retrieval component
elif instance["label"].upper() == "NOT ENOUGH INFO" and is_correct_label(instance):
return True

return False


def evidence_macro_precision(instance, max_evidence=None):
this_precision = 0.0
this_precision_hits = 0.0

if instance["label"].upper() != "NOT ENOUGH INFO":
all_evi = [[e[2], e[3]] for eg in instance["evidence"] for e in eg if e[3] is not None]

predicted_evidence = instance["predicted_evidence"] if max_evidence is None else \
instance["predicted_evidence"][:max_evidence]

for prediction in predicted_evidence:
if prediction in all_evi:
this_precision += 1.0
this_precision_hits += 1.0

return (this_precision / this_precision_hits) if this_precision_hits > 0 else 1.0, 1.0

return 0.0, 0.0

def evidence_macro_recall(instance, max_evidence=None):
# We only want to score F1/Precision/Recall of recalled evidence for NEI claims
if instance["label"].upper() != "NOT ENOUGH INFO":
# If there's no evidence to predict, return 1
if len(instance["evidence"]) == 0 or all([len(eg) == 0 for eg in instance]):
return 1.0, 1.0

predicted_evidence = instance["predicted_evidence"] if max_evidence is None else \
instance["predicted_evidence"][:max_evidence]

for evidence_group in instance["evidence"]:
evidence = [[e[2], e[3]] for e in evidence_group]
if all([item in predicted_evidence for item in evidence]):
# We only want to score complete groups of evidence. Incomplete groups are worthless.
return 1.0, 1.0
return 0.0, 1.0
return 0.0, 0.0


# Micro is not used. This code is just included to demostrate our model of macro/micro
def evidence_micro_precision(instance):
this_precision = 0
this_precision_hits = 0

# We only want to score Macro F1/Precision/Recall of recalled evidence for NEI claims
if instance["label"].upper() != "NOT ENOUGH INFO":
all_evi = [[e[2], e[3]] for eg in instance["evidence"] for e in eg if e[3] is not None]

for prediction in instance["predicted_evidence"]:
if prediction in all_evi:
this_precision += 1.0
this_precision_hits += 1.0

return this_precision, this_precision_hits


def fever_score(predictions,actual=None, max_evidence=5):
correct = 0
strict = 0

macro_precision = 0
macro_precision_hits = 0

macro_recall = 0
macro_recall_hits = 0

for idx,instance in enumerate(predictions):
assert 'predicted_evidence' in instance.keys(), 'evidence must be provided for the prediction'

#If it's a blind test set, we need to copy in the values from the actual data
if 'evidence' not in instance or 'label' not in instance:
assert actual is not None, 'in blind evaluation mode, actual data must be provided'
assert len(actual) == len(predictions), 'actual data and predicted data length must match'
assert 'evidence' in actual[idx].keys(), 'evidence must be provided for the actual evidence'
instance['evidence'] = actual[idx]['evidence']
instance['label'] = actual[idx]['label']

assert 'evidence' in instance.keys(), 'gold evidence must be provided'

if is_correct_label(instance):
correct += 1.0

if is_strictly_correct(instance, max_evidence):
strict+=1.0

macro_prec = evidence_macro_precision(instance, max_evidence)
macro_precision += macro_prec[0]
macro_precision_hits += macro_prec[1]

macro_rec = evidence_macro_recall(instance, max_evidence)
macro_recall += macro_rec[0]
macro_recall_hits += macro_rec[1]

total = len(predictions)

strict_score = strict / total
acc_score = correct / total

pr = (macro_precision / macro_precision_hits) if macro_precision_hits > 0 else 1.0
rec = (macro_recall / macro_recall_hits) if macro_recall_hits > 0 else 0.0

f1 = 2.0 * pr * rec / (pr + rec)

return strict_score, acc_score, pr, rec, f1
6 changes: 3 additions & 3 deletions train_label_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def predict_test(predictions_test, entailment_predictions_test, new_predictions_
i += 1


predictions_train = "predictions_train.jsonl"
predictions_test = "predictions.jsonl"
predictions_train = "predictions/predictions_train.jsonl"
predictions_test = "predictions/predictions.jsonl"
new_predictions_file = "predictions/new_predictions.jsonl"

gold_train = "data/subsample_train_relevant_docs.jsonl"
Expand Down Expand Up @@ -220,7 +220,7 @@ def predict_test(predictions_test, entailment_predictions_test, new_predictions_
# clf= Pipeline([('scaler', MinMaxScaler()), ('clf', svm.SVC())])

clf.fit(x_train, y_train)

print("Fit Done")
joblib.dump(clf, 'label_classifier.pkl')
# clf = joblib.load('filename.pkl')

Expand Down

0 comments on commit 97956ca

Please sign in to comment.