Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to get character indicies #19

Merged
merged 7 commits into from
Apr 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions crosslingual_coreference/CrossLingualPredictor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import itertools
import pathlib
from typing import List, Union
from typing import List, Tuple, Union

import requests
import tqdm # progress bar
from allennlp.predictors.predictor import Predictor
from spacy.tokens import Doc
from spacy.tokens import Doc, Span

from .CorefResolver import CorefResolver as Resolver

Expand Down Expand Up @@ -126,6 +126,7 @@ def predict(self, text: str) -> dict:

merged_clusters = self.merge_clusters(corrected_clusters)
resolved_text, heads = self.resolver.replace_corefs(doc, merged_clusters)
merged_clusters, heads = self.convert_indices(merged_clusters, heads, doc)

prediction = {"clusters": merged_clusters, "resolved_text": resolved_text, "cluster_heads": heads}

Expand All @@ -152,9 +153,11 @@ def pipe(self, texts: List[str]):
json_predictions = self.predictor.predict_batch_json(json_batch)
clusters_predictions = [prediction.get("clusters") for prediction in json_predictions]

for spacy_doc, cluster in zip(spacy_document_list, clusters_predictions):
resolved_text, heads = self.resolver.replace_corefs(spacy_doc, cluster)
predictions.append({"clusters": cluster, "resolved_text": resolved_text, "cluster_heads": heads})
for spacy_doc, clusters in zip(spacy_document_list, clusters_predictions):
resolved_text, heads = self.resolver.replace_corefs(spacy_doc, clusters)
clusters, heads = self.convert_indices(clusters, heads, spacy_doc)

predictions.append({"clusters": clusters, "resolved_text": resolved_text, "cluster_heads": heads})

return predictions

Expand Down Expand Up @@ -226,3 +229,33 @@ def merge_clusters(
main_doc_clus.sort()
main_doc_clus = list(k for k, _ in itertools.groupby(main_doc_clus))
return main_doc_clus

@staticmethod
def convert_indices(merged_clusters: List[List[List[int]]], heads: dict, spacy_doc: Doc) -> Tuple[list, dict]:
"""Convert indices from token to character level

Args:
merged_clusters (List[List[List[int]]]): List of clusters
heads (Dict[List[int]]): Dictionary of cluster heads
spacy_doc (Doc): Spacy doc object

Returns:
List[List[List[int]]], Dict[List[int]]: Tuple of converted clusters and heads
"""
char_merged_clusters = []
char_heads = {}

# clusters
for cluster in merged_clusters:
char_cluster = []
for span in cluster:
spacy_span = Span(spacy_doc, span[0], span[1] + 1)
char_cluster.append([spacy_span.start_char, spacy_span.end_char])
char_merged_clusters.append(char_cluster)

# cluster heads
for head_key in heads.keys():
span = Span(spacy_doc, heads[head_key][0], heads[head_key][1] + 1)
char_heads[head_key] = [span.start_char, span.end_char]

return char_merged_clusters, char_heads