From fc0d076b2d73c99e6bfded8286e2870ae5c4bf02 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 13 Dec 2022 14:35:30 +0100 Subject: [PATCH 01/15] Add submodule for conversational entity linking --- setup.cfg | 1 + src/REL/crel/__init__.py | 0 src/REL/crel/bert_md.py | 94 ++++ src/REL/crel/conv_el.py | 142 ++++++ src/REL/crel/rel_ed.py | 60 +++ src/REL/crel/s2e_pe/__init__.py | 0 src/REL/crel/s2e_pe/consts.py | 3 + .../crel/s2e_pe/coref_bucket_batch_sampler.py | 68 +++ src/REL/crel/s2e_pe/data.py | 260 ++++++++++ src/REL/crel/s2e_pe/modeling.py | 476 ++++++++++++++++++ src/REL/crel/s2e_pe/pe.py | 259 ++++++++++ src/REL/crel/s2e_pe/pe_data.py | 468 +++++++++++++++++ src/REL/crel/s2e_pe/utils.py | 202 ++++++++ tests/test_crel.py | 94 ++++ 14 files changed, 2127 insertions(+) create mode 100644 src/REL/crel/__init__.py create mode 100644 src/REL/crel/bert_md.py create mode 100644 src/REL/crel/conv_el.py create mode 100644 src/REL/crel/rel_ed.py create mode 100644 src/REL/crel/s2e_pe/__init__.py create mode 100644 src/REL/crel/s2e_pe/consts.py create mode 100644 src/REL/crel/s2e_pe/coref_bucket_batch_sampler.py create mode 100644 src/REL/crel/s2e_pe/data.py create mode 100644 src/REL/crel/s2e_pe/modeling.py create mode 100644 src/REL/crel/s2e_pe/pe.py create mode 100644 src/REL/crel/s2e_pe/pe_data.py create mode 100644 src/REL/crel/s2e_pe/utils.py create mode 100644 tests/test_crel.py diff --git a/setup.cfg b/setup.cfg index 8fbd4af..53a5c09 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,6 +47,7 @@ install_requires = konoha flair>=0.11 segtok + spacy torch nltk anyascii diff --git a/src/REL/crel/__init__.py b/src/REL/crel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/REL/crel/bert_md.py b/src/REL/crel/bert_md.py new file mode 100644 index 0000000..23737bd --- /dev/null +++ b/src/REL/crel/bert_md.py @@ -0,0 +1,94 @@ +import torch +from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline + + +class BERT_MD: + def __init__(self, file_pretrained): + """ + + Args: + file_pretrained = "./tmp/ft-conel/" + + Note: + The output of self.ner_model(s_input) is like + - s_input: e.g, 'Burger King franchise' + - return: e.g., [{'entity': 'B-ment', 'score': 0.99364895, 'index': 1, 'word': 'Burger', 'start': 0, 'end': 6}, ...] + """ + + model = AutoModelForTokenClassification.from_pretrained(file_pretrained) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model.to(device) + tokenizer = AutoTokenizer.from_pretrained(file_pretrained) + self.ner_model = pipeline( + "ner", + model=model, + tokenizer=tokenizer, + device=device.index if device.index != None else -1, + ignore_labels=[], + ) + + def md(self, s, flag_warning=False): + """Perform mention detection + + Args: + s: input string + flag_warning: if True, print warning message + + Returns: + REL style annotation results: [(start_position, length, mention), ...] + E.g., [[0, 15, 'The Netherlands'], ...] + """ + + ann = self.ner_model(s) # Get ann results from BERT-NER model + + ret = [] + pos_start, pos_end = -1, -1 # Initialize variables + + for i in range(len(ann)): + w, ner = ann[i]["word"], ann[i]["entity"] + assert ner in [ + "B-ment", + "I-ment", + "O", + ], f"Unexpected ner tag: {ner}. If you use BERT-NER as it is, then you should flag_use_normal_bert_ner_tag=True." + if ner == "B-ment" and w[:2] != "##": + if (pos_start != -1) and (pos_end != -1): # If B-ment is already found + ret.append( + [pos_start, pos_end - pos_start, s[pos_start:pos_end]] + ) # save the previously identified mention + pos_start, pos_end = -1, -1 # Initialize + pos_start, pos_end = ann[i]["start"], ann[i]["end"] + + elif ner == "B-ment" and w[:2] == "##": + if (ann[i]["index"] == ann[i - 1]["index"] + 1) and ( + ann[i - 1]["entity"] != "B-ment" + ): # If previous token has an entity (ner) label and it is NOT "B-ment" (i.e., ##xxx should not be the begin of the entity) + if flag_warning: + print( + f"WARNING: ##xxx (in this case {w}) should not be the begin of the entity" + ) + + elif ( + i > 0 + and (ner == "I-ment") + and (ann[i]["index"] == ann[i - 1]["index"] + 1) + ): # If w is I-ment and previous word's index (i.e., ann[i-1]['index']) is also a mention + pos_end = ann[i]["end"] # update pos_end + + # This only happens when flag_ignore_o is False + elif ( + ner == "O" + and w[:2] == "##" + and ( + ann[i - 1]["entity"] == "B-ment" or ann[i - 1]["entity"] == "I-ment" + ) + ): # If w is "O" and ##xxx, and previous token's index (i.e., ann[i-1]['index']) is B-ment or I-ment + pos_end = ann[i]["end"] # update pos_end + + # Append remaining ment + if (pos_start != -1) and (pos_end != -1): + ret.append( + [pos_start, pos_end - pos_start, s[pos_start:pos_end]] + ) # Save last mention + + return ret diff --git a/src/REL/crel/conv_el.py b/src/REL/crel/conv_el.py new file mode 100644 index 0000000..1ba276a --- /dev/null +++ b/src/REL/crel/conv_el.py @@ -0,0 +1,142 @@ +import importlib +import sys +from pathlib import Path + +from .bert_md import BERT_MD +from .rel_ed import REL_ED +from .s2e_pe import pe_data +from .s2e_pe.pe import EEMD, PEMD + + +class ConvEL: + def __init__( + self, base_url=".", wiki_version="wiki_2019", user_config=None, threshold=0 + ): + self.threshold = threshold + + self.wiki_version = wiki_version + self.base_url = base_url + self.file_pretrained = str(Path(base_url) / "bert_conv-td") + + self.bert_md = BERT_MD(self.file_pretrained) + self.rel_ed = REL_ED(self.base_url, self.wiki_version) + self.eemd = EEMD(s2e_pe_model=str(Path(base_url) / "s2e_ast_onto")) + self.pemd = PEMD() + + self.preprocess = pe_data.PreProcess() + self.postprocess = pe_data.PostProcess() + + # These are always initialize when get_annotations() is called + self.conv_hist_for_pe = ( + [] + ) # initialize the history of conversation, which is used in PE Linking + self.ment2ent = {} # This will be used for PE Linking + + def _error_check(self, conv): + assert type(conv) == list + for turn in conv: + assert type(turn) == dict + assert set(turn.keys()) == {"speaker", "utterance"} + assert turn["speaker"] in [ + "USER", + "SYSTEM", + ], f'Speaker should be either "USER" or "SYSTEM", but got {turn["speaker"]}' + + def _el(self, utt): + """Perform entity linking""" + # MD + md_results = self.bert_md.md(utt) + + # ED + spans = [[r[0], r[1]] for r in md_results] # r[0]: start, r[1]: length + el_results = self.rel_ed.ed(utt, spans) # ED + + self.conv_hist_for_pe[-1]["mentions"] = [r[2] for r in el_results] + self.ment2ent.update( + {r[2]: r[3] for r in el_results} + ) # If there is a mismatch of annotations for the same mentions, the last one (the most closest turn's one to the PEM) will be used. + + return [r[:4] for r in el_results] # [start_pos, length, mention, entity] + + def _pe(self, utt): + """Perform PE Linking""" + + ret = [] + + # Step 1: PE Mention Detection + pem_results = self.pemd.pem_detector(utt) + pem2result = {r[2]: r for r in pem_results} + + # Step 2: Finding corresponding explicit entity mentions (EEMs) + # NOTE: Current implementation can handle only one target PEM at a time + outputs = [] + for _, _, pem in pem_results: # pems: [[start_pos, length, pem], ...] + self.conv_hist_for_pe[-1]["pems"] = [ + pem + ] # Create a conv for each target PEM that you want to link + + # Preprocessing + token_with_info = self.preprocess.get_tokens_with_info( + self.conv_hist_for_pe + ) + input_data = self.preprocess.get_input_of_pe_linking(token_with_info) + + assert ( + len(input_data) == 1 + ), f"Current implementation can handle only one target PEM at a time" + input_data = input_data[0] + + # Finding corresponding explicit entity mentions (EEMs) + scores = self.eemd.get_scores(input_data) + + # Post processing + outputs += self.postprocess.get_results( + input_data, self.conv_hist_for_pe, self.threshold, scores + ) + + self.conv_hist_for_pe[-1]["pems"] = [] # Remove the target PEM + + # Step 3: Get corresponding entity + for r in outputs: + pem = r["personal_entity_mention"] + pem_result = pem2result[pem] # [start_pos, length, pem] + eem = r["mention"] # Explicit entity mention + ent = self.ment2ent[eem] # Corresponding entity + ret.append( + [pem_result[0], pem_result[1], pem_result[2], ent] + ) # [start_pos, length, PEM, entity] + + return ret + + def annotate(self, conv): + """Get conversational entity linking annotations + + Args: + conv: A list of dicts, each dict contains "speaker" and "utterance" keys. + + Returns: + A list of dicts, each dict contains conv's ones + "annotations" key. + """ + self._error_check(conv) + ret = [] + self.conv_hist_for_pe = [] # Initialize + self.ment2ent = {} # Initialize + + for turn in conv: + utt = turn["utterance"] + assert turn["speaker"] in [ + "USER", + "SYSTEM", + ], f'Speaker should be either "USER" or "SYSTEM", but got {turn["speaker"]}' + ret.append({"speaker": turn["speaker"], "utterance": utt}) + + self.conv_hist_for_pe.append({}) + self.conv_hist_for_pe[-1]["speaker"] = turn["speaker"] + self.conv_hist_for_pe[-1]["utterance"] = utt + + if turn["speaker"] == "USER": + el_results = self._el(utt) + pe_results = self._pe(utt) + ret[-1]["annotations"] = el_results + pe_results + + return ret diff --git a/src/REL/crel/rel_ed.py b/src/REL/crel/rel_ed.py new file mode 100644 index 0000000..a52148c --- /dev/null +++ b/src/REL/crel/rel_ed.py @@ -0,0 +1,60 @@ +import sys + +from REL.entity_disambiguation import EntityDisambiguation +from REL.mention_detection import MentionDetection +from REL.utils import process_results + + +class REL_ED: + def __init__(self, base_url, wiki_version): + + config = { + "mode": "eval", + "model_path": f"{base_url}/{wiki_version}/generated/model", + } + + self.mention_detection = MentionDetection( + base_url, wiki_version + ) # This is only used for format spans + self.model = EntityDisambiguation(base_url, wiki_version, config) + + def generate_response(self, text, spans): + """Generate ED results + + Returns: + - list of tuples for each entity found. + + Note: + - Original code: https://github.com/informagi/REL/blob/9ca253b1d371966c39219ed672f39784fd833d8d/REL/server.py#L101 + """ + + API_DOC = "API_DOC" + + if len(text) == 0 or len(spans) == 0: + return [] + + # Get the mentions from the spans + processed = {API_DOC: [text, spans]} + mentions_dataset, total_ment = self.mention_detection.format_spans(processed) + + # Disambiguation + predictions, timing = self.model.predict(mentions_dataset) + + # Process result. + result = process_results( + mentions_dataset, + predictions, + processed, + include_offset=False if (len(spans) > 0) else True, + ) + + # Singular document. + if len(result) > 0: + return [*result.values()][0] + + return [] + + def ed(self, text, spans): + """Change tuple to list to match the output format of REL API.""" + response = self.generate_response(text, spans) + return [list(ent) for ent in response] diff --git a/src/REL/crel/s2e_pe/__init__.py b/src/REL/crel/s2e_pe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/REL/crel/s2e_pe/consts.py b/src/REL/crel/s2e_pe/consts.py new file mode 100644 index 0000000..97bffe5 --- /dev/null +++ b/src/REL/crel/s2e_pe/consts.py @@ -0,0 +1,3 @@ +SPEAKER_START = 49518 # 'Ġ#####' +SPEAKER_END = 22560 # 'Ġ###' +NULL_ID_FOR_COREF = 0 diff --git a/src/REL/crel/s2e_pe/coref_bucket_batch_sampler.py b/src/REL/crel/s2e_pe/coref_bucket_batch_sampler.py new file mode 100644 index 0000000..b252973 --- /dev/null +++ b/src/REL/crel/s2e_pe/coref_bucket_batch_sampler.py @@ -0,0 +1,68 @@ +import logging +import math +import random +from typing import Iterable, List + +from torch.utils import data +from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + + +class BucketBatchSampler(DataLoader): + def __init__( + self, + data_source: data.Dataset, + max_total_seq_len: int, + sorting_keys: List[str] = None, + padding_noise: float = 0.1, + drop_last: bool = False, + batch_size_1: bool = False, + ): + self.sorting_keys = sorting_keys + self.padding_noise = padding_noise + self.max_total_seq_len = max_total_seq_len + self.data_source = data_source + data_source.examples.sort(key=lambda x: len(x[1].token_ids), reverse=True) + self.drop_last = drop_last + self.batches = ( + self.prepare_batches() if not batch_size_1 else self.prepare_eval_batches() + ) + + def prepare_batches(self): + batches = [] + batch = [] + per_example_batch_len = 0 + for ( + doc_key, + elem, + ) in self.data_source: # NOTE: The `doc_key` also contains subtoken map + batch.append(elem) + batch = self.data_source.pad_batch(batch, len(batch[0].token_ids)) + batches.append((doc_key, batch)) + batch = [] + + if len(batch) == 0: + return batches + batch = self.data_source.pad_batch(batch, len(batch[0].token_ids)) + batches.append( + (doc_key, batch) + ) # 220229: Change to return dockey as `prepare_eval_batches()` does. + return batches + + def __iter__(self) -> Iterable[List[int]]: + random.shuffle(self.batches) + yield from self.batches + + def __len__(self): + return len(self.batches) + + def calc_effective_per_example_batch_len(self, example_len): + return math.ceil((example_len + 2) / 512) * 512 + + def prepare_eval_batches(self): + batches = [] + for doc_key, elem in self.data_source: + batch = self.data_source.pad_batch([elem], len(elem.token_ids)) + batches.append((doc_key, batch)) + return batches diff --git a/src/REL/crel/s2e_pe/data.py b/src/REL/crel/s2e_pe/data.py new file mode 100644 index 0000000..7cb5488 --- /dev/null +++ b/src/REL/crel/s2e_pe/data.py @@ -0,0 +1,260 @@ +import json +import logging +import os +import pickle +from collections import namedtuple + +import torch +from torch.utils.data import Dataset + +from .consts import NULL_ID_FOR_COREF, SPEAKER_END, SPEAKER_START +from .utils import flatten_list_of_lists + +CorefExample = namedtuple("CorefExample", ["token_ids", "clusters"]) + +logger = logging.getLogger(__name__) + + +class CorefDataset(Dataset): + def __init__(self, input_data, tokenizer, model_name_or_path, max_seq_length=-1): + self.tokenizer = tokenizer + ( + examples, + self.max_mention_num, + self.max_cluster_size, + self.max_num_clusters, + dockey2eems_tokenspan, + dockey2pems_tokenspan, + ) = self._parse_jsonlines(input_data) + self.max_seq_length = max_seq_length + ( + self.examples, + self.lengths, + self.num_examples_filtered, + self.dockey2eems_subtokenspan, + self.dockey2pems_subtokenspan, + ) = self._tokenize( + examples, dockey2eems_tokenspan, dockey2pems_tokenspan, model_name_or_path + ) + logger.info( + f"Finished preprocessing Coref dataset. {len(self.examples)} examples were extracted, {self.num_examples_filtered} were filtered due to sequence length." + ) + + def _parse_jsonlines(self, d): + examples = [] + max_mention_num = -1 + max_cluster_size = -1 + max_num_clusters = -1 + dockey2pems_tokenspan = {} + dockey2eems_tokenspan = {} + doc_key = d["doc_key"] + assert ( + type(d["sentences"][0]) == list + ), "'sentences' should be 2d list, not just a 1d list of the tokens." + input_words = flatten_list_of_lists(d["sentences"]) + clusters = d["clusters"] + max_mention_num = max(max_mention_num, len(flatten_list_of_lists(clusters))) + max_cluster_size = max( + max_cluster_size, + max(len(cluster) for cluster in clusters) if clusters else 0, + ) + max_num_clusters = max(max_num_clusters, len(clusters) if clusters else 0) + speakers = flatten_list_of_lists(d["speakers"]) + examples.append((doc_key, input_words, clusters, speakers)) + dockey2eems_tokenspan[doc_key] = d["mentions"] + dockey2pems_tokenspan[doc_key] = d["pems"] + return ( + examples, + max_mention_num, + max_cluster_size, + max_num_clusters, + dockey2eems_tokenspan, + dockey2pems_tokenspan, + ) + + def _tokenize( + self, examples, dockey2eems_tokenspan, dockey2pems_tokenspan, model_name_or_path + ): + coref_examples = [] + lengths = [] + num_examples_filtered = 0 + dockey2eems_subtokenspan = {} + dockey2pems_subtokenspan = {} + for doc_key, words, clusters, speakers in examples: + word_idx_to_start_token_idx = dict() + word_idx_to_end_token_idx = dict() + end_token_idx_to_word_idx = [0] # for + + token_ids = [] + last_speaker = None + for idx, (word, speaker) in enumerate(zip(words, speakers)): + if last_speaker != speaker: + speaker_prefix = ( + [SPEAKER_START] + + self.tokenizer.encode(" " + speaker, add_special_tokens=False) + + [SPEAKER_END] + ) + last_speaker = speaker + else: + speaker_prefix = [] + for _ in range(len(speaker_prefix)): + end_token_idx_to_word_idx.append(idx) + token_ids.extend(speaker_prefix) + word_idx_to_start_token_idx[idx] = len(token_ids) + 1 # +1 for + tokenized = self.tokenizer.encode(" " + word, add_special_tokens=False) + for _ in range(len(tokenized)): + end_token_idx_to_word_idx.append(idx) + token_ids.extend(tokenized) + word_idx_to_end_token_idx[idx] = len( + token_ids + ) # old_seq_len + 1 (for ) + len(tokenized_word) - 1 (we start counting from zero) = len(token_ids) + + if 0 < self.max_seq_length < len(token_ids): + num_examples_filtered += 1 + continue + + new_clusters = [ + [ + (word_idx_to_start_token_idx[start], word_idx_to_end_token_idx[end]) + for start, end in cluster + ] + for cluster in clusters + ] + lengths.append(len(token_ids)) + + coref_examples.append( + ( + (doc_key, end_token_idx_to_word_idx), + CorefExample(token_ids=token_ids, clusters=new_clusters), + ) + ) + + dockey2eems_subtokenspan[doc_key] = [ + (word_idx_to_start_token_idx[start], word_idx_to_end_token_idx[end]) + for start, end in dockey2eems_tokenspan[doc_key] + ] + dockey2pems_subtokenspan[doc_key] = [ + (word_idx_to_start_token_idx[start], word_idx_to_end_token_idx[end]) + for start, end in dockey2pems_tokenspan[doc_key] + ] + + return ( + coref_examples, + lengths, + num_examples_filtered, + dockey2eems_subtokenspan, + dockey2pems_subtokenspan, + ) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, item): + return self.examples[item] + + def pad_clusters_inside(self, clusters): + return [ + cluster + + [(NULL_ID_FOR_COREF, NULL_ID_FOR_COREF)] + * (self.max_cluster_size - len(cluster)) + for cluster in clusters + ] + + def pad_clusters_outside(self, clusters): + return clusters + [[]] * (self.max_num_clusters - len(clusters)) + + def pad_clusters(self, clusters): + clusters = self.pad_clusters_outside(clusters) + clusters = self.pad_clusters_inside(clusters) + return clusters + + def _pe_create_tensored_batch(self, padded_batch, len_example): + """Create tensored_batch avoiding errors + Original code was: + `tensored_batch = tuple(torch.stack([example[i].squeeze() for example in padded_batch], dim=0) for i in range(len(example)))` + However, this does not handle the single cluster case (E.g., "clusters": [[[135, 136], [273, 273]]] in the train.english.jsonlines) + + The error caused by the above is like (220322): + gold_clusters = [tuple(tuple(m) for m in gc if NULL_ID_FOR_COREF not in m) for gc in gold_clusters.tolist()] + TypeError: argument of type 'int' is not iterable + + - Updates: + - 220228: Created + - 220322: Write the error details + """ + assert len_example == 3 + tensored_batch = tuple() + for i in range(len_example): + to_stack = [] + for example in padded_batch: + assert ( + len(example) == 3 + ), f"example contains three components: input_ids, attention_mask, and clusters. Current len(examples): {len(example)}" + if i < 2: # input_ids and attention_mask + to_stack.append(example[i].squeeze()) + elif i == 2: # clusters + to_stack.append( + example[i] + ) # squeeze() cause the error to single-cluster case + # add to_stack to tensored_batch (tuple) + tensored_batch += (torch.stack(to_stack, dim=0),) + return tensored_batch + + def pad_batch(self, batch, max_length): + max_length += 2 # we have additional two special tokens , + padded_batch = [] + for example in batch: + # encoded_dict = self.tokenizer.encode_plus(example[0], # This does not work transformers v4.18.0 (works with v3.3.1) + # See: https://github.com/huggingface/transformers/issues/10297 + encoded_dict = self.tokenizer.encode_plus( + example[0], + use_fast=False, + add_special_tokens=True, + pad_to_max_length=True, + max_length=max_length, + return_attention_mask=True, + return_tensors="pt", + ) + clusters = self.pad_clusters(example.clusters) + example = (encoded_dict["input_ids"], encoded_dict["attention_mask"]) + ( + torch.tensor(clusters), + ) + padded_batch.append(example) + # tensored_batch = tuple(torch.stack([example[i].squeeze() for example in padded_batch], dim=0) for i in range(len(example))) + tensored_batch = self._pe_create_tensored_batch( + padded_batch, len(example) + ) # HJ: 220228 + return tensored_batch + + +def get_dataset(tokenizer, input_data, conf): + """Read input data + + Args: + - tokenizer + - input_data (dict): Input dict containing the following keys: + dict_keys(['clusters', 'doc_key', 'mentions', 'pems', 'sentences', 'speakers']) + E.g., + test_jsonl = { + "clusters": [[[78, 83], [88, 89]]], # This can be blank when you want to perform prediction. + "doc_key": "dialind:0_turn:3_pem:my-favorite-forms-of-science-fiction", # doc_key should be unique, no restrictions on the format + "mentions": [[35, 35], [37, 38], [74, 74], [85, 85], [88, 89]], # mentions and spans should be token-level spans (i.e., different from REL). See original document of s2e-coref. + "pems": [[78, 83]], + "sentences": [["I", "think", "science", "fiction", "is", ...], ...], + "speakers": [["SYSTEM", "SYSTEM", "SYSTEM", ..., "USER", "USER", "USER", ...], ...], } + + + Returns: + - dataset (CorefDataset): + + Notes: + - Currently, parallel processing is not supported, i.e., you cannot input more than or equal to two sentences or PEMs at the same time. + """ + + coref_dataset = CorefDataset( + input_data, + tokenizer, + max_seq_length=conf.max_seq_length, + model_name_or_path=conf.model_name_or_path, + ) + return coref_dataset diff --git a/src/REL/crel/s2e_pe/modeling.py b/src/REL/crel/s2e_pe/modeling.py new file mode 100644 index 0000000..7ed3150 --- /dev/null +++ b/src/REL/crel/s2e_pe/modeling.py @@ -0,0 +1,476 @@ +import torch +from torch.nn import Dropout, LayerNorm, Linear, Module +from transformers import BertModel, BertPreTrainedModel, LongformerModel + +try: # If you use `211018_s2e_coref` + from transformers.modeling_bert import ACT2FN +except: # If you use `jupyterlab-debugger` + from transformers.models.bert.modeling_bert import ACT2FN + +import json +import os + +from .utils import extract_clusters # , ce_get_start_end_subtoken_num +from .utils import extract_mentions_to_predicted_clusters_from_clusters, mask_tensor + + +class FullyConnectedLayer(Module): + def __init__(self, config, input_dim, output_dim, dropout_prob): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.dropout_prob = dropout_prob + + self.dense = Linear(self.input_dim, self.output_dim) + self.layer_norm = LayerNorm(self.output_dim, eps=config.layer_norm_eps) + self.activation_func = ACT2FN[config.hidden_act] + self.dropout = Dropout(self.dropout_prob) + + def forward(self, inputs): + temp = inputs + temp = self.dense(temp) + temp = self.activation_func(temp) + temp = self.layer_norm(temp) + temp = self.dropout(temp) + return temp + + +class S2E(BertPreTrainedModel): + def __init__(self, config, args): + super().__init__(config) + self.max_span_length = args.max_span_length + self.top_lambda = args.top_lambda + self.ffnn_size = args.ffnn_size + self.do_mlps = self.ffnn_size > 0 + self.ffnn_size = self.ffnn_size if self.do_mlps else config.hidden_size + self.normalise_loss = args.normalise_loss + + # self.longformer = LongformerModel(config) + self.longformer = BertModel(config) + + self.start_mention_mlp = ( + FullyConnectedLayer( + config, config.hidden_size, self.ffnn_size, args.dropout_prob + ) + if self.do_mlps + else None + ) + self.end_mention_mlp = ( + FullyConnectedLayer( + config, config.hidden_size, self.ffnn_size, args.dropout_prob + ) + if self.do_mlps + else None + ) + self.start_coref_mlp = ( + FullyConnectedLayer( + config, config.hidden_size, self.ffnn_size, args.dropout_prob + ) + if self.do_mlps + else None + ) + self.end_coref_mlp = ( + FullyConnectedLayer( + config, config.hidden_size, self.ffnn_size, args.dropout_prob + ) + if self.do_mlps + else None + ) + + self.start_coref_mlp = ( + FullyConnectedLayer( + config, config.hidden_size, self.ffnn_size, args.dropout_prob + ) + if self.do_mlps + else None + ) + self.end_coref_mlp = ( + FullyConnectedLayer( + config, config.hidden_size, self.ffnn_size, args.dropout_prob + ) + if self.do_mlps + else None + ) + + self.mention_start_classifier = Linear(self.ffnn_size, 1) + self.mention_end_classifier = Linear(self.ffnn_size, 1) + self.mention_s2e_classifier = Linear(self.ffnn_size, self.ffnn_size) + + self.antecedent_s2s_classifier = Linear(self.ffnn_size, self.ffnn_size) + self.antecedent_e2e_classifier = Linear(self.ffnn_size, self.ffnn_size) + self.antecedent_s2e_classifier = Linear(self.ffnn_size, self.ffnn_size) + self.antecedent_e2s_classifier = Linear(self.ffnn_size, self.ffnn_size) + + self.init_weights() + + def _get_span_mask(self, batch_size, k, max_k): + """ + :param batch_size: int + :param k: tensor of size [batch_size], with the required k for each example + :param max_k: int + :return: [batch_size, max_k] of zero-ones, where 1 stands for a valid span and 0 for a padded span + """ + size = (batch_size, max_k) + idx = torch.arange(max_k, device=self.device).unsqueeze(0).expand(size) + len_expanded = k.unsqueeze(1).expand(size).to(self.device) + ret = (idx < len_expanded).int() + return ret + + def _prune_topk_mentions(self, mention_logits, attention_mask): + """ + :param mention_logits: Shape [batch_size, seq_length, seq_length] + :param attention_mask: [batch_size, seq_length] + :param top_lambda: + :return: + """ + batch_size, seq_length, _ = mention_logits.size() + actual_seq_lengths = torch.sum(attention_mask, dim=-1) # [batch_size] + + k = ( + actual_seq_lengths * self.top_lambda + ).int() # [batch_size] # top_lambda is in argument of the `run_coref.py` + max_k = int( + torch.max(k) + ) # This is the k for the largest input in the batch, we will need to pad + + _, topk_1d_indices = torch.topk( + mention_logits.view(batch_size, -1), dim=-1, k=max_k + ) # [batch_size, max_k] + span_mask = self._get_span_mask(batch_size, k, max_k) # [batch_size, max_k] + topk_1d_indices = (topk_1d_indices * span_mask) + (1 - span_mask) * ( + (seq_length**2) - 1 + ) # We take different k for each example + sorted_topk_1d_indices, _ = torch.sort( + topk_1d_indices, dim=-1 + ) # [batch_size, max_k] + + topk_mention_start_ids = ( + sorted_topk_1d_indices // seq_length + ) # [batch_size, max_k] + topk_mention_end_ids = ( + sorted_topk_1d_indices % seq_length + ) # [batch_size, max_k] + + topk_mention_logits = mention_logits[ + torch.arange(batch_size).unsqueeze(-1).expand(batch_size, max_k), + topk_mention_start_ids, + topk_mention_end_ids, + ] # [batch_size, max_k] + + topk_mention_logits = topk_mention_logits.unsqueeze( + -1 + ) + topk_mention_logits.unsqueeze( + -2 + ) # [batch_size, max_k, max_k] + + return ( + topk_mention_start_ids, + topk_mention_end_ids, + span_mask, + topk_mention_logits, + ) + + def _ce_prune_pem_eem( + self, mention_logits, pem_eem_subtokenspan + ): # attention_mask, subtoken_map, pem_eem_subtokenspan): + + batch_size, seq_length, _ = mention_logits.size() + assert batch_size == 1 # HJ: currently, only batch_size==1 is supported + + k = torch.Tensor([len(pem_eem_subtokenspan)]) + + max_k = int( + torch.max(k) + ) # This is the k for the largest input in the batch, we will need to pad + + span_mask = self._get_span_mask(batch_size, k, max_k) # [batch_size, max_k] + + pem_eem_start_ids, pem_eem_end_ids = torch.Tensor( + [[span[0] for span in pem_eem_subtokenspan]] + ).long().to(self.device), torch.Tensor( + [[span[1] for span in pem_eem_subtokenspan]] + ).long().to( + self.device + ) # HJ: 220302: [[...]] should be use because we have n batchs (here, we use n=1 though) + + return ( + pem_eem_start_ids, + pem_eem_end_ids, + span_mask, + None, + ) # topk_mention_logits + + def _mask_antecedent_logits(self, antecedent_logits, span_mask): + # We now build the matrix for each pair of spans (i,j) - whether j is a candidate for being antecedent of i? + antecedents_mask = torch.ones_like(antecedent_logits, dtype=self.dtype).tril( + diagonal=-1 + ) # [batch_size, k, k] + antecedents_mask = antecedents_mask * span_mask.unsqueeze( + -1 + ) # [batch_size, k, k] + antecedent_logits = mask_tensor(antecedent_logits, antecedents_mask) + return antecedent_logits + + def _get_cluster_labels_after_pruning(self, span_starts, span_ends, all_clusters): + """ + :param span_starts: [batch_size, max_k] + :param span_ends: [batch_size, max_k] + :param all_clusters: [batch_size, max_cluster_size, max_clusters_num, 2] + :return: [batch_size, max_k, max_k + 1] - [b, i, j] == 1 if i is antecedent of j + """ + batch_size, max_k = span_starts.size() + new_cluster_labels = torch.zeros((batch_size, max_k, max_k + 1), device="cpu") + all_clusters_cpu = all_clusters.cpu().numpy() + for b, (starts, ends, gold_clusters) in enumerate( + zip(span_starts.cpu().tolist(), span_ends.cpu().tolist(), all_clusters_cpu) + ): + gold_clusters = extract_clusters(gold_clusters) + mention_to_gold_clusters = ( + extract_mentions_to_predicted_clusters_from_clusters(gold_clusters) + ) + gold_mentions = set(mention_to_gold_clusters.keys()) + for i, (start, end) in enumerate(zip(starts, ends)): + if (start, end) not in gold_mentions: + continue + for j, (a_start, a_end) in enumerate(list(zip(starts, ends))[:i]): + if (a_start, a_end) in mention_to_gold_clusters[(start, end)]: + new_cluster_labels[b, i, j] = 1 + new_cluster_labels = new_cluster_labels.to(self.device) + no_antecedents = 1 - torch.sum(new_cluster_labels, dim=-1).bool().float() + new_cluster_labels[:, :, -1] = no_antecedents + return new_cluster_labels + + def _get_marginal_log_likelihood_loss( + self, coref_logits, cluster_labels_after_pruning, span_mask + ): + """ + :param coref_logits: [batch_size, max_k, max_k] + :param cluster_labels_after_pruning: [batch_size, max_k, max_k] + :param span_mask: [batch_size, max_k] + :return: + """ + gold_coref_logits = mask_tensor(coref_logits, cluster_labels_after_pruning) + + gold_log_sum_exp = torch.logsumexp( + gold_coref_logits, dim=-1 + ) # [batch_size, max_k] + all_log_sum_exp = torch.logsumexp(coref_logits, dim=-1) # [batch_size, max_k] + + gold_log_probs = gold_log_sum_exp - all_log_sum_exp + losses = -gold_log_probs + losses = losses * span_mask + per_example_loss = torch.sum(losses, dim=-1) # [batch_size] + if self.normalise_loss: + per_example_loss = per_example_loss / losses.size(-1) + loss = per_example_loss.mean() + return loss + + def _get_mention_mask(self, mention_logits_or_weights): + """ + Returns a tensor of size [batch_size, seq_length, seq_length] where valid spans + (start <= end < start + max_span_length) are 1 and the rest are 0 + :param mention_logits_or_weights: Either the span mention logits or weights, size [batch_size, seq_length, seq_length] + """ + mention_mask = torch.ones_like(mention_logits_or_weights, dtype=self.dtype) + mention_mask = mention_mask.triu(diagonal=0) + mention_mask = mention_mask.tril(diagonal=self.max_span_length - 1) + return mention_mask + + def _calc_mention_logits(self, start_mention_reps, end_mention_reps): + start_mention_logits = self.mention_start_classifier( + start_mention_reps + ).squeeze( + -1 + ) # [batch_size, seq_length] + end_mention_logits = self.mention_end_classifier(end_mention_reps).squeeze( + -1 + ) # [batch_size, seq_length] + + temp = self.mention_s2e_classifier( + start_mention_reps + ) # [batch_size, seq_length] + joint_mention_logits = torch.matmul( + temp, end_mention_reps.permute([0, 2, 1]) + ) # [batch_size, seq_length, seq_length] + + mention_logits = ( + joint_mention_logits + + start_mention_logits.unsqueeze(-1) + + end_mention_logits.unsqueeze(-2) + ) + mention_mask = self._get_mention_mask( + mention_logits + ) # [batch_size, seq_length, seq_length] + mention_logits = mask_tensor( + mention_logits, mention_mask + ) # [batch_size, seq_length, seq_length] + return mention_logits + + def _calc_coref_logits(self, top_k_start_coref_reps, top_k_end_coref_reps): + # s2s + temp = self.antecedent_s2s_classifier( + top_k_start_coref_reps + ) # [batch_size, max_k, dim] + top_k_s2s_coref_logits = torch.matmul( + temp, top_k_start_coref_reps.permute([0, 2, 1]) + ) # [batch_size, max_k, max_k] + + # e2e + temp = self.antecedent_e2e_classifier( + top_k_end_coref_reps + ) # [batch_size, max_k, dim] + top_k_e2e_coref_logits = torch.matmul( + temp, top_k_end_coref_reps.permute([0, 2, 1]) + ) # [batch_size, max_k, max_k] + + # s2e + temp = self.antecedent_s2e_classifier( + top_k_start_coref_reps + ) # [batch_size, max_k, dim] + top_k_s2e_coref_logits = torch.matmul( + temp, top_k_end_coref_reps.permute([0, 2, 1]) + ) # [batch_size, max_k, max_k] + + # e2s + temp = self.antecedent_e2s_classifier( + top_k_end_coref_reps + ) # [batch_size, max_k, dim] + top_k_e2s_coref_logits = torch.matmul( + temp, top_k_start_coref_reps.permute([0, 2, 1]) + ) # [batch_size, max_k, max_k] + + # sum all terms + coref_logits = ( + top_k_s2e_coref_logits + + top_k_e2s_coref_logits + + top_k_s2s_coref_logits + + top_k_e2e_coref_logits + ) # [batch_size, max_k, max_k] + return coref_logits + + def _ce_get_scores( + self, mention_start_ids, mention_end_ids, subtoken_map, final_logits, doc_key + ): + """Get scores""" + scores = ( + [] + ) # score_dc = [{'span_subtoken':(start_id, end_id), 'coref_logits':coref_logits}, ...] + _mention_start_ids_flatten = mention_start_ids[0] # [N_mention] + _mention_end_ids_flatten = mention_end_ids[0] # [N_mention] + _final_logits_2d = final_logits[ + 0 + ] # [N_mention, N_mention+1] (+1 is for threshold) + # Get the length of the _mention_start_ids_flatten + N = len(_mention_start_ids_flatten) + + for i in range(N): # loop for anaphoras + for j in range(N): # loop for antecedents + if i <= j: + continue # anaphoras should not come before antecedents + scores.append( + { + "doc_key": doc_key, + "span_token_anaphora": ( + int(subtoken_map[_mention_start_ids_flatten[i]]), + int(subtoken_map[_mention_end_ids_flatten[i]]), + ), + "span_token_antecedent": ( + int(subtoken_map[_mention_start_ids_flatten[j]]), + int(subtoken_map[_mention_end_ids_flatten[j]]), + ), + # 'span_subtoken': (int(_mention_start_ids_flatten[i]), int(_mention_end_ids_flatten[i])), + # 'subtoken_map': subtoken_map, + "coref_logits": float(_final_logits_2d[i][j]), + } + ) + + return scores + + # def forward(self, input_ids, attention_mask=None, gold_clusters=None, return_all_outputs=False): + def forward( + self, + input_ids, + attention_mask=None, + gold_clusters=None, + return_all_outputs=False, + subtoken_map=None, + pem_eem_subtokenspan=None, + doc_key=None, + ): + # TODO: Change the argument of this forward func, from `pem_eem` to `pem` and `eem` + # And do pem_eem = pem+eem + + outputs = self.longformer(input_ids, attention_mask=attention_mask) + sequence_output = outputs[0] # [batch_size, seq_len, dim] + # MEMO: `sequence_output` should be a hidden vector of the model. + # Originally, this seems to be acquired by `outputs.last_hidden_state` (https://huggingface.co/transformers/master/model_doc/longformer.html) + + # Compute representations + start_mention_reps = ( + self.start_mention_mlp(sequence_output) if self.do_mlps else sequence_output + ) + end_mention_reps = ( + self.end_mention_mlp(sequence_output) if self.do_mlps else sequence_output + ) + + start_coref_reps = ( + self.start_coref_mlp(sequence_output) if self.do_mlps else sequence_output + ) + end_coref_reps = ( + self.end_coref_mlp(sequence_output) if self.do_mlps else sequence_output + ) + + # mention scores + mention_logits = self._calc_mention_logits(start_mention_reps, end_mention_reps) + + # prune mentions + # (span_mask: [batch_size, max_k] of zero-ones, where 1 stands for a valid span and 0 for a padded span) + # mention_start_ids, mention_end_ids, span_mask, topk_mention_logits = self._prune_topk_mentions(mention_logits, attention_mask) + mention_start_ids, mention_end_ids, span_mask, _ = self._ce_prune_pem_eem( + mention_logits, pem_eem_subtokenspan + ) + + batch_size, _, dim = start_coref_reps.size() + max_k = mention_start_ids.size(-1) + size = (batch_size, max_k, dim) + + # Antecedent scores + # gather reps + topk_start_coref_reps = torch.gather( + start_coref_reps, dim=1, index=mention_start_ids.unsqueeze(-1).expand(size) + ) + topk_end_coref_reps = torch.gather( + end_coref_reps, dim=1, index=mention_end_ids.unsqueeze(-1).expand(size) + ) + coref_logits = self._calc_coref_logits( + topk_start_coref_reps, topk_end_coref_reps + ) + final_logits = coref_logits # topk_mention_logits + coref_logits + final_logits = self._mask_antecedent_logits(final_logits, span_mask) + # adding zero logits for null span + final_logits = torch.cat( + (final_logits, torch.zeros((batch_size, max_k, 1), device=self.device)), + dim=-1, + ) # [batch_size, max_k, max_k + 1] + scores = self._ce_get_scores( + mention_start_ids, mention_end_ids, subtoken_map, final_logits, doc_key + ) + + if return_all_outputs: + outputs = (mention_start_ids, mention_end_ids, final_logits, mention_logits) + else: + outputs = tuple() + + if gold_clusters is not None: + losses = {} + labels_after_pruning = self._get_cluster_labels_after_pruning( + mention_start_ids, mention_end_ids, gold_clusters + ) + loss = self._get_marginal_log_likelihood_loss( + final_logits, labels_after_pruning, span_mask + ) # HJ: 220303: `labels_after_pruning` is strange... + losses.update({"loss": loss}) + outputs = (loss,) + outputs + (losses,) + + return outputs, scores diff --git a/src/REL/crel/s2e_pe/pe.py b/src/REL/crel/s2e_pe/pe.py new file mode 100644 index 0000000..e6ca80c --- /dev/null +++ b/src/REL/crel/s2e_pe/pe.py @@ -0,0 +1,259 @@ +# PEMD +from tokenizers.pre_tokenizers import Whitespace + +pre_tokenizer = Whitespace() +try: + import spacy + + nlp = spacy.load("en_core_web_md") +except: # From Google Colab (see https://stackoverflow.com/a/59197634) + import spacy.cli + + spacy.cli.download("en_core_web_md") + import en_core_web_md + + nlp = en_core_web_md.load() +import torch +from transformers import AutoConfig, AutoTokenizer, LongformerConfig + +# EEMD +from . import data +from .coref_bucket_batch_sampler import BucketBatchSampler +from .modeling import S2E +from .pe_data import PreProcess # to use get_span() + + +class PEMD: + """Responsible for personal entity mention detection""" + + def __init__(self): + self.pronouns = ["my", "our"] # These should be lowercase + self.preprocess = PreProcess() # to use get_span() + + def _extract_text_with_pronoun(self, utt: str, max_candidate_num=10): + """ + + Args: + max_candidate_num (int): Max following words num (which equals to candidate num). Does not contain "my/our" in this count. + + Example: + Input: 'Our town is big into high school football - our quarterback just left to go play for Clemson. Oh, that is my town.' + Output: + [{'extracted_text': 'Our town is big into high school football - our quarterback', + 'pronoun': ('Our', (0, 3))}, ...] + """ + if any( + [True for p in self.pronouns if p in utt.lower()] + ): # If at least one pronoun is in utt.lower() + ret = [] + else: # If no pronouns are in utt.lower() + return [] + + try: # if tokenizer version is 0.10.3 etc where pre_tokenize_str is available + ws = pre_tokenizer.pre_tokenize_str( + utt + ) # E.g., [('Our', (0, 3)), ('town', (4, 8)), ...] + except: # if 0.8.1.rc2 etc where pre_tokenizer_str is NOT available + ws = pre_tokenizer.pre_tokenize(utt) + for i, (w, _) in enumerate(ws): + if w.lower() in self.pronouns: + n_options = min( + max_candidate_num, len(ws[i:]) - 1 + ) # `len(ws[i:])` contains "my/our" so have to operate -1 + text_w_pronoun = utt[ + ws[i][1][0] : ws[i + n_options][1][1] + ] # E.g., 'our quarterback just ...' # `ws[i][1][0]`: start position. `ws[i+n+2][1][1]`: end position. + ret.append({"pronoun": ws[i], "extracted_text": text_w_pronoun}) + return ret + + def pem_detector(self, utt): + """Mention detection for personal entities + + Args: + utt (str): Input utterance. + E.g., 'I agree. One of my favorite forms of science fiction is anything related to time travel! I find it fascinating.' + + Returns: + list of detected personal entity mentions. + E.g., ['my favorite forms of science fiction'] + + """ + _dc_list = self._extract_text_with_pronoun( + utt + ) # E.g., [{'extracted_text': 'Our town is big into high ...', 'pronoun': ('Our', (0, 3))}, ...] + if len(_dc_list) == 0: + return [] + else: + texts_w_pronoun = [ + _dc["extracted_text"] for _dc in _dc_list + ] # E.g., ['Our town is big into ...', 'My dog loves human food!'] + + ret = [] + for text in texts_w_pronoun: # For each extracted text + doc = nlp(text) + ment = "" + end_pos = 0 # start_pos is always 0 + for i, token in enumerate(doc): + # print(token.pos_, token.text) + if i == 0: + assert ( + token.text.lower() in self.pronouns + ), f"{token.text} does not start with 'my' or 'our'" + end_pos = token.idx + len(token.text) # update end_pos + else: # i > 0 + if token.pos_ in [ + "ADJ", + "NOUN", + "PROPN", + "NUM", + "PART", + ] or token.text.lower() in [ + "of", + "in", + "the", + "a", + "an", + ]: + end_pos = token.idx + len(token.text) # update end_pos + else: + break + ment = text[:end_pos] + + ###### Post process ####### + # if end with " of " then remove it + for drop_term in [" of", " in", " to"]: + ment = ( + ment[: -(len(drop_term) - 1)] if ment.endswith(drop_term) else ment + ) + + if ( + len(ment) > min([len(pron) for pron in self.pronouns]) + 1 + ): # Want to ignore the case: "My " + ret.append(ment.strip()) + + # 220406: TMP error check + # TODO: Check this part whether it is needed or not + assert len(ment) != "our ", f'Should fix "if len(ment)>len(CLUE)+1" part.' + + # Sanity check + for ment in ret: + assert ment in utt, f"{ment} is not in {utt}" + + # Change to REL format [start_pos, length, mention] + ret = [ + [start_pos, len(m), m] + for m in ret + for start_pos, _ in self.preprocess.get_span( + m, utt, flag_start_end_span_representation=False + ) + ] + + return ret + + +class EEMD: + """Find corresponding explicit entity mention using s2e-coref-based method""" + + def __init__(self, s2e_pe_model): + self.conf = self.Config(s2e_pe_model=s2e_pe_model) + self.model = self._read_model() + + class Config: + """Inner class for config""" + + def __init__(self, s2e_pe_model): + self.max_seq_length = 4096 + self.model_name_or_path = s2e_pe_model + self.max_total_seq_len = 4096 + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # self.device = torch.device("cpu") # TMP + + # Config for NN model + # Params are from: https://github.com/yuvalkirstain/s2e-coref/blob/main/cli.py + self.max_span_length = 30 + self.top_lambda = 0.4 + self.ffnn_size = 3072 + self.normalise_loss = False + self.dropout_prob = 0.3 + + def _read_model(self): + config_class = LongformerConfig + base_model_prefix = "longformer" + + transformer_config = AutoConfig.from_pretrained( + self.conf.model_name_or_path + ) # , cache_dir=args.cache_dir) + + S2E.config_class = config_class + S2E.base_model_prefix = base_model_prefix + + model = S2E.from_pretrained( + self.conf.model_name_or_path, config=transformer_config, args=self.conf + ) + + model.to(self.conf.device) + + return model + + def get_scores(self, input_data): + """Calculate the score of each mention pair + Args: + input_data (dict): Input data. + E.g., {'clusters': [], # Not used for inference + 'doc_key': 'tmp', # Not used for inference + 'mentions': [[2, 3], [77, 78]], # Detected concept and NE mentions + 'pems': [[67, 72]], # Detected personal entity mention. Only one PEM is allowed now. + 'sentences': [['I', 'think', 'science', 'fiction', 'is', ...]], # tokenized sentences using tokenizers.pre_tokenizers + 'speakers': [['USER', 'USER', 'USER', 'USER', 'USER', ...]], # speaker information + 'text': None + } + + Returns: + The scores for each mention pair. The pairs which does not have any PEM are removed in the later post-processing. + E.g., + [{'doc_key': 'tmp', + 'span_token_anaphora': (67, 72), + 'span_token_antecedent': (2, 3), ...] + """ + assert ( + type(input_data) == dict + ), f"input_data should be a dict, but got {type(input_data)}" + tokenizer = AutoTokenizer.from_pretrained( + self.conf.model_name_or_path, use_fast=False + ) + # `use_fast=False` should be supecified for v4.18.0 (do not need to do this for v3.3.1) + # See: https://github.com/huggingface/transformers/issues/10297#issuecomment-812548803 + + eval_dataset = data.get_dataset(tokenizer, input_data, self.conf) + + eval_dataloader = BucketBatchSampler( + eval_dataset, + max_total_seq_len=self.conf.max_total_seq_len, + batch_size_1=True, + ) + + assert len(eval_dataloader) == 1, f"Currently, only 1 batch is supported" + for i, ((doc_key, subtoken_maps), batch) in enumerate(eval_dataloader): + # NOTE: subtoken_maps should NOT be used to map word -> subtoken!!! + # The original name of subtoken_maps is `end_token_idx_to_word_idx`, meaning this is intended to map subtoken (end token) -> word. + # NOTE: Currently, this code only supports only one example at a time, however, for the futurework, we keep this for loop here. + batch = tuple(tensor.to(self.conf.device) for tensor in batch) + input_ids, attention_mask, gold_clusters = batch + + with torch.no_grad(): + _, scores = self.model( + input_ids=input_ids, # This calls __call__ in module.py in PyTorch, and it calls S2E.forward(). + attention_mask=attention_mask, + gold_clusters=gold_clusters, + return_all_outputs=True, + subtoken_map=subtoken_maps, # HJ + pem_eem_subtokenspan=( + sorted( + eval_dataset.dockey2eems_subtokenspan[doc_key] + + eval_dataset.dockey2pems_subtokenspan[doc_key] + ) + ), # HJ + doc_key=doc_key, + ) # HJ: 220221 + + return scores diff --git a/src/REL/crel/s2e_pe/pe_data.py b/src/REL/crel/s2e_pe/pe_data.py new file mode 100644 index 0000000..250243d --- /dev/null +++ b/src/REL/crel/s2e_pe/pe_data.py @@ -0,0 +1,468 @@ +from tokenizers.pre_tokenizers import Whitespace + +pre_tokenizer = Whitespace() + +TMP_DOC_ID = "tmp" # temporary doc id + + +class PreProcess: + """Create input for PE Linking module""" + + def _error_check(self, conv): + assert type(conv) == list + assert len(conv) > 0, f"conv should be a list of dicts, but got {conv}" + for turn in conv: + assert type(turn) == dict, f"conv should be a list of dicts, but got {turn}" + assert set(turn.keys()) == { + "speaker", + "utterance", + "mentions", + "pems", + } or set(turn.keys()) == { + "speaker", + "utterance", + }, f"Each turn should have either [speaker, utterance, mentions, pems] keys for USER or [speaker, utterance] keys for SYSTEM. If there is no pems or mentions, then set them to empty list." + assert turn["speaker"] in [ + "USER", + "SYSTEM", + ], f'The speaker should be either USER or SYSTEM, but got {turn["speaker"]}' + assert ( + type(turn["utterance"]) == str + ), f'The utterance should be a string, but got {turn["utterance"]}' + if turn["speaker"] == "USER": + assert ( + type(turn["mentions"]) == list + ), f'The mentions should be a list, but got {turn["mentions"]}' + assert ( + type(turn["pems"]) == list + ), f'The pems should be a list, but got {turn["pems"]}' + + # Check there are only one pem per conv + pems = [pem for turn in conv if "pems" in turn for pem in turn["pems"]] + assert ( + len(pems) == 1 + ), f"Current implementation only supports one pem per input conv. If there are multiple PEM, then split them into multiple conv." # This is also a TODO for the future + + def get_span(self, ment, text, flag_start_end_span_representation=True): + """Get (start, end) span of a mention (inc. PEM) in a text + + Args: + ment (str): E.g., 'Paris' + text (str): E.g., 'Paris. The football club Paris Saint-Germain and the rugby union club Stade Français are based in Paris.' + + Returns: mention spans + if flag_start_end_span_representation==True: + E.g., [(0, 5), (25, 30), (98, 103)] + if flag_start_end_span_representation==False: + E.g., [(0, 5), (25, 5), (98, 5)] + + Note: + - re.finditer is NOT used since it takes regex pattern (not string) as input and fails for the patterns such as: + text = 'You you dance? I love cuban Salsa but also like other types as well. dance-dance.' + ment = 'dance? ' + """ + assert ment in text, f"The mention {ment} is not in the text {text}" + spans = [] # [(start_pos, length), ...] + offset = 0 + while True: + try: + start_pos = text.index(ment, offset) + spans.append((start_pos, len(ment))) + offset = start_pos + len(ment) + except: + break + + if ( + flag_start_end_span_representation + ): # (start_pos, length) --> (start_pos, end_pos) + spans = [(s, l + s) for s, l in spans] # pos_end = pos_start + length + + return spans + + def _token_belongs_to_mention( + self, m_spans: list, t_span: tuple, utt: str, print_warning=False + ) -> bool: + """Check whether token span is in ment span(s) or not + + Args: + m_spans: e.g., [(0, 4), (10, 14), (2,4)] + t_span: e.g., (1, 3) + """ + + def _error_check(m_spans, t_span): + assert len(t_span) == 2 + assert ( + t_span[1] > t_span[0] + ) # Note that span must be (start_ind, end_ind), NOT the REL style output of (start_ind, length) + assert any( + [True if m_span[1] > m_span[0] else False for m_span in m_spans] + ) # The same as above + + _error_check(m_spans, t_span) + + # Main + for m_span in m_spans: + + # if token span is out of mention span (i.e., does not have any overlaps), then go to next + t_out_m = (t_span[1] <= m_span[0]) or (m_span[1] <= t_span[0]) + if t_out_m: + continue + + # Check whether (1) token is in mention, (2) mention is in token, or (3) token partially overlaps with mention + t_in_m = (m_span[0] <= t_span[0]) and (t_span[1] <= m_span[1]) + m_in_t = (t_span[0] <= m_span[0]) and ( + m_span[1] <= t_span[1] + ) # To deal with the case of ment:"physicians" ent:"Physician" + t_ol_m = not (t_in_m or t_out_m) # token overlaps with mention + + if t_in_m or m_in_t: # Case 1 or 2 + return True + elif t_ol_m: # Case 3 + if print_warning: + print( + f"WARNING: There is overlaps between mention and word spans. t_span:{t_span} ({utt[t_span[0]:t_span[1]]}), m_span:{m_span} ({utt[m_span[0]:m_span[1]]})" + ) + # NOTE: Treat this token as not belonging to the mention + + return False + + def _tokens_info( + self, + text: str, + speaker: str, + ments: list, + pems: list, + ): + """Append information for each token + The information includes: + - span: (start_pos, end_pos) of the token, acquired from the pre_tokenizers + - mention: what mention the token belongs to (or if not in any mention, None) + - pem: same as mention, but for pems + - speaker: the speaker of the utterance, either "USER" or "SYSTEM" + + Args: + text: the text of the utterance + speaker: the speaker of the utterance, either "USER" or "SYSTEM" + ments: list of mentions + pems: list of pems + + Returns: list of tokens with information + E.g., [{'token': 'Blue', + 'span': (0, 4), + 'mention': 'Blue', + 'pem': None, + 'speaker': 'USER'}, ...] + """ + ments = list(sorted(ments, key=len)) # to perform longest match + tokens_conv = [] + + try: # if tokenizer version is 0.10.3 etc where pre_tokenize_str is available + tokens_per_utt = pre_tokenizer.pre_tokenize(text) + except: # if 0.8.1.rc2 etc where pre_tokenizer_str is NOT available + tokens_per_utt = pre_tokenizer.pre_tokenize_str(text) + + ment2span = {ment: self.get_span(ment, text) for ment in ments} # mention spans + pem2span = {pem: self.get_span(pem, text) for pem in pems} # pem spans + # NOTE: get_span() cannot consider the case of the same surface mention occurring multiple times in the same utterance, only one of which is detected by the MD module. + # This is because the function detects ALL of the spans for the given surface mention. + # However, this case is extremely rare thus here the code ignores this case. + + for tkn_span in tokens_per_utt: + tkn = tkn_span[0] + span = tkn_span[1] + ment_out, pem_out = None, None # Initialize + + # First check if token is in any PEMs + for pem in pems: + if self._token_belongs_to_mention(pem2span[pem], span, text): + pem_out = pem + break + + # If token is not in any pem, then check if it is in any mention + if pem_out is None: + for ment in ments: + if self._token_belongs_to_mention(ment2span[ment], span, text): + ment_out = ment + break + + tokens_conv.append( + { + "token": tkn, + "span": span, + "mention": ment_out, + "pem": pem_out, + "speaker": speaker, + } + ) + return tokens_conv + + def get_tokens_with_info(self, conv): + """Get tokens with information of: + - span: (start_pos, end_pos) of the token, acquired from the pre_tokenizers + - mention: what mention the token belongs to (or if not in any mention, None) + - pem: same as mention, but for pems + - speaker: the speaker of the utterance, either "USER" or "SYSTEM" + + Args: + conv: a conversation + E.g., + {"speaker": "USER", + "utterance": "I agree. One of my favorite forms of science fiction is anything related to time travel! I find it fascinating.", + "mentions": ["science fiction", "time travel", ], + "pems": ["my favorite forms of science fiction", ],}, + """ + self._error_check(conv) + ret = [] + for turn_num, turn in enumerate(conv): + speaker = turn["speaker"] + if speaker == "USER": + ments = turn["mentions"] if "mentions" in turn else [] + pems = turn["pems"] if "pems" in turn else [] + elif speaker == "SYSTEM": + ments = [] + pems = [] + else: + raise ValueError( + f'Unknown speaker: {speaker}. Speaker must be either "USER" or "SYSTEM".' + ) + tkn_info = self._tokens_info(turn["utterance"], speaker, ments, pems) + for elem in tkn_info: + elem["turn_number"] = turn_num + ret += tkn_info + return ret + + def _get_token_level_span(self, token_info: list, key_for_ment: str): + """Get token-level spans for mentions and pems + Token-level span is the input for s2e + + Args: + token_info (list): E.g., [{'token': 'Blue', 'span': (0, 4), 'corresponding_ment': 'Blue', 'speaker': 'USER', 'turn_number': 0}, ...] + key_ment_or_pem (str): 'mention' or 'pem' + + Returns: + E.g., [[2, 5], [8, 9], [0, 3]] + """ + # Error check + assert key_for_ment in [ + "mention", + "pem", + ] # key_for_ment should be mention or pem + + pem_and_eem = [] + start_pos, end_pos = None, None + for i in range(len(token_info)): + prev_ment = token_info[i - 1][key_for_ment] if i > 0 else None + curr_ment = token_info[i][key_for_ment] + futr_ment = ( + token_info[i + 1][key_for_ment] if (i + 1) < len(token_info) else None + ) + + if (prev_ment != curr_ment) and (curr_ment != None): # mention start + if start_pos == None: # Error check + start_pos = i + else: + raise ValueError("pos should be None to assign the value") + if (futr_ment != curr_ment) and (curr_ment != None): + # print(curr_ment,start_pos,end_pos) + if end_pos == None: # Error check + end_pos = i + else: + raise ValueError("pos should be None to assign the value") + + # print(prev_ment,curr_ment,futr_ment,'\tSTART_END_POS:',start_pos,end_pos) + + if (start_pos != None) and (end_pos != None): + # print('curr_ment:',curr_ment) + pem_and_eem.append([start_pos, end_pos]) + start_pos, end_pos = None, None # Initialize + + return pem_and_eem + + def get_input_of_pe_linking(self, token_info): + """Get the input of PE Linking module + + Args: + token_info (list): list of dict where tokens and their information is stored. Output of get_tokens_with_info() + E.g., + [{'token': 'I', # token + 'span': (0, 1), # (start_pos, end_pos) of the token + 'mention': None, # what mention the token belongs to (or if not in any mention, None) + 'pem': None, # same as mention, but for pems + 'speaker': 'USER', # the speaker of the utterance, either "USER" or "SYSTEM" + 'turn_number': 0}, # turn number of the utterance which the token belongs to (0-based) + ...] + + Returns: + Input of PE Linking module (the same input as s2e-coref) + """ + ret = [] + pem_spans = self._get_token_level_span(token_info, "pem") + # TODO: This is redundant and not efficient. + # Instead `pem_spans` can be acquired by pos_to_token mapping, which maps char_position --> token_position + # But I am too lazy to implement this code now. (Especially since it affects all other functions and data flows) + + for ( + pem_span + ) in ( + pem_spans + ): # [[142, 143], [256, 258], [309, 310]]. Note that this is token-level, not char-level + (start, end) = pem_span + turn_num = token_info[start]["turn_number"] + assert ( + turn_num == token_info[end]["turn_number"] + ), f"Start token and end token should have the same turn_number. start: {start}, end: {end}" + tokens_until_current_turn = [ + e for e in token_info if e["turn_number"] <= turn_num + ] # Extract tokens until the current turn + + ret.append( + { + "clusters": [], # Not used for inference + "doc_key": "tmp", # Not used for inference + "mentions": self._get_token_level_span( + tokens_until_current_turn, "mention" + ), # Detected mention spans, where format is (start_token_ind, end_token_ind) # E.g., [[28, 43], [67, 78]]. TODO: The same as above todo + "pems": [ + pem_span + ], # Detected personal entity mention span. The format is the same as mention spans # E.g., [[7, 43]]. NOTE: Currently our tool support only one mention at a time. + "sentences": [ + [e["token"] for e in tokens_until_current_turn] + ], # Tokenized sentences. E.g., ['I', 'think', 'science', 'fiction', 'is', ...] + "speakers": [ + [e["speaker"] for e in tokens_until_current_turn] + ], # Speaker information. E.g., ['SYSTEM', 'SYSTEM', ..., 'USER', 'USER', ...] + "text": None, + } + ) + return ret + + +class PostProcess: + """Handle output of PE Linking module""" + + def _get_ment2score( + self, doc_key: str, mentions: list, pems: list, scores: list, flag_print=False + ) -> dict: + """Get mention to score map + + Args: + doc_key (str): E.g., 'dialind:1_turn:0_pem:My-favourite-type-of-cake' + mentions: E.g., [[6, 7], [12, 12], [14, 15]] + pems: E.g., [[0, 4]] + scores: The scores for all mention (inc. PE) pair + E.g., [{'doc_key': 'tmp', 'span_token_anaphora': [8, 8], 'span_token_antecedent': [0, 0], 'coref_logits': -66.80387115478516}, ...] + + Returns: + {(6, 7): -42.52804183959961, + (12, 12): -83.429443359375, + (14, 15): -47.706520080566406} + + """ + assert all( + [isinstance(m, list) for m in mentions] + ) # Check both mentions and pems are 2d lists + assert all([isinstance(m, list) for m in pems]) # The same for pems + assert len(pems) == 1 # Check we only have one PEM + if doc_key not in [sj["doc_key"] for sj in scores]: # 220403 + if flag_print: + print( + f"{doc_key} not in scores. It might be that EL tool could not detect any candidate EEMs for this PEM. Return empty dict." + ) + return {} # ment2score = {} + + # Change all ments and pems to tuple to be able to compare + ment_tpl_list = [ + tuple(m) for m in mentions + ] # E.g., [(6, 7), (12, 12), (14, 15)] + pem_tpl = tuple(pems[0]) # E.g., (0, 4) + + ment2score = {} + span_hist = set() # This is used to error check + for sj in scores: + if sj["doc_key"] == doc_key: + # print(sj) + span_ano = tuple(sj["span_token_anaphora"]) # E.g., (6, 7) + span_ant = tuple(sj["span_token_antecedent"]) # E.g., (0, 4) + span_hist.add(span_ano) + span_hist.add(span_ant) + + if ( + span_ano == pem_tpl and span_ant in ment_tpl_list + ): # anaphora is the PEM case + ment2score[span_ant] = sj["coref_logits"] + elif ( + span_ant == pem_tpl and span_ano in ment_tpl_list + ): # antecedent is the PEM case + ment2score[span_ano] = sj["coref_logits"] + + # Check all ment_tpl_list and pem_tpl are in span_hist + assert sorted(ment_tpl_list + [pem_tpl]) == sorted( + list(span_hist) + ), f"mentions in score.json and pred.jsonl should be the same. {sorted(ment_tpl_list + [pem_tpl])} != {sorted(list(span_hist))}. doc_key: {doc_key}" + return ment2score + + def _convert_to_mention_from_token(self, mention, comb_text): + """ + Args: + mention (list): [start, end] (this is token-level (NOT subtoken-level)) + """ + start = mention[0] # output['subtoken_map'][mention[0]] + end = mention[1] + 1 # output['subtoken_map'][mention[1]] + 1 + mtext = "".join(" ".join(comb_text[start:end]).split(" ##")) + return mtext + + def get_results(self, pel_input, conv, threshold, scores): + """Get PE Linking post-processed results + + Args: + pel_input (dict): input for PE Linking module + E.g., {'clusters': [], # Not used for inference + 'doc_key': 'tmp', # Not used for inference + 'mentions': [[2, 3], [77, 78]], # Detected concept and NE mentions + 'pems': [[67, 72]], # Detected personal entity mention. Only one PEM is allowed now. + 'sentences': [['I', 'think', 'science', 'fiction', 'is', ...]], # tokenized sentences using tokenizers.pre_tokenizers + 'speakers': [['USER', 'USER', 'USER', 'USER', 'USER', ...]], # speaker information + 'text': None + } + threshold: default 0 + conv: The conversation input to preprocessing module (conversation before preprocessing) + scores: The scores for all mention (inc. PE) pair + E.g., + [{'doc_key': 'tmp', + 'span_token_anaphora': (67, 72), # This could be either a mention or a PEM + 'span_token_antecedent': (2, 3), # The same as above + 'coref_logits': -4.528693675994873}, # Output score from PE Linking module + Returns: + E.g., + [{'personal_entity_mention': 'my favorite forms of science fiction', + 'mention': 'time travel', + 'score': 4.445976734161377}] + + """ + assert type(pel_input) == dict, f"pel_input should be a dict. {type(pel_input)}" + ments_span_tokenlevel = [m for m in pel_input["mentions"]] + pems_span_tokenlevel = [p for p in pel_input["pems"]] # len(pems) == 1 + assert ( + len(pems_span_tokenlevel) == 1 + ), f"len(pems_span_tokenlevel) should be 1. {len(pems_span_tokenlevel)}" + + mspan2score = self._get_ment2score( + TMP_DOC_ID, ments_span_tokenlevel, pems_span_tokenlevel, scores + ) # Mention span to score + comb_text = pel_input["sentences"][ + 0 + ] # pel_input['sentences'] should have only one sentence + + pred_peas = [] + + pem = [m for turn in conv if turn["speaker"] == "USER" for m in turn["pems"]][ + 0 + ] # Each conv has only one pem for current implementation + for ment_span_tokenlevel in pel_input["mentions"]: + score = mspan2score[tuple(ment_span_tokenlevel)] + ment = self._convert_to_mention_from_token(ment_span_tokenlevel, comb_text) + if score > threshold: + pred_peas.append( + {"personal_entity_mention": pem, "mention": ment, "score": score} + ) + return pred_peas diff --git a/src/REL/crel/s2e_pe/utils.py b/src/REL/crel/s2e_pe/utils.py new file mode 100644 index 0000000..048dabb --- /dev/null +++ b/src/REL/crel/s2e_pe/utils.py @@ -0,0 +1,202 @@ +import json +import os +from datetime import datetime +from time import time + +import numpy as np + +# import git +import torch + +from .consts import NULL_ID_FOR_COREF + + +def flatten_list_of_lists(lst): + return [elem for sublst in lst for elem in sublst] + + +def extract_clusters(gold_clusters): + gold_clusters = [ + tuple(tuple(m) for m in gc if NULL_ID_FOR_COREF not in m) + for gc in gold_clusters.tolist() + ] + return gold_clusters + + +def extract_mentions_to_predicted_clusters_from_clusters(gold_clusters): + mention_to_gold = {} + for gc in gold_clusters: + for mention in gc: + mention_to_gold[tuple(mention)] = gc + return mention_to_gold + + +# def extract_clusters_for_decode(mention_to_antecedent): +def extract_clusters_for_decode(mention_to_antecedent, pems_subtoken): + """ + Args: + pems (list): E.g., [(2,3), (8,11), ...] + """ + + mention_to_antecedent = sorted(mention_to_antecedent) + mention_to_cluster = {} + clusters = [] + for mention, antecedent in mention_to_antecedent: + if (mention in pems_subtoken) or (antecedent in pems_subtoken): + if antecedent in mention_to_cluster: + cluster_idx = mention_to_cluster[antecedent] + clusters[cluster_idx].append(mention) + mention_to_cluster[mention] = cluster_idx + + else: + cluster_idx = len(clusters) + mention_to_cluster[mention] = cluster_idx + mention_to_cluster[antecedent] = cluster_idx + clusters.append([antecedent, mention]) + clusters = [tuple(cluster) for cluster in clusters] + return clusters, mention_to_cluster + + +def ce_extract_clusters_for_decode_with_one_mention_per_pem( + starts, end_offsets, coref_logits, pems_subtoken, flag_use_threshold +): + """ + + Args: + - flag_use_threshold: + True: Default. If PEM does not meet a threshold (default: 0), then all mentions are ignored. The threshold is stored in final element of each row of coref_logits. + False: Ignore threshold, pick the highest logit EEM for each PEM. + Updates: + - 220302: Created + """ + if flag_use_threshold: + max_antecedents = np.argmax( + coref_logits, axis=1 + ).tolist() # HJ: 220225: mention_to_antecedents takes max score. We have at most two predicted EEMs (one is coreference is PEM case, and the other is antecedent is PEM case). + else: + max_antecedents = np.argmax( + coref_logits[:, :-1], axis=1 + ).tolist() # HJ: 220225: mention_to_antecedents takes max score. We have at most two predicted EEMs (one is coreference is PEM case, and the other is antecedent is PEM case). + + # Create {(ment, antecedent): logits} dict + mention_antecedent_to_coreflogit_dict = { + ( + (int(start), int(end)), + (int(starts[max_antecedent]), int(end_offsets[max_antecedent])), + ): logit[max_antecedent] + for start, end, max_antecedent, logit in zip( + starts, end_offsets, max_antecedents, coref_logits + ) + if max_antecedent < len(starts) + } + # 220403: Drop if key has the same start and end pos for anaphora and antecedent + mention_antecedent_to_coreflogit_dict = { + k: v for k, v in mention_antecedent_to_coreflogit_dict.items() if k[0] != k[1] + } + if len(mention_antecedent_to_coreflogit_dict) == 0: + return [] + + # Select the ment-ant pair containing the PEM + + mention_antecedent_to_coreflogit_dict_with_pem = { + (m, a): logit + for (m, a), logit in mention_antecedent_to_coreflogit_dict.items() + if (m in pems_subtoken) or (a in pems_subtoken) + } + if len(mention_antecedent_to_coreflogit_dict_with_pem) == 0: + return [] + + # Select the max score + _max_logit = max(mention_antecedent_to_coreflogit_dict_with_pem.values()) + if flag_use_threshold and (_max_logit <= 0): + print(f"WARNING: _max_logit = {_max_logit}") + # _max_logit = _max_logit if _max_logit > 0 else 0 # HJ: 220302: If we set a threshold, then this does not work. + assert ( + coref_logits[-1][-1] == 0 + ), f"The threshold should be 0. If you set your threshold, then the code above should be fixed." + # Select the pair with max score + mention_to_antecedent_max_pem = { + ((m[0], m[1]), (a[0], a[1])) + for (m, a), logit in mention_antecedent_to_coreflogit_dict_with_pem.items() + if logit == _max_logit + } + assert ( + len(mention_to_antecedent_max_pem) <= 1 + ), f"Two or more mentions have the same max score: {mention_to_antecedent_max_pem}" + + predicted_clusters, _ = extract_clusters_for_decode( + mention_to_antecedent_max_pem, pems_subtoken + ) # TODO: 220302: Using `extract_clusters_for_decode` here is redundant. + return predicted_clusters + + +def mask_tensor(t, mask): + t = t + ((1.0 - mask.float()) * -10000.0) + t = torch.clamp(t, min=-10000.0, max=10000.0) + return t + + +# def write_meta_data(output_dir, args): +# output_path = os.path.join(output_dir, "meta.json") +# repo = git.Repo(search_parent_directories=True) +# hexsha = repo.head.commit.hexsha +# ts = time() +# print(f"Writing {output_path}") +# with open(output_path, mode='w') as f: +# json.dump( +# { +# 'git_hexsha': hexsha, +# 'args': {k: str(v) for k, v in args.__dict__.items()}, +# 'date': datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') +# }, +# f, +# indent=4, +# sort_keys=True) + +# def ce_get_start_end_subtoken_num(start_token, end_token, subtoken_map): +# """ +# Example: +# ### Input ### +# start_token, end_token = (2,4) +# subtoken_map = [0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] # subtoken_map + +# ### Output ### +# (6, 8) + +# Notes: +# - This function is used in modeling.py and this util.py +# """ +# N = len(subtoken_map) +# start_subtoken = subtoken_map.index(start_token) +# end_subtoken = (N-1) - subtoken_map[::-1].index(end_token) + +# return start_subtoken, end_subtoken + +# def ce_get_pem_ments(predict_file): +# """ +# Args: +# args.predict_file: E.g., '../data/data_dir/test.english.jsonlines' +# Return: +# - pems (dict): E.g., {'dialind_0': [[2, 4]], 'dialind_1': [[66, 67]], ...} +# - ments (dict): E.g., {'dialind_0': [[0, 0], [3, 4], [35, 35], ...} +# """ +# pems = {} +# ments = {} +# with open(predict_file) as f: +# jsonl = f.readlines() +# print(len(jsonl)) +# for l in jsonl: +# l = json.loads(l) +# assert l['doc_key'] not in pems, 'doc_key should be unique' +# assert l['doc_key'] not in ments, 'doc_key should be unique' +# pems[l['doc_key']] = l['pems'] +# ments[l['doc_key']] = l['mentions'] +# # pems.append(l['pems']) # PEM +# # ments.append(l['mentions']) # mentions + +# # Error check: pem should not be in ments +# for doc_key in pems: # doc_key: E.g., 'dialind_0' +# for pem in pems[doc_key]: # pem: E.g., [2, 4] +# assert pem not in ments[doc_key], f'PEM {pem} should not be in ments. Fix this at 010_preprocess_eval_ConEL.ipynb' + +# return pems, ments diff --git a/tests/test_crel.py b/tests/test_crel.py new file mode 100644 index 0000000..c50802d --- /dev/null +++ b/tests/test_crel.py @@ -0,0 +1,94 @@ +import os +import pytest +from REL.crel.conv_el import ConvEL +import yaml +from pathlib import Path + +os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + +@pytest.fixture +def cel(): + return ConvEL(base_url = os.environ.get('REL_BASE_URL', '.')) + + +def test_conv1(cel): + example = [ + { + "speaker": + "USER", + "utterance": + "I think science fiction is an amazing genre for anything. Future science, technology, time travel, FTL travel, they're all such interesting concepts.", + }, + { + "speaker": + "SYSTEM", + "utterance": + "Awesome! I really love how sci-fi storytellers focus on political/social/philosophical issues that would still be around even in the future. Makes them relatable.", + }, + { + "speaker": + "USER", + "utterance": + "I agree. One of my favorite forms of science fiction is anything related to time travel! I find it fascinating.", + }, + ] + + result = cel.annotate(example) + assert isinstance(result, list) + + expected_annotations = [ + [[8, 15, 'science fiction', 'Science_fiction'], + [38, 5, 'genre', 'Genre_fiction'], + [74, 10, 'technology', 'Technology'], + [86, 11, 'time travel', 'Time_travel'], + [99, 10, 'FTL travel', 'Faster-than-light']], + [[37, 15, 'science fiction', 'Science_fiction'], + [76, 11, 'time travel', 'Time_travel'], + [16, 36, 'my favorite forms of science fiction', 'Time_travel']], + ] + + annotations = [ + res['annotations'] for res in result if res['speaker'] == 'USER' + ] + + assert annotations == expected_annotations + + +def test_conv2(cel): + example = [ + { + "speaker": + "USER", + "utterance": + "I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.", + }, + { + "speaker": "SYSTEM", + "utterance": "Some people are allergic to histamine in tomatoes.", + }, + { + "speaker": + "USER", + "utterance": + "Talking of food, can you recommend me a restaurant in my city for our anniversary?", + }, + ] + + result = cel.annotate(example) + assert isinstance(result, list) + + annotations = [ + res['annotations'] for res in result if res['speaker'] == 'USER' + ] + + expected_annotations = [ + [[17, 8, 'tomatoes', 'Tomato'], + [54, 19, 'Italian restaurants', 'Italian_cuisine'], + [82, 6, 'London', 'London']], + [[11, 4, 'food', 'Food'], + [40, 10, 'restaurant', 'Restaurant'], + [54, 7, 'my city', 'London']], + ] + + assert annotations == expected_annotations From 5fc89c5291c58fc46970951e8c0b51a68a1d5de4 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 13 Dec 2022 15:50:05 +0100 Subject: [PATCH 02/15] Implement API for conversational entity linking --- scripts/efficiency_test.py | 9 +++-- src/REL/server.py | 74 ++++++++++++++++++++++---------------- 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/scripts/efficiency_test.py b/scripts/efficiency_test.py index fc66f54..b903712 100644 --- a/scripts/efficiency_test.py +++ b/scripts/efficiency_test.py @@ -1,12 +1,15 @@ import numpy as np import requests +import os from REL.training_datasets import TrainingEvaluationDatasets np.random.seed(seed=42) -base_url = "/Users/vanhulsm/Desktop/projects/data/" -wiki_version = "wiki_2014" +base_url = os.environ.get("REL_BASE_URL") +wiki_version = "wiki_2019" +host = 'localhost' +port = '5555' datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()["aida_testB"] # random_docs = np.random.choice(list(datasets.keys()), 50) @@ -40,7 +43,7 @@ print(myjson) print("Output API:") - print(requests.post("http://192.168.178.11:1235", json=myjson).json()) + print(requests.post(f"http://{host}:{port}", json=myjson).json()) print("----------------------------") diff --git a/src/REL/server.py b/src/REL/server.py index d26d6a9..c9b014d 100644 --- a/src/REL/server.py +++ b/src/REL/server.py @@ -6,6 +6,8 @@ from REL.mention_detection import MentionDetection from REL.utils import process_results +from REL.crel.conv_el import ConvEL + API_DOC = "API_DOC" @@ -25,6 +27,8 @@ def __init__(self, *args, **kwargs): self.custom_ner = not isinstance(tagger_ner, SequenceTagger) self.mention_detection = MentionDetection(base_url, wiki_version) + self.conv_linking = ConvEL(base_url=base_url, wiki_version=wiki_version) + super().__init__(*args, **kwargs) def do_GET(self): @@ -64,8 +68,8 @@ def do_POST(self): self.send_response(200) self.end_headers() - text, spans = self.read_json(post_data) - response = self.generate_response(text, spans) + data = self.read_json(post_data) + response = self.generate_response(**data) self.wfile.write(bytes(json.dumps(response), "utf-8")) except Exception as e: @@ -83,22 +87,23 @@ def read_json(self, post_data): """ data = json.loads(post_data.decode("utf-8")) - text = data["text"] - text = text.replace("&", "&") + + if isinstance(data["text"], str): + data["text"] = data["text"].replace("&", "&") # GERBIL sends dictionary, users send list of lists. if "spans" in data: try: - spans = [list(d.values()) for d in data["spans"]] + data["spans"] = [list(d.values()) for d in data["spans"]] except Exception: - spans = data["spans"] pass - else: - spans = [] - return text, spans + data.setdefault("spans", []) + data.setdefault("conversation", False) + + return data - def generate_response(self, text, spans): + def generate_response(self, *, text: list, spans: list, conversation: bool): """ Generates response for API. Can be either ED only or EL, meaning end-to-end. @@ -108,29 +113,36 @@ def generate_response(self, text, spans): if len(text) == 0: return [] - if len(spans) > 0: - # ED. - processed = {API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.format_spans( - processed - ) - else: - # EL - processed = {API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.find_mentions( - processed, self.tagger_ner - ) + if conversation: - # Disambiguation - predictions, timing = self.model.predict(mentions_dataset) + result = self.conv_linking.annotate(text) + return result - # Process result. - result = process_results( - mentions_dataset, - predictions, - processed, - include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, - ) + else: + + if len(spans) > 0: + # ED. + processed = {API_DOC: [text, spans]} + mentions_dataset, total_ment = self.mention_detection.format_spans( + processed + ) + else: + # EL + processed = {API_DOC: [text, spans]} + mentions_dataset, total_ment = self.mention_detection.find_mentions( + processed, self.tagger_ner + ) + + # Disambiguation + predictions, timing = self.model.predict(mentions_dataset) + + # Process result. + result = process_results( + mentions_dataset, + predictions, + processed, + include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, + ) # Singular document. if len(result) > 0: From 3aa2b6e419338341a7e43f40692399970c47a8b6 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 13 Dec 2022 15:59:22 +0100 Subject: [PATCH 03/15] Add script for testing the server response --- scripts/test_server.py | 57 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 scripts/test_server.py diff --git a/scripts/test_server.py b/scripts/test_server.py new file mode 100644 index 0000000..7035e47 --- /dev/null +++ b/scripts/test_server.py @@ -0,0 +1,57 @@ +import os +import requests + +# Script for testing the implementation of the conversational entity linking API +# +# To run: +# +# python .\src\REL\server.py $REL_BASE_URL wiki_2019 +# or +# python .\src\REL\server.py $env:REL_BASE_URL wiki_2019 +# +# Set $REL_BASE_URL to where your data are stored (`base_url`) +# +# These paths must exist: +# - `$REL_BASE_URL/bert_conv` +# - `$REL_BASE_URL/s2e_ast_onto ` +# +# (see https://github.com/informagi/conversational-entity-linking-2022/tree/main/tool#step-1-download-models) +# + + +host = 'localhost' +port = '5555' + +text1 = { + "text": "REL is a modular Entity Linking package that can both be integrated in existing pipelines or be used as an API.", + "spans": [] +} + +conv1 = {"conversation": "True", + "text" : [ + { + "speaker": + "USER", + "utterance": + "I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.", + }, + { + "speaker": "SYSTEM", + "utterance": "Some people are allergic to histamine in tomatoes.", + }, + { + "speaker": + "USER", + "utterance": + "Talking of food, can you recommend me a restaurant in my city for our anniversary?", + }, + ] +} + +for myjson in text1, conv1: + print('Input API:') + print(myjson) + print() + print('Output API:') + print(requests.post(f"http://{host}:{port}", json=myjson).json()) + print('----------------------------') \ No newline at end of file From 637a7b54625e38fd5cd8b07c777ddb815e92090b Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Mon, 19 Dec 2022 15:29:19 +0100 Subject: [PATCH 04/15] Re-implement server using fastapi/pydantic --- requirements.txt | 9 +- setup.cfg | 9 +- src/REL/db/base.py | 2 +- src/REL/server.py | 236 +++++++++++++++++---------------------------- 4 files changed, 101 insertions(+), 155 deletions(-) diff --git a/requirements.txt b/requirements.txt index c5a5969..c84bf33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ +anyascii colorama -konoha +fastapi flair>=0.11 +konoha +nltk +pydantic segtok torch -nltk -anyascii +uvicorn diff --git a/setup.cfg b/setup.cfg index 8fbd4af..19ab16c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,13 +43,16 @@ package_dir = = src include_package_data = True install_requires = + anyascii colorama - konoha + fastapi flair>=0.11 + konoha + nltk + pydantic segtok torch - nltk - anyascii + uvicorn [options.extras_require] develop = diff --git a/src/REL/db/base.py b/src/REL/db/base.py index 8eec44d..2526946 100644 --- a/src/REL/db/base.py +++ b/src/REL/db/base.py @@ -40,7 +40,7 @@ def initialize_db(self, fname, table_name, columns): db (sqlite3.Connection): a SQLite3 database with an embeddings table. """ # open database in autocommit mode by setting isolation_level to None. - db = sqlite3.connect(fname, isolation_level=None) + db = sqlite3.connect(fname, isolation_level=None, check_same_thread=False) q = "create table if not exists {}(word text primary key, {})".format( table_name, ", ".join(["{} {}".format(k, v) for k, v in columns.items()]) diff --git a/src/REL/server.py b/src/REL/server.py index d26d6a9..c488330 100644 --- a/src/REL/server.py +++ b/src/REL/server.py @@ -1,152 +1,99 @@ -import json -from http.server import BaseHTTPRequestHandler - +from REL.entity_disambiguation import EntityDisambiguation +from REL.ner import load_flair_ner from flair.models import SequenceTagger - from REL.mention_detection import MentionDetection from REL.utils import process_results -API_DOC = "API_DOC" - - - -def make_handler(base_url, wiki_version, model, tagger_ner): - """ - Class/function combination that is used to setup an API that can be used for e.g. GERBIL evaluation. - """ - class GetHandler(BaseHTTPRequestHandler): - def __init__(self, *args, **kwargs): - self.model = model - self.tagger_ner = tagger_ner - - self.base_url = base_url - self.wiki_version = wiki_version - - self.custom_ner = not isinstance(tagger_ner, SequenceTagger) - self.mention_detection = MentionDetection(base_url, wiki_version) - - super().__init__(*args, **kwargs) - - def do_GET(self): - self.send_response(200) - self.end_headers() - self.wfile.write( - bytes( - json.dumps( - { - "schemaVersion": 1, - "label": "status", - "message": "up", - "color": "green", - } - ), - "utf-8", - ) - ) - return - - def do_HEAD(self): - # send bad request response code - self.send_response(400) - self.end_headers() - self.wfile.write(bytes(json.dumps([]), "utf-8")) - return - - def do_POST(self): - """ - Returns response. - - :return: - """ - try: - content_length = int(self.headers["Content-Length"]) - post_data = self.rfile.read(content_length) - self.send_response(200) - self.end_headers() - - text, spans = self.read_json(post_data) - response = self.generate_response(text, spans) - - self.wfile.write(bytes(json.dumps(response), "utf-8")) - except Exception as e: - print(f"Encountered exception: {repr(e)}") - self.send_response(400) - self.end_headers() - self.wfile.write(bytes(json.dumps([]), "utf-8")) - return - - def read_json(self, post_data): - """ - Reads input JSON message. - - :return: document text and spans. - """ - - data = json.loads(post_data.decode("utf-8")) - text = data["text"] - text = text.replace("&", "&") - - # GERBIL sends dictionary, users send list of lists. - if "spans" in data: - try: - spans = [list(d.values()) for d in data["spans"]] - except Exception: - spans = data["spans"] - pass - else: - spans = [] - - return text, spans - - def generate_response(self, text, spans): - """ - Generates response for API. Can be either ED only or EL, meaning end-to-end. - - :return: list of tuples for each entity found. - """ - - if len(text) == 0: - return [] - - if len(spans) > 0: - # ED. - processed = {API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.format_spans( - processed - ) - else: - # EL - processed = {API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.find_mentions( - processed, self.tagger_ner - ) - - # Disambiguation - predictions, timing = self.model.predict(mentions_dataset) - - # Process result. - result = process_results( - mentions_dataset, - predictions, - processed, - include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, - ) - - # Singular document. - if len(result) > 0: - return [*result.values()][0] +class ModelHandler: + API_DOC = "API_DOC" + + def __init__(self, base_url, wiki_version, ed_model, ner_model): + self.model = model + self.tagger_ner = tagger_ner + + self.base_url = base_url + self.wiki_version = wiki_version + + self.custom_ner = not isinstance(tagger_ner, SequenceTagger) + self.mention_detection = MentionDetection(base_url, wiki_version) + + def generate_response(self, + *, + text: list, + spans: list, + conversation: bool = False): + """ + Generates response for API. Can be either ED only or EL, meaning end-to-end. + + :return: list of tuples for each entity found. + """ + + if len(text) == 0: return [] - return GetHandler + if len(spans) > 0: + # ED. + processed = {self.API_DOC: [text, spans]} + mentions_dataset, total_ment = self.mention_detection.format_spans( + processed) + else: + # EL + processed = {self.API_DOC: [text, spans]} + mentions_dataset, total_ment = self.mention_detection.find_mentions( + processed, self.tagger_ner) + + # Disambiguation + predictions, timing = self.model.predict(mentions_dataset) + + # Process result. + result = process_results( + mentions_dataset, + predictions, + processed, + include_offset=False if + ((len(spans) > 0) or self.custom_ner) else True, + ) + + # Singular document. + if len(result) > 0: + return [*result.values()][0] + + return [] + + +from fastapi import FastAPI +from pydantic import BaseModel +from typing import List + +app = FastAPI() + +@app.get("/") +def root(): + """Returns server status.""" + return { + "schemaVersion": 1, + "label": "status", + "message": "up", + "color": "green", + } + + +class EntityConfig(BaseModel): + text: str = Field(..., description="Text for entity linking or disambiguation.") + spans: List[str] = Field(..., description="Spans for entity disambiguation.") + + +@app.post("/") +def root(config: EntityConfig): + """Submit your text here for entity disambiguation or linking.""" + response = handler.generate_response(text=config.text, spans=config.spans) + return response if __name__ == "__main__": import argparse - from http.server import HTTPServer - - from REL.entity_disambiguation import EntityDisambiguation - from REL.ner import load_flair_ner + import uvicorn p = argparse.ArgumentParser() p.add_argument("base_url") @@ -161,14 +108,7 @@ def generate_response(self, text, spans): ed_model = EntityDisambiguation( args.base_url, args.wiki_version, {"mode": "eval", "model_path": args.ed_model} ) - server_address = (args.bind, args.port) - server = HTTPServer( - server_address, - make_handler(args.base_url, args.wiki_version, ed_model, ner_model), - ) - try: - print("Ready for listening.") - server.serve_forever() - except KeyboardInterrupt: - exit(0) + handler = ModelHandler(args.base_url, args.wiki_version, ed_model, ner_model) + + uvicorn.run(app, port=args.port, host=args.bind) From fcd57454593a70de8c84d89cbc19ce9d1f92b369 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Mon, 19 Dec 2022 15:35:11 +0100 Subject: [PATCH 05/15] Fix names --- src/REL/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/REL/server.py b/src/REL/server.py index c488330..4946851 100644 --- a/src/REL/server.py +++ b/src/REL/server.py @@ -8,7 +8,7 @@ class ModelHandler: API_DOC = "API_DOC" - def __init__(self, base_url, wiki_version, ed_model, ner_model): + def __init__(self, base_url, wiki_version, model, tagger_ner): self.model = model self.tagger_ner = tagger_ner @@ -63,7 +63,7 @@ def generate_response(self, from fastapi import FastAPI -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing import List app = FastAPI() From cef1e0f5ce3094fcef21b73c102d152893073b30 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Mon, 19 Dec 2022 16:25:57 +0100 Subject: [PATCH 06/15] Implement conversational entity linking in server --- scripts/test_server.py | 13 +++++--- src/REL/server.py | 75 +++++++++++++++++++++++++----------------- 2 files changed, 53 insertions(+), 35 deletions(-) diff --git a/scripts/test_server.py b/scripts/test_server.py index 7035e47..2953248 100644 --- a/scripts/test_server.py +++ b/scripts/test_server.py @@ -27,7 +27,7 @@ "spans": [] } -conv1 = {"conversation": "True", +conv1 = { "text" : [ { "speaker": @@ -48,10 +48,15 @@ ] } -for myjson in text1, conv1: + +for endpoint, myjson in ( + ('', text1), + ('conversation/', conv1) + ): print('Input API:') print(myjson) print() print('Output API:') - print(requests.post(f"http://{host}:{port}", json=myjson).json()) - print('----------------------------') \ No newline at end of file + print(requests.post(f"http://{host}:{port}/{endpoint}", json=myjson).json()) + print('----------------------------') + diff --git a/src/REL/server.py b/src/REL/server.py index 1864ac0..03908e2 100644 --- a/src/REL/server.py +++ b/src/REL/server.py @@ -3,6 +3,7 @@ from flair.models import SequenceTagger from REL.mention_detection import MentionDetection from REL.utils import process_results +from REL.crel.conv_el import ConvEL class ModelHandler: API_DOC = "API_DOC" @@ -21,7 +22,7 @@ def generate_response(self, *, text: list, spans: list, - conversation: bool = False): + ): """ Generates response for API. Can be either ED only or EL, meaning end-to-end. @@ -31,37 +32,30 @@ def generate_response(self, if len(text) == 0: return [] - if conversation: - - result = self.conv_linking.annotate(text) - return result - + if len(spans) > 0: + # ED. + processed = {self.API_DOC: [text, spans]} + mentions_dataset, total_ment = self.mention_detection.format_spans( + processed + ) else: - - if len(spans) > 0: - # ED. - processed = {self.API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.format_spans( - processed - ) - else: - # EL - processed = {self.API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.find_mentions( - processed, self.tagger_ner - ) - - # Disambiguation - predictions, timing = self.model.predict(mentions_dataset) - - # Process result. - result = process_results( - mentions_dataset, - predictions, - processed, - include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, + # EL + processed = {self.API_DOC: [text, spans]} + mentions_dataset, total_ment = self.mention_detection.find_mentions( + processed, self.tagger_ner ) + # Disambiguation + predictions, timing = self.model.predict(mentions_dataset) + + # Process result. + result = process_results( + mentions_dataset, + predictions, + processed, + include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, + ) + # Singular document. if len(result) > 0: return [*result.values()][0] @@ -71,7 +65,7 @@ def generate_response(self, from fastapi import FastAPI from pydantic import BaseModel, Field -from typing import List, Optional +from typing import List, Optional, Literal app = FastAPI() @@ -94,7 +88,24 @@ class EntityConfig(BaseModel): @app.post("/") def root(config: EntityConfig): """Submit your text here for entity disambiguation or linking.""" - response = handler.generate_response(text=config.text, spans=config.spans, conversation=config.conversation) + response = handler.generate_response(text=config.text, spans=config.spans) + return response + + +class ConversationTurn(BaseModel): + speaker: Literal["USER", "SYSTEM"] = Field(..., "Speaker for this turn.") + utterance: str = Field(..., description="Input utterance.") + + +class ConversationConfig(BaseModel): + text: List[ConversationTurn] = Field(..., "Conversation as list of turns between two speakers.") + + +@app.post("/conversation/") +def conversation(config: ConversationConfig): + """Submit your text here for conversational entity linking.""" + text = config.dict()['text'] + response = conv_handler.annotate(text) return response @@ -118,4 +129,6 @@ def root(config: EntityConfig): handler = ModelHandler(args.base_url, args.wiki_version, ed_model, ner_model) + conv_handler = ConvEL(args.base_url, args.wiki_version) + uvicorn.run(app, port=args.port, host=args.bind) From 2e4fddc633fe088c34d1cd887c2b5e76e95a365f Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 20 Dec 2022 12:06:29 +0100 Subject: [PATCH 07/15] Refactor response handler code to submodule --- src/REL/response_model.py | 65 +++++++++++++++++++++++++++++++++ src/REL/server.py | 77 +++++---------------------------------- 2 files changed, 74 insertions(+), 68 deletions(-) create mode 100644 src/REL/response_model.py diff --git a/src/REL/response_model.py b/src/REL/response_model.py new file mode 100644 index 0000000..b41da49 --- /dev/null +++ b/src/REL/response_model.py @@ -0,0 +1,65 @@ +from REL.entity_disambiguation import EntityDisambiguation +from REL.ner import load_flair_ner +from flair.models import SequenceTagger +from REL.mention_detection import MentionDetection +from REL.utils import process_results + + +class ResponseModel: + API_DOC = "API_DOC" + + def __init__(self, base_url, wiki_version, model, tagger_ner=None): + self.model = model + self.tagger_ner = tagger_ner + + self.base_url = base_url + self.wiki_version = wiki_version + + self.custom_ner = not isinstance(tagger_ner, SequenceTagger) + self.mention_detection = MentionDetection(base_url, wiki_version) + + def generate_response(self, + *, + text: list, + spans: list, + ): + """ + Generates response for API. Can be either ED only or EL, meaning end-to-end. + + :return: list of tuples for each entity found. + """ + + if len(text) == 0: + return [] + + processed = {self.API_DOC: [text, spans]} + + if len(spans) > 0: + # ED. + mentions_dataset, total_ment = self.mention_detection.format_spans( + processed + ) + else: + # EL + mentions_dataset, total_ment = self.mention_detection.find_mentions( + processed, self.tagger_ner + ) + + # Disambiguation + predictions, timing = self.model.predict(mentions_dataset) + + include_offset = (len(spans) == 0) and not self.custom_ner + + # Process result. + result = process_results( + mentions_dataset, + predictions, + processed, + include_offset=include_offset, + ) + + # Singular document. + if len(result) > 0: + return [*result.values()][0] + + return [] \ No newline at end of file diff --git a/src/REL/server.py b/src/REL/server.py index 03908e2..b321ad2 100644 --- a/src/REL/server.py +++ b/src/REL/server.py @@ -1,67 +1,4 @@ -from REL.entity_disambiguation import EntityDisambiguation -from REL.ner import load_flair_ner -from flair.models import SequenceTagger -from REL.mention_detection import MentionDetection -from REL.utils import process_results -from REL.crel.conv_el import ConvEL - -class ModelHandler: - API_DOC = "API_DOC" - - def __init__(self, base_url, wiki_version, model, tagger_ner): - self.model = model - self.tagger_ner = tagger_ner - - self.base_url = base_url - self.wiki_version = wiki_version - - self.custom_ner = not isinstance(tagger_ner, SequenceTagger) - self.mention_detection = MentionDetection(base_url, wiki_version) - - def generate_response(self, - *, - text: list, - spans: list, - ): - """ - Generates response for API. Can be either ED only or EL, meaning end-to-end. - - :return: list of tuples for each entity found. - """ - - if len(text) == 0: - return [] - - if len(spans) > 0: - # ED. - processed = {self.API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.format_spans( - processed - ) - else: - # EL - processed = {self.API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.find_mentions( - processed, self.tagger_ner - ) - - # Disambiguation - predictions, timing = self.model.predict(mentions_dataset) - - # Process result. - result = process_results( - mentions_dataset, - predictions, - processed, - include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, - ) - - # Singular document. - if len(result) > 0: - return [*result.values()][0] - - return [] - +from REL.response_model import ResponseModel from fastapi import FastAPI from pydantic import BaseModel, Field @@ -93,12 +30,12 @@ def root(config: EntityConfig): class ConversationTurn(BaseModel): - speaker: Literal["USER", "SYSTEM"] = Field(..., "Speaker for this turn.") + speaker: Literal["USER", "SYSTEM"] = Field(..., description="Speaker for this turn.") utterance: str = Field(..., description="Input utterance.") class ConversationConfig(BaseModel): - text: List[ConversationTurn] = Field(..., "Conversation as list of turns between two speakers.") + text: List[ConversationTurn] = Field(..., description="Conversation as list of turns between two speakers.") @app.post("/conversation/") @@ -122,13 +59,17 @@ def conversation(config: ConversationConfig): p.add_argument("--port", "-p", default=5555, type=int) args = p.parse_args() + from REL.crel.conv_el import ConvEL + from REL.entity_disambiguation import EntityDisambiguation + from REL.ner import load_flair_ner + ner_model = load_flair_ner(args.ner_model) ed_model = EntityDisambiguation( args.base_url, args.wiki_version, {"mode": "eval", "model_path": args.ed_model} ) - handler = ModelHandler(args.base_url, args.wiki_version, ed_model, ner_model) + handler = ResponseModel(args.base_url, args.wiki_version, ed_model, ner_model) - conv_handler = ConvEL(args.base_url, args.wiki_version) + conv_handler = ConvEL(args.base_url, args.wiki_version, ed_model=ed_model) uvicorn.run(app, port=args.port, host=args.bind) From 51cdb97065335791e96498744e6da18eff82030c Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 20 Dec 2022 12:07:13 +0100 Subject: [PATCH 08/15] Remove `rel_ed` and make ConvEL depend on REL MD --- src/REL/crel/conv_el.py | 27 ++++++++++++++++--- src/REL/crel/rel_ed.py | 60 ----------------------------------------- 2 files changed, 23 insertions(+), 64 deletions(-) delete mode 100644 src/REL/crel/rel_ed.py diff --git a/src/REL/crel/conv_el.py b/src/REL/crel/conv_el.py index 1ba276a..53c3046 100644 --- a/src/REL/crel/conv_el.py +++ b/src/REL/crel/conv_el.py @@ -3,14 +3,14 @@ from pathlib import Path from .bert_md import BERT_MD -from .rel_ed import REL_ED from .s2e_pe import pe_data from .s2e_pe.pe import EEMD, PEMD +from REL.response_model import ResponseModel class ConvEL: def __init__( - self, base_url=".", wiki_version="wiki_2019", user_config=None, threshold=0 + self, base_url=".", wiki_version="wiki_2019", ed_model=None, user_config=None, threshold=0 ): self.threshold = threshold @@ -19,7 +19,12 @@ def __init__( self.file_pretrained = str(Path(base_url) / "bert_conv-td") self.bert_md = BERT_MD(self.file_pretrained) - self.rel_ed = REL_ED(self.base_url, self.wiki_version) + + if not ed_model: + ed_model = self._default_ed_model() + + self.response_model = ResponseModel(self.base_url, self.wiki_version, model=ed_model) + self.eemd = EEMD(s2e_pe_model=str(Path(base_url) / "s2e_ast_onto")) self.pemd = PEMD() @@ -32,6 +37,13 @@ def __init__( ) # initialize the history of conversation, which is used in PE Linking self.ment2ent = {} # This will be used for PE Linking + def _default_ed_model(self): + from REL.entity_disambiguation import EntityDisambiguation + return EntityDisambiguation(self.base_url, self.wiki_version, config={ + "mode": "eval", + "model_path": f"{self.base_url}/{self.wiki_version}/generated/model", + }) + def _error_check(self, conv): assert type(conv) == list for turn in conv: @@ -49,7 +61,7 @@ def _el(self, utt): # ED spans = [[r[0], r[1]] for r in md_results] # r[0]: start, r[1]: length - el_results = self.rel_ed.ed(utt, spans) # ED + el_results = self.ed(utt, spans) # ED self.conv_hist_for_pe[-1]["mentions"] = [r[2] for r in el_results] self.ment2ent.update( @@ -140,3 +152,10 @@ def annotate(self, conv): ret[-1]["annotations"] = el_results + pe_results return ret + + def ed(self, text, spans): + """Change tuple to list to match the output format of REL API.""" + response = self.response_model.generate_response(text=text, spans=spans) + return [list(ent) for ent in response] + + diff --git a/src/REL/crel/rel_ed.py b/src/REL/crel/rel_ed.py deleted file mode 100644 index a52148c..0000000 --- a/src/REL/crel/rel_ed.py +++ /dev/null @@ -1,60 +0,0 @@ -import sys - -from REL.entity_disambiguation import EntityDisambiguation -from REL.mention_detection import MentionDetection -from REL.utils import process_results - - -class REL_ED: - def __init__(self, base_url, wiki_version): - - config = { - "mode": "eval", - "model_path": f"{base_url}/{wiki_version}/generated/model", - } - - self.mention_detection = MentionDetection( - base_url, wiki_version - ) # This is only used for format spans - self.model = EntityDisambiguation(base_url, wiki_version, config) - - def generate_response(self, text, spans): - """Generate ED results - - Returns: - - list of tuples for each entity found. - - Note: - - Original code: https://github.com/informagi/REL/blob/9ca253b1d371966c39219ed672f39784fd833d8d/REL/server.py#L101 - """ - - API_DOC = "API_DOC" - - if len(text) == 0 or len(spans) == 0: - return [] - - # Get the mentions from the spans - processed = {API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.format_spans(processed) - - # Disambiguation - predictions, timing = self.model.predict(mentions_dataset) - - # Process result. - result = process_results( - mentions_dataset, - predictions, - processed, - include_offset=False if (len(spans) > 0) else True, - ) - - # Singular document. - if len(result) > 0: - return [*result.values()][0] - - return [] - - def ed(self, text, spans): - """Change tuple to list to match the output format of REL API.""" - response = self.generate_response(text, spans) - return [list(ent) for ent in response] From bd646242a2609ec81b4bff43450444f4cfc416db Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 20 Dec 2022 15:13:17 +0100 Subject: [PATCH 09/15] Fix argument name --- src/REL/crel/conv_el.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/REL/crel/conv_el.py b/src/REL/crel/conv_el.py index 53c3046..868f962 100644 --- a/src/REL/crel/conv_el.py +++ b/src/REL/crel/conv_el.py @@ -39,7 +39,7 @@ def __init__( def _default_ed_model(self): from REL.entity_disambiguation import EntityDisambiguation - return EntityDisambiguation(self.base_url, self.wiki_version, config={ + return EntityDisambiguation(self.base_url, self.wiki_version, user_config={ "mode": "eval", "model_path": f"{self.base_url}/{self.wiki_version}/generated/model", }) From 1db9086b9bae2dfc595466329d76ebffd3ff8f8a Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 20 Dec 2022 15:56:44 +0100 Subject: [PATCH 10/15] Add some docs for conversational entity linking --- docs/tutorials/conversations.md | 71 +++++++++++++++++++++++++++++++++ docs/tutorials/index.md | 1 + mkdocs.yml | 4 +- 3 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 docs/tutorials/conversations.md diff --git a/docs/tutorials/conversations.md b/docs/tutorials/conversations.md new file mode 100644 index 0000000..afaa276 --- /dev/null +++ b/docs/tutorials/conversations.md @@ -0,0 +1,71 @@ +# Conversational entity linking + +The `crel` submodule the conversational entity linking tool trained on the [ConEL-2 dataset](https://github.com/informagi/conversational-entity-linking-2022#conel-2-conversational-entity-linking-dataset). + +Unlike existing EL methods, `crel` is developed to identify both named entities and concepts. +It also uses coreference resolution techniques to identify personal entities and references to the explicit entity mentions in the conversations. + +This tutorial describes how to start with conversational entity linking on a local machine. + +For more information, see the original [repository on conversational entity linking](https://github.com/informagi/conversational-entity-linking-2022). + +## Start with your local environment + +### Step 1: Download models + +First, download the models below: + +- **MD for concepts and NEs**: + + [Click here to download models](https://drive.google.com/file/d/1OoC2XZp4uBy0eB_EIuIhEHdcLEry2LtU/view?usp=sharing) + + Extract `bert_conv-td` to your `base_url` +- **Personal Entity Linking**: + + [Click here to download models](https://drive.google.com/file/d/1-jW8xkxh5GV-OuUBfMeT2Tk7tEzvH181/view?usp=sharing) + + Extract `s2e_ast_onto` to your `base_url` + +Additionally, conversational entity linking uses the wiki 2019 dataset. For more information on where to place the data and the `base_url`, check out [this page](../how_to_get_started). If setup correctly, your `base_url` should contain these directories: + + +```bash +. +└── bert_conv-td +└── s2e_ast_onto +└── wiki_2019 +``` + + +### Step 2: Example code + +This example shows how to link a short conversation. Note that the speakers must be named "USER" and "SPEAKER". + + +```python +>>> from REL.crel.conv_el import ConvEL +>>> +>>> cel = ConvEL(base_url="C:/path/to/base_url/") +>>> +>>> conversation = [ +>>> {"speaker": "USER", +>>> "utterance": "I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.",}, +>>> +>>> {"speaker": "SYSTEM", +>>> "utterance": "Some people are allergic to histamine in tomatoes.",}, +>>> +>>> {"speaker": "USER", +>>> "utterance": "Talking of food, can you recommend me a restaurant in my city for our anniversary?",}, +>>> ] +>>> +>>> annotated = cel.annotate(conversation) +>>> [item for item in annotated if item['speaker'] == 'USER'] +[{'speaker': 'USER', + 'utterance': 'I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.', + 'annotations': [[17, 8, 'tomatoes', 'Tomato'], + [54, 19, 'Italian restaurants', 'Italian_cuisine'], + [82, 6, 'London', 'London']]}, + {'speaker': 'USER', + 'utterance': 'Talking of food, can you recommend me a restaurant in my city for our anniversary?', + 'annotations': [[11, 4, 'food', 'Food'], + [40, 10, 'restaurant', 'Restaurant'], + [54, 7, 'my city', 'London']]}] + +``` + diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index c96d1bd..371db1a 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -14,3 +14,4 @@ The remainder of the tutorials are optional and for users who wish to e.g. train 5. [Reproducing our results](reproducing_our_results/) 6. [REL as systemd service](systemd_instructions/) 7. [Notes on using custom models](custom_models/) +7. [Conversational entity linking](conversations/) diff --git a/mkdocs.yml b/mkdocs.yml index 1ef6fd5..2a11e66 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,6 +10,7 @@ nav: - tutorials/reproducing_our_results.md - tutorials/systemd_instructions.md - tutorials/custom_models.md + - tutorials/conversations.md - Python API reference: - api/entity_disambiguation.md - api/generate_train_test.md @@ -72,11 +73,10 @@ plugins: - https://numpy.org/doc/stable/objects.inv - https://docs.scipy.org/doc/scipy/objects.inv - https://pandas.pydata.org/docs/objects.inv - selection: + options: docstring_style: sphinx docstring_options: ignore_init_summary: yes - rendering: show_submodules: no show_source: true docstring_section_style: list From 64959f8bd42890e7bbc10c2e069fbe78f265cefd Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 20 Dec 2022 16:27:59 +0100 Subject: [PATCH 11/15] Skip CREL test if running on CI We have no way of testing, because the data are not available --- tests/test_crel.py | 84 +++++++++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/tests/test_crel.py b/tests/test_crel.py index c50802d..76d1a30 100644 --- a/tests/test_crel.py +++ b/tests/test_crel.py @@ -7,30 +7,27 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "1" -@pytest.fixture +@pytest.fixture() def cel(): - return ConvEL(base_url = os.environ.get('REL_BASE_URL', '.')) + return ConvEL(base_url=os.environ.get("REL_BASE_URL", ".")) +@pytest.mark.skipif( + os.environ.get("GITHUB_ACTIONS"), reason="No way of testing this on Github actions." +) def test_conv1(cel): example = [ { - "speaker": - "USER", - "utterance": - "I think science fiction is an amazing genre for anything. Future science, technology, time travel, FTL travel, they're all such interesting concepts.", + "speaker": "USER", + "utterance": "I think science fiction is an amazing genre for anything. Future science, technology, time travel, FTL travel, they're all such interesting concepts.", }, { - "speaker": - "SYSTEM", - "utterance": - "Awesome! I really love how sci-fi storytellers focus on political/social/philosophical issues that would still be around even in the future. Makes them relatable.", + "speaker": "SYSTEM", + "utterance": "Awesome! I really love how sci-fi storytellers focus on political/social/philosophical issues that would still be around even in the future. Makes them relatable.", }, { - "speaker": - "USER", - "utterance": - "I agree. One of my favorite forms of science fiction is anything related to time travel! I find it fascinating.", + "speaker": "USER", + "utterance": "I agree. One of my favorite forms of science fiction is anything related to time travel! I find it fascinating.", }, ] @@ -38,57 +35,60 @@ def test_conv1(cel): assert isinstance(result, list) expected_annotations = [ - [[8, 15, 'science fiction', 'Science_fiction'], - [38, 5, 'genre', 'Genre_fiction'], - [74, 10, 'technology', 'Technology'], - [86, 11, 'time travel', 'Time_travel'], - [99, 10, 'FTL travel', 'Faster-than-light']], - [[37, 15, 'science fiction', 'Science_fiction'], - [76, 11, 'time travel', 'Time_travel'], - [16, 36, 'my favorite forms of science fiction', 'Time_travel']], + [ + [8, 15, "science fiction", "Science_fiction"], + [38, 5, "genre", "Genre_fiction"], + [74, 10, "technology", "Technology"], + [86, 11, "time travel", "Time_travel"], + [99, 10, "FTL travel", "Faster-than-light"], + ], + [ + [37, 15, "science fiction", "Science_fiction"], + [76, 11, "time travel", "Time_travel"], + [16, 36, "my favorite forms of science fiction", "Time_travel"], + ], ] - annotations = [ - res['annotations'] for res in result if res['speaker'] == 'USER' - ] + annotations = [res["annotations"] for res in result if res["speaker"] == "USER"] assert annotations == expected_annotations +@pytest.mark.skipif( + os.environ.get("GITHUB_ACTIONS"), reason="No way of testing this on Github actions." +) def test_conv2(cel): example = [ { - "speaker": - "USER", - "utterance": - "I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.", + "speaker": "USER", + "utterance": "I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.", }, { "speaker": "SYSTEM", "utterance": "Some people are allergic to histamine in tomatoes.", }, { - "speaker": - "USER", - "utterance": - "Talking of food, can you recommend me a restaurant in my city for our anniversary?", + "speaker": "USER", + "utterance": "Talking of food, can you recommend me a restaurant in my city for our anniversary?", }, ] result = cel.annotate(example) assert isinstance(result, list) - annotations = [ - res['annotations'] for res in result if res['speaker'] == 'USER' - ] + annotations = [res["annotations"] for res in result if res["speaker"] == "USER"] expected_annotations = [ - [[17, 8, 'tomatoes', 'Tomato'], - [54, 19, 'Italian restaurants', 'Italian_cuisine'], - [82, 6, 'London', 'London']], - [[11, 4, 'food', 'Food'], - [40, 10, 'restaurant', 'Restaurant'], - [54, 7, 'my city', 'London']], + [ + [17, 8, "tomatoes", "Tomato"], + [54, 19, "Italian restaurants", "Italian_cuisine"], + [82, 6, "London", "London"], + ], + [ + [11, 4, "food", "Food"], + [40, 10, "restaurant", "Restaurant"], + [54, 7, "my city", "London"], + ], ] assert annotations == expected_annotations From 8097e620322a246eacd79e59b29bab937b6ed0b9 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 20 Dec 2022 16:52:23 +0100 Subject: [PATCH 12/15] Update skipif condition --- tests/test_crel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_crel.py b/tests/test_crel.py index 76d1a30..d12e759 100644 --- a/tests/test_crel.py +++ b/tests/test_crel.py @@ -13,7 +13,7 @@ def cel(): @pytest.mark.skipif( - os.environ.get("GITHUB_ACTIONS"), reason="No way of testing this on Github actions." + os.environ.get("GITHUB_ACTIONS")=='true',, reason="No way of testing this on Github actions." ) def test_conv1(cel): example = [ @@ -55,7 +55,7 @@ def test_conv1(cel): @pytest.mark.skipif( - os.environ.get("GITHUB_ACTIONS"), reason="No way of testing this on Github actions." + os.environ.get("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions." ) def test_conv2(cel): example = [ From 52ae1468e35d14084edddf4a96c29fab98007b0f Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 20 Dec 2022 16:54:28 +0100 Subject: [PATCH 13/15] Change to getenv function --- tests/test_crel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_crel.py b/tests/test_crel.py index d12e759..8c9669b 100644 --- a/tests/test_crel.py +++ b/tests/test_crel.py @@ -13,7 +13,7 @@ def cel(): @pytest.mark.skipif( - os.environ.get("GITHUB_ACTIONS")=='true',, reason="No way of testing this on Github actions." + os.getenv("GITHUB_ACTIONS")=='true',, reason="No way of testing this on Github actions." ) def test_conv1(cel): example = [ @@ -55,7 +55,7 @@ def test_conv1(cel): @pytest.mark.skipif( - os.environ.get("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions." + os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions." ) def test_conv2(cel): example = [ From aa691f5c44c668939469dd09a1cb413a4831fdc8 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Tue, 20 Dec 2022 17:00:36 +0100 Subject: [PATCH 14/15] Fix syntax error --- tests/test_crel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_crel.py b/tests/test_crel.py index 8c9669b..36fc551 100644 --- a/tests/test_crel.py +++ b/tests/test_crel.py @@ -13,7 +13,7 @@ def cel(): @pytest.mark.skipif( - os.getenv("GITHUB_ACTIONS")=='true',, reason="No way of testing this on Github actions." + os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions." ) def test_conv1(cel): example = [ From f88f1b039d36f1d7e8ab33f089acaf7ac21604eb Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Mon, 23 Jan 2023 16:12:19 +0100 Subject: [PATCH 15/15] Remove unused code --- src/REL/crel/s2e_pe/utils.py | 68 ------------------------------------ 1 file changed, 68 deletions(-) diff --git a/src/REL/crel/s2e_pe/utils.py b/src/REL/crel/s2e_pe/utils.py index 048dabb..7207b59 100644 --- a/src/REL/crel/s2e_pe/utils.py +++ b/src/REL/crel/s2e_pe/utils.py @@ -5,7 +5,6 @@ import numpy as np -# import git import torch from .consts import NULL_ID_FOR_COREF @@ -31,7 +30,6 @@ def extract_mentions_to_predicted_clusters_from_clusters(gold_clusters): return mention_to_gold -# def extract_clusters_for_decode(mention_to_antecedent): def extract_clusters_for_decode(mention_to_antecedent, pems_subtoken): """ Args: @@ -134,69 +132,3 @@ def mask_tensor(t, mask): t = t + ((1.0 - mask.float()) * -10000.0) t = torch.clamp(t, min=-10000.0, max=10000.0) return t - - -# def write_meta_data(output_dir, args): -# output_path = os.path.join(output_dir, "meta.json") -# repo = git.Repo(search_parent_directories=True) -# hexsha = repo.head.commit.hexsha -# ts = time() -# print(f"Writing {output_path}") -# with open(output_path, mode='w') as f: -# json.dump( -# { -# 'git_hexsha': hexsha, -# 'args': {k: str(v) for k, v in args.__dict__.items()}, -# 'date': datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') -# }, -# f, -# indent=4, -# sort_keys=True) - -# def ce_get_start_end_subtoken_num(start_token, end_token, subtoken_map): -# """ -# Example: -# ### Input ### -# start_token, end_token = (2,4) -# subtoken_map = [0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] # subtoken_map - -# ### Output ### -# (6, 8) - -# Notes: -# - This function is used in modeling.py and this util.py -# """ -# N = len(subtoken_map) -# start_subtoken = subtoken_map.index(start_token) -# end_subtoken = (N-1) - subtoken_map[::-1].index(end_token) - -# return start_subtoken, end_subtoken - -# def ce_get_pem_ments(predict_file): -# """ -# Args: -# args.predict_file: E.g., '../data/data_dir/test.english.jsonlines' -# Return: -# - pems (dict): E.g., {'dialind_0': [[2, 4]], 'dialind_1': [[66, 67]], ...} -# - ments (dict): E.g., {'dialind_0': [[0, 0], [3, 4], [35, 35], ...} -# """ -# pems = {} -# ments = {} -# with open(predict_file) as f: -# jsonl = f.readlines() -# print(len(jsonl)) -# for l in jsonl: -# l = json.loads(l) -# assert l['doc_key'] not in pems, 'doc_key should be unique' -# assert l['doc_key'] not in ments, 'doc_key should be unique' -# pems[l['doc_key']] = l['pems'] -# ments[l['doc_key']] = l['mentions'] -# # pems.append(l['pems']) # PEM -# # ments.append(l['mentions']) # mentions - -# # Error check: pem should not be in ments -# for doc_key in pems: # doc_key: E.g., 'dialind_0' -# for pem in pems[doc_key]: # pem: E.g., [2, 4] -# assert pem not in ments[doc_key], f'PEM {pem} should not be in ments. Fix this at 010_preprocess_eval_ConEL.ipynb' - -# return pems, ments