Skip to content

Commit

Permalink
All pipelines working and testes
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrojlazevedo committed Apr 15, 2020
1 parent 9416131 commit de8a711
Show file tree
Hide file tree
Showing 14 changed files with 405 additions and 727 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ fever-baselines/glove.6B.zip
list_of_Wikipedia_files.txt
.idea/
rte/entailment_predictions/
rte/entailment_predictions*
output/
74 changes: 50 additions & 24 deletions doc_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def get_docs_with_oie(claim, wiki_entities,client):
ents.add(sub.text)
for obj in clause['O']:
ents.add(obj.text)
print(ner_spacy)
print(ents)
# print(ner_spacy)
# print(ents)

if len(ents) > 4:
ents = clean_entities(ents)
Expand Down Expand Up @@ -88,7 +88,7 @@ def getClosestDocs(wiki_entities, entities):
entities[i] = str(entities[i])
selected_docs = set()
for ent in entities:
print(ent)
# print(ent)
ent = ud.normalize('NFC', ent)

best_1 = 1.1
Expand All @@ -105,34 +105,35 @@ def getClosestDocs(wiki_entities, entities):
for we in wiki_entities:
dists.append((distance(we, ent), we))
b = datetime.datetime.now()
print(b-a)
# print(b-a)

pair_1 = min(dists, key=operator.itemgetter(0))
dists.remove(pair_1)
pair_2 = min(dists, key=operator.itemgetter(0))
# dists.remove(pair_1)
# pair_2 = min(dists, key=operator.itemgetter(0))

best_match_1 = pair_1[1]
best_match_2 = pair_2[1]
# best_match_2 = pair_2[1]

best_match_1 = best_match_1.replace(" ", "_")
best_match_1 = best_match_1.replace("/", "-SLH-")
best_match_1 = best_match_1.replace("(", "-LRB-")
best_match_1 = best_match_1.replace(")", "-RRB-")

best_match_2 = best_match_2.replace(" ", "_")
best_match_2 = best_match_2.replace("/", "-SLH-")
best_match_2 = best_match_2.replace("(", "-LRB-")
best_match_2 = best_match_2.replace(")", "-RRB-")
# best_match_2 = best_match_2.replace(" ", "_")
# best_match_2 = best_match_2.replace("/", "-SLH-")
# best_match_2 = best_match_2.replace("(", "-LRB-")
# best_match_2 = best_match_2.replace(")", "-RRB-")

best_match_3 = best_match_3.replace(" ", "_")
best_match_3 = best_match_3.replace("/", "-SLH-")
best_match_3 = best_match_3.replace("(", "-LRB-")
best_match_3 = best_match_3.replace(")", "-RRB-")

selected_docs.add(best_match_1)
selected_docs.add(best_match_2)
# selected_docs.add(best_match_2)
# selected_docs.append(best_match_3)
print(selected_docs)
# print(selected_docs)
print(selected_docs)
return list(selected_docs), entities


Expand All @@ -146,7 +147,32 @@ def getRelevantDocs(claim, wiki_entities, ner_module="spaCy", nlp=None): # ,mat
else:
print("Error: Incorrect Document Retrieval Specifications")
return
return getClosestDocs(wiki_entities, entities)
return get_closest_docs_ner(wiki_entities, entities)


def get_closest_docs_ner(wiki_entities,entities):
entities = list(entities)
for i in range(len(entities)):
entities[i] = str(entities[i])
selected_docs = []
for ent in entities:
ent = ud.normalize('NFC',ent)
if ent in wiki_entities:
best_match = ent
else:
best = 11111111111
best_match = ""
for we in wiki_entities:
dist = distance(we,ent)
if dist < best:
best = dist
best_match = we
best_match = best_match.replace(" ","_")
best_match = best_match.replace("/","-SLH-")
best_match = best_match.replace("(","-LRB-")
best_match = best_match.replace(")","-RRB-")
selected_docs.append(best_match)
return selected_docs, entities


def getDocContent(wiki_folder, doc_id):
Expand All @@ -165,16 +191,16 @@ def getDocContent(wiki_folder, doc_id):

"""
def getDocContentFromFile(wiki_folder, doc_filename, doc_id):
fileContent= jsonlines.open(wiki_folder + "/" + doc_filename)
for doc in fileContent:
if doc["id"] == doc_id:
doc["fileId"] = doc_filename
return doc
return None
fileContent= jsonlines.open(wiki_folder + "/" + doc_filename)
for doc in fileContent:
if doc["id"] == doc_id:
doc["fileId"] = doc_filename
return doc
return None
"""


