Skip to content

Commit

Permalink
Merge pull request #19 from Pandora-Intelligence/character-ranges
Browse files Browse the repository at this point in the history
Add option to get character indicies
  • Loading branch information
davidberenstein1957 authored Apr 5, 2023
2 parents e038afb + d4494b2 commit 390dc71
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions crosslingual_coreference/CrossLingualPredictor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import itertools
import pathlib
from typing import List, Union
from typing import List, Tuple, Union

import requests
import tqdm # progress bar
from allennlp.predictors.predictor import Predictor
from spacy.tokens import Doc
from spacy.tokens import Doc, Span

from .CorefResolver import CorefResolver as Resolver

Expand Down Expand Up @@ -126,6 +126,7 @@ def predict(self, text: str) -> dict:

merged_clusters = self.merge_clusters(corrected_clusters)
resolved_text, heads = self.resolver.replace_corefs(doc, merged_clusters)
merged_clusters, heads = self.convert_indices(merged_clusters, heads, doc)

prediction = {"clusters": merged_clusters, "resolved_text": resolved_text, "cluster_heads": heads}

Expand All @@ -152,9 +153,11 @@ def pipe(self, texts: List[str]):
json_predictions = self.predictor.predict_batch_json(json_batch)
clusters_predictions = [prediction.get("clusters") for prediction in json_predictions]

for spacy_doc, cluster in zip(spacy_document_list, clusters_predictions):
resolved_text, heads = self.resolver.replace_corefs(spacy_doc, cluster)
predictions.append({"clusters": cluster, "resolved_text": resolved_text, "cluster_heads": heads})
for spacy_doc, clusters in zip(spacy_document_list, clusters_predictions):
resolved_text, heads = self.resolver.replace_corefs(spacy_doc, clusters)
clusters, heads = self.convert_indices(clusters, heads, spacy_doc)

predictions.append({"clusters": clusters, "resolved_text": resolved_text, "cluster_heads": heads})

return predictions

Expand Down Expand Up @@ -226,3 +229,33 @@ def merge_clusters(
main_doc_clus.sort()
main_doc_clus = list(k for k, _ in itertools.groupby(main_doc_clus))
return main_doc_clus

@staticmethod
def convert_indices(merged_clusters: List[List[List[int]]], heads: dict, spacy_doc: Doc) -> Tuple[list, dict]:
"""Convert indices from token to character level
Args:
merged_clusters (List[List[List[int]]]): List of clusters
heads (Dict[List[int]]): Dictionary of cluster heads
spacy_doc (Doc): Spacy doc object
Returns:
List[List[List[int]]], Dict[List[int]]: Tuple of converted clusters and heads
"""
char_merged_clusters = []
char_heads = {}

# clusters
for cluster in merged_clusters:
char_cluster = []
for span in cluster:
spacy_span = Span(spacy_doc, span[0], span[1] + 1)
char_cluster.append([spacy_span.start_char, spacy_span.end_char])
char_merged_clusters.append(char_cluster)

# cluster heads
for head_key in heads.keys():
span = Span(spacy_doc, heads[head_key][0], heads[head_key][1] + 1)
char_heads[head_key] = [span.start_char, span.end_char]

return char_merged_clusters, char_heads

0 comments on commit 390dc71

Please sign in to comment.