Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mention detection with Bert #151

Open
wants to merge 62 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
af0da5e
efficiency test without server
eriktks Dec 15, 2022
26072db
efficiency test without server
eriktks Dec 15, 2022
5f850dc
efficiency test without server
eriktks Dec 15, 2022
ca1b937
efficiency test without server
eriktks Dec 15, 2022
234e886
efficiency test without server
eriktks Dec 15, 2022
f8f3d7e
fixed bert server usage
eriktks Dec 16, 2022
f3914e9
fixed gerbil test problem
eriktks Dec 16, 2022
84f28d7
added multilingual bert
eriktks Dec 20, 2022
67696a9
refactored code
eriktks Dec 22, 2022
0af9d19
refactored code
eriktks Dec 23, 2022
8c18251
smooth installation updates
eriktks Jan 5, 2023
a6ae211
fixed tests/test_ed_pipeline.py
eriktks Jan 5, 2023
1dd3a54
made required arguments optional
eriktks Jan 5, 2023
e7b1604
code cleanup
eriktks Jan 13, 2023
cd24c27
prune word-internal mentions
eriktks Jan 13, 2023
f2a5514
solve initials bug
eriktks Jan 13, 2023
0b90309
file path standardization
eriktks Jan 13, 2023
dc008c0
flagged flair with splitting
eriktks Jan 16, 2023
151bac4
move evaluate_predictions.py to scripts
eriktks Jan 19, 2023
894837d
move evaluate_predictions.py to scripts
eriktks Jan 19, 2023
08a5938
simplified NER tagger selection
eriktks Jan 30, 2023
083f7f5
skipped tests requiring data
eriktks Jan 30, 2023
061edc1
add defaults for arguments
eriktks Jan 30, 2023
369275f
replace next by continue
eriktks Jan 30, 2023
4e6703e
replace with list comprhension
eriktks Jan 30, 2023
7b15c15
simplify code
eriktks Jan 30, 2023
1ca5fbe
values became keyword arguments
eriktks Jan 30, 2023
9e3dae1
string formatting replaced rounding
eriktks Jan 30, 2023
25a5bf1
Update tests/test_evaluate_predictions.py
eriktks Feb 14, 2023
508483a
Update tests/test_evaluate_predictions.py
eriktks Feb 14, 2023
cd86c35
Update tests/test_evaluate_predictions.py
eriktks Feb 14, 2023
fa188e8
Update tests/test_evaluate_predictions.py
eriktks Feb 14, 2023
cbdc791
fixed data format
eriktks Feb 14, 2023
d31135e
make tests work
eriktks Feb 13, 2024
10a3d87
make tests work
eriktks Feb 13, 2024
93899ed
make tests work
eriktks Feb 13, 2024
72b31e0
make tests work
eriktks Feb 13, 2024
44dc91d
removed redundant function
eriktks Feb 13, 2024
ea80f2f
use_server on same level
eriktks Feb 20, 2024
83ff7e6
fixed unreadable code
eriktks Feb 27, 2024
94cc303
base_url for defining path
eriktks Feb 27, 2024
cdcf86c
use startswith iso re.search
eriktks Feb 27, 2024
8a05197
simplified computations
eriktks Feb 27, 2024
565a96a
print with % iso f
eriktks Mar 5, 2024
9d713cc
removed is_flair function argument
eriktks Mar 5, 2024
afbb17f
removed function argument tagger_ner_name
eriktks Mar 5, 2024
6b5b10d
updated combine_entities output format
eriktks Mar 5, 2024
c21e85f
replaced re.sub by str.removeprefix
eriktks Mar 19, 2024
bdc149c
removed redundant re calls
eriktks Mar 19, 2024
2e73d0c
crash without loaded model
eriktks Mar 19, 2024
3bf20d2
simplified split_docs_value variable
eriktks Mar 19, 2024
ef5c5a9
removed hard-coded paths
eriktks Mar 19, 2024
e7304fb
enabling manual action run
eriktks Mar 19, 2024
7cd274f
changing python version
eriktks Mar 19, 2024
e500990
chnaged pytest arguments
eriktks Mar 19, 2024
4858dce
fixing merge conflicts
eriktks Mar 28, 2024
425d917
solving most merge conflicts
eriktks Mar 29, 2024
191b08a
corrected incomplete path
eriktks Apr 8, 2024
59fb3ae
added documentation for ner
eriktks Apr 8, 2024
19ca5a4
restricted scipy version
eriktks Apr 9, 2024
abf4152
restricted scipy version
eriktks Apr 9, 2024
becdac1
restricted scipy version
eriktks Apr 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/radboud-el)](https://pypi.org/project/radboud-el/)
[![PyPI](https://img.shields.io/pypi/v/radboud-el.svg?style=flat)](https://pypi.org/project/radboud-el/)

---

Example tests:

* Flair: `python3 scripts/efficiency_test.py --process_sentences`
* Bert: `python3 scripts/efficiency_test.py --use_bert_base_uncased --split_docs_value 500`
* Server (slower):
* `python3 src/REL/server.py --use_bert_base_uncased --split_docs_value 500 --ed-model ed-wiki-2019 data wiki_2019`
* `python3 scripts/efficiency_test.py --use_server`

Needs installation of REL documents in directory `doc` (`ed-wiki-2019`, `generic` and `wiki_2019`)

---

REL is a modular Entity Linking package that is provided as a Python package as well as a web API. REL has various meanings - one might first notice that it stands for relation, which is a suiting name for the problems that can be tackled with this package. Additionally, in Dutch a 'rel' means a disturbance of the public order, which is exactly what we aim to achieve with the release of this package.

REL utilizes *English* Wikipedia as a knowledge base and can be used for the following tasks:
Expand Down
94 changes: 72 additions & 22 deletions scripts/efficiency_test.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,74 @@
import argparse
from REL.evaluate_predictions import evaluate
import json
import numpy as np
import os
import re
import requests

from REL.ner.set_tagger_ner import set_tagger_ner
from REL.training_datasets import TrainingEvaluationDatasets

parser = argparse.ArgumentParser()
parser.add_argument("--max_docs", help = "number of documents")
parser.add_argument("--process_sentences", help = "process sentences rather than documents", action="store_true")
parser.add_argument("--split_docs_value", help = "threshold number of tokens to split document")
parser.add_argument("--use_bert_base_cased", help = "use Bert base cased rather than Flair", action="store_true")
parser.add_argument("--use_bert_large_cased", help = "use Bert large cased rather than Flair", action="store_true")
parser.add_argument("--use_bert_base_uncased", help = "use Bert base uncased rather than Flair", action="store_true")
parser.add_argument("--use_bert_large_uncased", help = "use Bert large uncased rather than Flair", action="store_true")
parser.add_argument("--use_bert_multilingual", help = "use Bert multilingual rather than Flair", action="store_true")
eriktks marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--use_server", help = "use server", action="store_true")
parser.add_argument("--wiki_version", help = "Wiki version")
args = parser.parse_args()

np.random.seed(seed=42)

base_url = "/Users/vanhulsm/Desktop/projects/data/"
wiki_version = "wiki_2014"
base_url = os.path.abspath(os.path.dirname(__file__) + "/../data/")
process_sentences = args.process_sentences

if args.split_docs_value:
split_docs_value = int(args.split_docs_value)
else:
split_docs_value = 0
eriktks marked this conversation as resolved.
Show resolved Hide resolved
eriktks marked this conversation as resolved.
Show resolved Hide resolved

if args.max_docs:
max_docs = int(args.max_docs)
else:
max_docs = 50
eriktks marked this conversation as resolved.
Show resolved Hide resolved

if args.wiki_version:
wiki_version = args.wiki_version
else:
wiki_version = "wiki_2019"
eriktks marked this conversation as resolved.
Show resolved Hide resolved

datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()["aida_testB"]

# random_docs = np.random.choice(list(datasets.keys()), 50)
use_server = args.use_server
tagger_ner, tagger_ner_name = set_tagger_ner(args.use_bert_base_cased, args.use_bert_base_uncased, args.use_bert_large_cased, args.use_bert_large_uncased, args.use_bert_multilingual)
eriktks marked this conversation as resolved.
Show resolved Hide resolved

print(f"max_docs={max_docs} wiki_version={wiki_version} tagger_ner_name={tagger_ner_name} process_sentences={process_sentences} split_docs_value={split_docs_value}")

server = True
docs = {}
all_results = {}
for i, doc in enumerate(datasets):
sentences = []
for x in datasets[doc]:
if x["sentence"] not in sentences:
sentences.append(x["sentence"])
text = ". ".join([x for x in sentences])
if len(sentences) == 0:
next
eriktks marked this conversation as resolved.
Show resolved Hide resolved

if len(docs) == 50:
print("length docs is 50.")
text = ". ".join([x for x in sentences])
if len(docs) >= max_docs:
print(f"length docs is {len(docs)}.")
print("====================")
break

if len(text.split()) > 200:
docs[doc] = [text, []]
# Demo script that can be used to query the API.
if server:
if use_server:
eriktks marked this conversation as resolved.
Show resolved Hide resolved
myjson = {
"text": text,
"spans": [
Expand All @@ -40,13 +80,24 @@
print(myjson)

print("Output API:")
print(requests.post("http://192.168.178.11:1235", json=myjson).json())
results = requests.post("http://0.0.0.0:5555", json=myjson)
print(results.json())
print("----------------------------")

try:
results_list = []
for result in results.json():
results_list.append({ "mention": result[2], "prediction": result[3] }) # Flair + Bert
all_results[doc] = results_list
except json.decoder.JSONDecodeError:
print("The analysis results are not in json format:", str(results))
all_results[doc] = []
eriktks marked this conversation as resolved.
Show resolved Hide resolved

if len(all_results) > 0:
evaluate(all_results)
eriktks marked this conversation as resolved.
Show resolved Hide resolved

# --------------------- Now total --------------------------------
# ------------- RUN SEPARATELY TO BALANCE LOAD--------------------
if not server:
if not use_server:
from time import time

import flair
Expand All @@ -56,27 +107,26 @@
from REL.entity_disambiguation import EntityDisambiguation
from REL.mention_detection import MentionDetection

base_url = "C:/Users/mickv/desktop/data_back/"
from REL.ner.bert_wrapper import load_bert_ner

flair.device = torch.device("cuda:0")
flair.device = torch.device("cpu")

mention_detection = MentionDetection(base_url, wiki_version)

# Alternatively use Flair NER tagger.
tagger_ner = SequenceTagger.load("ner-fast")

start = time()
mentions_dataset, n_mentions = mention_detection.find_mentions(docs, tagger_ner)
print("MD took: {}".format(time() - start))
mentions_dataset, n_mentions = mention_detection.find_mentions(docs, tagger_ner_name, process_sentences, split_docs_value, tagger_ner)
eriktks marked this conversation as resolved.
Show resolved Hide resolved
print("MD took: {} seconds".format(round(time() - start, 2)))

# 3. Load model.
# 3. Load ED model.
config = {
"mode": "eval",
"model_path": "{}/{}/generated/model".format(base_url, wiki_version),
}
model = EntityDisambiguation(base_url, wiki_version, config)
ed_model = EntityDisambiguation(base_url, wiki_version, config)

# 4. Entity disambiguation.
start = time()
predictions, timing = model.predict(mentions_dataset)
print("ED took: {}".format(time() - start))
predictions, timing = ed_model.predict(mentions_dataset)
print("ED took: {} seconds".format(round(time() - start, 2)))
eriktks marked this conversation as resolved.
Show resolved Hide resolved

evaluate(predictions)
6 changes: 3 additions & 3 deletions scripts/gerbil_middleware/Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
default: build dockerize

build:
mvn clean package -U
mvn clean package -U

dockerize:
docker build -t git.project-hobbit.eu:4567/gerbil/spotwrapnifws4test .
docker build -t git.project-hobbit.eu:4567/gerbil/spotwrapnifws4test .

push:
docker push git.project-hobbit.eu:4567/gerbil/spotwrapnifws4test
docker push git.project-hobbit.eu:4567/gerbil/spotwrapnifws4test
2 changes: 1 addition & 1 deletion scripts/gerbil_middleware/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
<dependency>
<groupId>org.apache.jena</groupId>
<artifactId>jena-core</artifactId>
<version>4.2.0</version>
<version>2.11.1</version>
</dependency>
<dependency>
<groupId>org.apache.jena</groupId>
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ install_requires =
torch
nltk
anyascii
termcolor
syntok

[options.extras_require]
develop =
Expand Down
7 changes: 4 additions & 3 deletions src/REL/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ def lookup_wik(self, w, table_name, column):
"select {} from {} where word = :word".format(column, table_name),
{"word": w},
).fetchone()
res = (
e if e is None else json.loads(e[0].decode()) if column == "p_e_m" else e[0]
)
try:
res = ( e if e is None else json.loads(e[0].decode()) if column == "p_e_m" else e[0] )
except Exception:
res = ( e if e is None else json.loads("".join(chr(int(x, 2)) for x in e[0].split())) if column == "p_e_m" else e[0] )
eriktks marked this conversation as resolved.
Show resolved Hide resolved

return res

Expand Down
136 changes: 136 additions & 0 deletions src/REL/evaluate_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import re


UNUSED = -1


def get_gold_data(doc):
GOLD_DATA_FILE = "./data/generic/test_datasets/AIDA/AIDA-YAGO2-dataset.tsv"
eriktks marked this conversation as resolved.
Show resolved Hide resolved
entities = []

in_file = open(GOLD_DATA_FILE, "r")
for line in in_file:
if re.search(f"^-DOCSTART- \({doc} ", line):
eriktks marked this conversation as resolved.
Show resolved Hide resolved
break
for line in in_file:
if re.search(f"^-DOCSTART- ", line):
eriktks marked this conversation as resolved.
Show resolved Hide resolved
break
fields = line.strip().split("\t")
if len(fields) > 3:
if fields[1] == "B":
entities.append([fields[2], fields[3]])
return entities


def md_match(gold_entities, predicted_entities, predicted_links, gold_i, predicted_i):
return gold_entities[gold_i][0].lower() == predicted_entities[predicted_i][0].lower()


def el_match(gold_entities, predicted_entities, predicted_links, gold_i, predicted_i):
return(gold_entities[gold_i][0].lower() == predicted_entities[predicted_i][0].lower() and
gold_entities[gold_i][1].lower() == predicted_entities[predicted_i][1].lower())


def find_correct_els(gold_entities, predicted_entities, gold_links, predicted_links):
for gold_i in range(0, len(gold_entities)):
if gold_links[gold_i] == UNUSED:
for predicted_i in range(0, len(predicted_entities)):
if (predicted_links[predicted_i] == UNUSED and
el_match(gold_entities, predicted_entities, predicted_links, gold_i, predicted_i)):
gold_links[gold_i] = predicted_i
predicted_links[predicted_i] = gold_i
return gold_links, predicted_links


def find_correct_mds(gold_entities, predicted_entities, gold_links, predicted_links):
for gold_i in range(0, len(gold_entities)):
if gold_links[gold_i] == UNUSED:
for predicted_i in range(0, len(predicted_entities)):
if (predicted_links[predicted_i] == UNUSED and
md_match(gold_entities, predicted_entities, predicted_links, gold_i, predicted_i)):
gold_links[gold_i] = predicted_i
predicted_links[predicted_i] = gold_i
return gold_links, predicted_links



def compare_entities(gold_entities, predicted_entities):
gold_links = len(gold_entities) * [UNUSED]
predicted_links = len(predicted_entities) * [UNUSED]
gold_links, predicted_links = find_correct_els(gold_entities, predicted_entities, gold_links, predicted_links)
gold_links, predicted_links = find_correct_mds(gold_entities, predicted_entities, gold_links, predicted_links)
return gold_links, predicted_links


def count_entities(gold_entities, predicted_entities, gold_links, predicted_links):
correct = 0
wrong_md = 0
wrong_el = 0
missed = 0
for predicted_i in range(0, len(predicted_links)):
if predicted_links[predicted_i] == UNUSED:
wrong_md += 1
elif predicted_entities[predicted_i][1] == gold_entities[predicted_links[predicted_i]][1]:
correct += 1
else:
wrong_el += 1
for gold_i in range(0, len(gold_links)):
if gold_links[gold_i] == UNUSED:
missed += 1
return correct, wrong_md, wrong_el, missed


def compare_and_count_entities(gold_entities, predicted_entities):
gold_links, predicted_links = compare_entities(gold_entities, predicted_entities)
return count_entities(gold_entities, predicted_entities, gold_links, predicted_links)


def compute_md_scores(correct_all, wrong_md_all, wrong_el_all, missed_all):
if correct_all + wrong_el_all > 0:
precision_md = 100*(correct_all + wrong_el_all) / (correct_all + wrong_el_all + wrong_md_all)
recall_md = 100*(correct_all + wrong_el_all) / (correct_all + wrong_el_all + missed_all)
f1_md = 2 * precision_md * recall_md / ( precision_md + recall_md )
eriktks marked this conversation as resolved.
Show resolved Hide resolved
else:
precision_md = 0
recall_md = 0
f1_md = 0
return precision_md, recall_md, f1_md


def compute_el_scores(correct_all, wrong_md_all, wrong_el_all, missed_all):
if correct_all > 0:
precision_el = 100 * correct_all / (correct_all + wrong_md_all + wrong_el_all)
recall_el = 100 * correct_all / (correct_all + wrong_el_all + missed_all)
f1_el = 2 * precision_el * recall_el / ( precision_el + recall_el )
else:
precision_el = 0.0
recall_el = 0
f1_el = 0
return precision_el, recall_el, f1_el


def print_scores(correct_all, wrong_md_all, wrong_el_all, missed_all):
precision_md, recall_md, f1_md = compute_md_scores(correct_all, wrong_md_all, wrong_el_all, missed_all)
precision_el, recall_el, f1_el = compute_el_scores(correct_all, wrong_md_all, wrong_el_all, missed_all)
print("Results: PMD RMD FMD PEL REL FEL: ", end="")
print(f"{precision_md:0.1f}% {recall_md:0.1f}% {f1_md:0.1f}% | ",end="")
print(f"{precision_el:0.1f}% {recall_el:0.1f}% {f1_el:0.1f}%")
eriktks marked this conversation as resolved.
Show resolved Hide resolved
return precision_md, recall_md, f1_md, precision_el, recall_el, f1_el


def evaluate(predictions):
correct_all = 0
wrong_md_all = 0
wrong_el_all = 0
missed_all = 0
for doc in predictions:
gold_entities = get_gold_data(doc)
predicted_entities = []
for mention in predictions[doc]:
predicted_entities.append([mention["mention"], mention["prediction"]])
correct, wrong_md, wrong_el, missed = compare_and_count_entities(gold_entities, predicted_entities)
correct_all += correct
wrong_md_all += wrong_md
wrong_el_all += wrong_el
missed_all += missed
print_scores(correct_all, wrong_md_all, wrong_el_all, missed_all)
Loading