Expand Down
51 changes: 38 additions & 13 deletions generate_rte_preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
import codecs
import unicodedata as ud
from openie import StanfordOpenIE
import gensim


relevant_sentences_file = "data/dev_concatenation.jsonl"
concatenate_file = "data/dev_concatenation_oie_2.jsonl"
relevant_sentences_file = "data/dev_sentence_selection.jsonl"
ner_file = "data/subsample_train_concatenation_2_1.jsonl"
concatenate_file = "data/dev_sentence_selection_final.jsonl"

instances = []
zero_results = 0
INCLUDE_NER = False
INCLUDE_OIE = True
RUN_RTE = False
INCLUDE_OIE = False
RUN_RTE = True
RUN_TRIPLE_BASED = False
RUN_SENTENCE_BERT = True

relevant_sentences_file = jsonlines.open(relevant_sentences_file)
if RUN_RTE:
Expand All @@ -27,18 +32,25 @@
predictor = Predictor.from_archive(model)

wiki_dir = "data/wiki-pages/wiki-pages"
wiki_split_docs_dir = "data/wiki-pages-split"
wiki_split_docs_dir = "../wiki-pages-split"

claim_num = 1

wiki_entities = os.listdir(wiki_split_docs_dir)
# sum = 0
for i in range(len(wiki_entities)):
wiki_entities[i] = wiki_entities[i].replace("-SLH-", "/")
wiki_entities[i] = wiki_entities[i].replace("_", " ")
wiki_entities[i] = wiki_entities[i][:-5]
wiki_entities[i] = wiki_entities[i].replace("-LRB-", "(")
wiki_entities[i] = wiki_entities[i].replace("-RRB-", ")")
# tokens_sentence = gensim.utils.simple_preprocess(wiki_entities[i])
# tokens_sentence = gensim.utils.simple_preprocess(wiki_entities[i])
# num = len(tokens_sentence)
# #print(num)
# sum += num
# print(sum/len(wiki_entities))
# exit(0)

# wiki_entities[i] = ' '.join(map(str, tokens_sentence))

print("Wiki entities successfully parsed")
Expand All @@ -64,17 +76,30 @@ def run_rte(claim, evidence, claim_num):


with StanfordOpenIE() as client:
with jsonlines.open(concatenate_file, mode='w') as writer_c:
with jsonlines.open(concatenate_file, mode='w') as writer_c, \
jsonlines.open(ner_file, mode='w') as writer_n:
for i in range(0, len(instances)):
claim = instances[i]['claim']
print(claim)
evidence = instances[i]['predicted_sentences']
if RUN_TRIPLE_BASED:
if 'predicted_sentences_triple' in instances[i]:
evidence = instances[i]['predicted_sentences_triple']
print(evidence)
else:
evidence = instances[i]['predicted_sentences']
elif RUN_SENTENCE_BERT:
if 'predicted_sentences_bert' in instances[i]:
evidence = instances[i]['predicted_sentences_bert']
print("hello" + str(evidence))
else:
evidence = instances[i]['predicted_sentences']
else:
evidence = instances[i]['predicted_sentences']
potential_evidence_sentences = []

for sentence in evidence:
# print(sentence)
# print(sentence[0])
# load document from TF-IDF

# load document from sentence pair
relevant_doc = ud.normalize('NFC', sentence[0])
relevant_doc = relevant_doc.replace("/", "-SLH-")
file = codecs.open(wiki_split_docs_dir + "/" + relevant_doc + ".json", "r", "utf-8")
Expand Down Expand Up @@ -115,6 +140,7 @@ def run_rte(claim, evidence, claim_num):

instances[i]['predicted_pages_ner'] = relevant_docs
instances[i]['predicted_sentences_ner'] = predicted_evidence
writer_n.write(instances[i])

if RUN_RTE:
preds = run_rte(claim, potential_evidence_sentences, claim_num)
Expand All @@ -132,8 +158,6 @@ def run_rte(claim, evidence, claim_num):

saveFile.close()
claim_num += 1
# print(claim_num)
# print(instances[i])

if INCLUDE_OIE:
relevant_docs, entities = doc_retrieval.get_docs_with_oie(claim, wiki_entities, client)
Expand All @@ -143,4 +167,5 @@ def run_rte(claim, evidence, claim_num):
writer_c.write(instances[i])
print("Claim number: " + str(i) + " of " + str(len(instances)))


print("Number of Zero Sentences Found: " + str(zero_results))
146 changes: 0 additions & 146 deletions generate_rte_preds_1.py

This file was deleted.

Loading

0 comments on commit de8a711

Please sign in to comment.