Skip to content

Commit

Permalink
Updated script to run with the SpacyNER tagger and REL linker (castor…
Browse files Browse the repository at this point in the history
…ini#1226)

* Updated script to run with the SpacyNER tagger and REL linker. It is RAM intensive

* Shortened left/right context creation into a one-liner. Corrected indentation from tabs to spaces

* Added safeguard to keep documents without entities in the final collection file. Fixed positions that refer to the original position of the mentions in the original text.

Co-authored-by: Gustavo Gonçalves <[email protected]>
  • Loading branch information
gsgoncalves and gsgoncalves authored Jul 10, 2022
1 parent b3f3d94 commit f553d43
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 42 deletions.
17 changes: 17 additions & 0 deletions docs/working-with-entity-linking.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,22 @@ python entity_linking.py --input_path [input_jsonl_file] --rel_base_url [base_ur
--spacy_model [en_core_web_sm, en_core_web_lg, etc.] --output_path [output_jsonl_file]
```

An extended example assuming you're running the script from the scripts dir:
```bash
REL_DATA_PATH=/home/$USER/REL/data
INPUT_JSONL_FILE=../collections/msmarco-passage/collection_jsonl/docs00.json
mkdir ../collections/msmarco-passage/collection_jsonl_with_entities/
OUTPUT_JSONL_FILE=../collections/msmarco-passage/msmarco-passage/collection_jsonl_with_entities/docs00.json
BASE_URL=$REL_DATA_PATH
ED_MODEL=$REL_DATA_PATH/ed-wiki-2019/model
WIKI_VERSION=wiki_2019
WIKIMAPPER_INDEX=$REL_DATA_PATH/index_enwiki-20190420.db

python entity_linking.py --input_path $INPUT_JSONL_FILE \
--rel_base_url $BASE_URL --rel_ed_model_path $ED_MODEL \
--rel_wiki_version $WIKI_VERSION --wikimapper_index $WIKIMAPPER_INDEX \
--spacy_model en_core_web_sm --output_path $OUTPUT_JSONL_FILE
```

It should take about 5 to 10 minutes to run entity linking on 5,000 MS MARCO passages on Compute Canada.
See [this](https://github.com/castorini/onboarding/blob/master/docs/cc-guide.md#compute-canada) for instructions about running scripts on Compute Canada.
140 changes: 98 additions & 42 deletions scripts/entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,75 +17,130 @@
import argparse
import jsonlines
import spacy
from REL.REL.mention_detection import MentionDetection
from REL.REL.utils import process_results
from REL.REL.entity_disambiguation import EntityDisambiguation
from REL.REL.ner import NERBase, Span
import sys
from REL.mention_detection import MentionDetectionBase
from REL.utils import process_results, split_in_words
from REL.entity_disambiguation import EntityDisambiguation
from REL.ner import Span
from wikimapper import WikiMapper

from typing import Dict, List, Tuple
from tqdm import tqdm

# Spacy Mention Detection class which overrides the NERBase class in the REL entity linking process
class NERSpacy(NERBase):
def __init__(self):
class NERSpacyMD(MentionDetectionBase):
def __init__(self, base_url:str, wiki_version:str, spacy_model:str):
super().__init__(base_url, wiki_version)
# we only want to link entities of specific types
self.ner_labels = ['PERSON', 'NORP', 'FAC', 'ORG', 'GPE', 'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART',
'LAW', 'LANGUAGE', 'DATE', 'TIME', 'MONEY', 'QUANTITY']
self.spacy_model = spacy_model
spacy.prefer_gpu()
self.tagger = spacy.load(spacy_model)

# mandatory function which overrides NERBase.predict()
def predict(self, doc):
mentions = []
def predict(self, doc: spacy.tokens.Doc) -> List[Span]:
spans = []
for ent in doc.ents:
if ent.label_ in self.ner_labels:
mentions.append(Span(ent.text, ent.start_char, ent.end_char, 0, ent.label_))
return mentions
spans.append(Span(ent.text, ent.start_char, ent.end_char, 0, ent.label_))
return spans

"""
Responsible for finding mentions given a set of documents in a batch-wise manner. More specifically,
it returns the mention, its left/right context and a set of candidates.
:return: Dictionary with mentions per document.
"""

def find_mentions(self, dataset: Dict[str, str]) -> Tuple[Dict[str, List[Dict]], int]:
results = {}
total_ment = 0
for i, doc in tqdm(enumerate(dataset), desc='Finding mentions', total=len(dataset)):
result_doc = []
doc_text = dataset[doc]
spacy_doc = self.tagger(doc_text)
spans = self.predict(spacy_doc)
for entity in spans:
text, start_pos, end_pos, conf, tag = (
entity.text,
entity.start_pos,
entity.end_pos,
entity.score,
entity.tag,
)
m = self.preprocess_mention(text)
cands = self.get_candidates(m)
if len(cands) == 0:
continue
total_ment += 1
# Re-create ngram as 'text' is at times changed by Flair (e.g. double spaces are removed).
ngram = doc_text[start_pos:end_pos]
left_ctxt = " ".join(split_in_words(doc_text[:start_pos])[-100:])
right_ctxt = " ".join(split_in_words(doc_text[end_pos:])[:100])
res = {
"mention": m,
"context": (left_ctxt, right_ctxt),
"candidates": cands,
"gold": ["NONE"],
"pos": start_pos,
"sent_idx": 0,
"ngram": ngram,
"end_pos": end_pos,
"sentence": doc_text,
"conf_md": conf,
"tag": tag,
}
result_doc.append(res)
results[doc] = result_doc
return results, total_ment


# run REL entity linking on processed doc
def rel_entity_linking(spacy_docs, rel_base_url, rel_wiki_version, rel_ed_model_path):
mention_detection = MentionDetection(rel_base_url, rel_wiki_version)
tagger_spacy = NERSpacy()
mentions_dataset, _ = mention_detection.find_mentions(spacy_docs, tagger_spacy)
def rel_entity_linking(docs: Dict[str,str], spacy_model:str, rel_base_url:str, rel_wiki_version:str, rel_ed_model_path:str) -> Dict[str, List[Tuple]]:
mention_detection = NERSpacyMD(rel_base_url, rel_wiki_version, spacy_model)
mentions_dataset, _ = mention_detection.find_mentions(docs)
config = {
'mode': 'eval',
'model_path': rel_ed_model_path,
}
ed_model = EntityDisambiguation(rel_base_url, rel_wiki_version, config)
predictions, _ = ed_model.predict(mentions_dataset)

linked_entities = process_results(mentions_dataset, predictions, spacy_docs)
linked_entities = process_results(mentions_dataset, predictions, docs)
return linked_entities


# apply spaCy nlp processing pipeline on each doc
def apply_spacy_pipeline(input_path, spacy_model):
nlp = spacy.load(spacy_model)
spacy_docs = {}
# read input pyserini json docs into a dictionary
def read_docs(input_path: str) -> Dict[str, str]:
docs = {}
with jsonlines.open(input_path) as reader:
for obj in reader:
spacy_docs[obj['id']] = nlp(obj['contents'])
return spacy_docs
for obj in tqdm(reader, desc='Reading docs'):
docs[obj['id']] = obj['contents']
return docs


# enrich REL entity linking results with entities' wikidata ids, and write final results as json objects
def enrich_el_results(rel_linked_entities, spacy_docs, wikimapper_index):
# rel_linked_entities: Tuples of entities are composed by start_pos:int, mention_length:int, ent_text:str, ent_wikipedia_id:str, conf_score:float, ner_score:int, ent_type:str
def enrich_el_results(rel_linked_entities: Dict[str, List[Tuple]], docs: Dict[str, str], wikimapper_index:str) -> List[Dict]:
wikimapper = WikiMapper(wikimapper_index)
linked_entities_json = []
for docid, ents in rel_linked_entities.items():
linked_entities_info = []
for start_pos, end_pos, ent_text, ent_wikipedia_id, ent_type in ents:
# find entities' wikidata ids using their REL results (i.e. linked wikipedia ids)
ent_wikipedia_id = ent_wikipedia_id.replace('&amp;', '&')
ent_wikidata_id = wikimapper.title_to_id(ent_wikipedia_id)

# write results as json objects
linked_entities_info.append({'start_pos': start_pos, 'end_pos': end_pos, 'ent_text': ent_text,
'wikipedia_id': ent_wikipedia_id, 'wikidata_id': ent_wikidata_id,
'ent_type': ent_type})
linked_entities_json.append({'id': docid, 'contents': spacy_docs[docid].text,
'entities': linked_entities_info})
for docid, doc_text in tqdm(docs.items(), desc='Enriching EL results', total=len(rel_linked_entities)):
if docid not in rel_linked_entities:
linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': []})
else:
linked_entities_info = []
ents = rel_linked_entities[docid]
for start_pos, mention_length, ent_text, ent_wikipedia_id, conf_score, ner_score, ent_type in ents:
# find entities' wikidata ids using their REL results (i.e. linked wikipedia ids)
ent_wikipedia_id = ent_wikipedia_id.replace('&amp;', '&')
ent_wikidata_id = wikimapper.title_to_id(ent_wikipedia_id)

# write results as json objects
linked_entities_info.append({'start_pos': start_pos, 'end_pos': start_pos + mention_length, 'ent_text': ent_text,
'wikipedia_id': ent_wikipedia_id, 'wikidata_id': ent_wikidata_id,
'ent_type': ent_type})
linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': linked_entities_info})
return linked_entities_json


def main():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--input_path', type=str, help='path to input texts')
Expand All @@ -97,13 +152,14 @@ def main():
parser.add_argument('-o', '--output_path', type=str, help='path to output json file')
args = parser.parse_args()

spacy_docs = apply_spacy_pipeline(args.input_path, args.spacy_model)
rel_linked_entities = rel_entity_linking(spacy_docs, args.rel_base_url, args.rel_wiki_version,
docs = read_docs(args.input_path)
rel_linked_entities = rel_entity_linking(docs, args.spacy_model, args.rel_base_url, args.rel_wiki_version,
args.rel_ed_model_path)
linked_entities_json = enrich_el_results(rel_linked_entities, spacy_docs, args.wikimapper_index)
linked_entities_json = enrich_el_results(rel_linked_entities, docs, args.wikimapper_index)
with jsonlines.open(args.output_path, mode='w') as writer:
writer.write_all(linked_entities_json)


if __name__ == '__main__':
main()
main()
sys.exit(0)

0 comments on commit f553d43

Please sign in to comment.