diff --git a/CHANGELOG.md b/CHANGELOG.md index af958255..63262786 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,17 @@ Types of changes * "Security" in case of vulnerabilities. --> +## [Unreleased] + +### Added + +- Added an `overwrite_entities` parameter to the spaCy pipeline component to allow for overwriting spaCy entities. +- Added `.pipe()` method to spaCy integration to allow for batched inference. + +### Changed + +- Stop overwriting spaCy entities by default. + ## [1.2.5] ### Fixed diff --git a/notebooks/spacy_integration.ipynb b/notebooks/spacy_integration.ipynb index 9af305bb..3e245460 100644 --- a/notebooks/spacy_integration.ipynb +++ b/notebooks/spacy_integration.ipynb @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -58,7 +58,7 @@ " BCE)" ] }, - "execution_count": 11, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -266,12 +266,12 @@ "source": [ "Much better!\n", "\n", - "But, what if we don't want to use a model with these labels? Well, this integration works for any [SpanMarker model on the Hugging Face Hub](https://huggingface.co/models?library=span-marker), so we can just pick another one. Let's now also ensure that the model stays on the CPU, just to see how that works." + "But, what if we don't want to use a model with these labels? Well, this integration works for any [SpanMarker model on the Hugging Face Hub](https://huggingface.co/models?library=span-marker), so we can just pick another one. Let's now also ensure that the model stays on the CPU, just to see how that works. Beyond that, we'll overwrite entities from spaCy's own NER model. This is recommended when the SpanMarker model uses a different label scheme than spaCy, which uses the labels from OntoNotes v5." ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -328,6 +328,7 @@ " config={\n", " \"model\": \"tomaarsen/span-marker-xlm-roberta-base-fewnerd-fine-super\",\n", " \"device\": \"cpu\",\n", + " \"overwrite_entities\": True,\n", " },\n", ")\n", "\n", @@ -347,7 +348,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -360,7 +361,7 @@ " (Paris, 'GPE')]" ] }, - "execution_count": 16, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } diff --git a/span_marker/__init__.py b/span_marker/__init__.py index 15be1969..62de8722 100644 --- a/span_marker/__init__.py +++ b/span_marker/__init__.py @@ -26,6 +26,7 @@ "model": "tomaarsen/span-marker-roberta-large-ontonotes5", "batch_size": 4, "device": None, + "overwrite_entities": False, } @Language.factory( @@ -39,14 +40,16 @@ def _spacy_span_marker_factory( model: str, batch_size: int, device: Optional[Union[str, torch.device]], + overwrite_entities: bool, ) -> SpacySpanMarkerWrapper: - # Remove the existing NER component, if it exists, - # to allow for SpanMarker to act as a drop-in replacement - try: - nlp.remove_pipe("ner") - except ValueError: - # The `ner` pipeline component was not found - pass + if overwrite_entities: + # Remove the existing NER component, if it exists, + # to allow for SpanMarker to act as a drop-in replacement + try: + nlp.remove_pipe("ner") + except ValueError: + # The `ner` pipeline component was not found + pass return SpacySpanMarkerWrapper(model, batch_size=batch_size, device=device) diff --git a/span_marker/spacy_integration.py b/span_marker/spacy_integration.py index 323311a3..8abb7c18 100644 --- a/span_marker/spacy_integration.py +++ b/span_marker/spacy_integration.py @@ -1,9 +1,11 @@ import os -from typing import Any, Optional, Union +import types +from typing import List, Optional, Union import torch from datasets import Dataset from spacy.tokens import Doc, Span +from spacy.util import filter_spans, minibatch from span_marker.modeling import SpanMarkerModel @@ -53,6 +55,7 @@ def __init__( *args, batch_size: int = 4, device: Optional[Union[str, torch.device]] = None, + overwrite_entities: bool = False, **kwargs, ) -> None: """Initialize a SpanMarker wrapper for spaCy. @@ -63,6 +66,8 @@ def __init__( batch_size (int): The number of samples to include per batch. Higher is faster, but requires more memory. Defaults to 4. device (Optional[Union[str, torch.device]]): The device to place the model on. Defaults to None. + overwrite_entities (bool): Whether to overwrite the existing entities in the `doc.ents` attribute. + Defaults to False. """ self.model = SpanMarkerModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if device: @@ -70,22 +75,35 @@ def __init__( elif torch.cuda.is_available(): self.model.to("cuda") self.batch_size = batch_size + self.overwrite_entities = overwrite_entities + + @staticmethod + def convert_inputs_to_dataset(inputs): + inputs = Dataset.from_dict( + { + "tokens": inputs, + "document_id": [0] * len(inputs), + "sentence_id": range(len(inputs)), + } + ) + return inputs + + def set_ents(self, doc: Doc, ents: List[Span]): + if self.overwrite_entities: + doc.set_ents(ents) + else: + doc.set_ents(filter_spans(ents + list(doc.ents))) def __call__(self, doc: Doc) -> Doc: """Fill `doc.ents` and `span.label_` using the chosen SpanMarker model.""" sents = list(doc.sents) inputs = [[token.text if not token.is_space else "" for token in sent] for sent in sents] + # use document-level context in the inference if the model was also trained that way if self.model.config.trained_with_document_context: - inputs = Dataset.from_dict( - { - "tokens": inputs, - "document_id": [0] * len(inputs), - "sentence_id": range(len(inputs)), - } - ) - outputs = [] + inputs = self.convert_inputs_to_dataset(inputs) + ents = [] entities_list = self.model.predict(inputs, batch_size=self.batch_size) for sentence, entities in zip(sents, entities_list): for entity in entities: @@ -93,7 +111,37 @@ def __call__(self, doc: Doc) -> Doc: end = entity["word_end_index"] span = sentence[start:end] span.label_ = entity["label"] - outputs.append(span) + ents.append(span) + + self.set_ents(doc, ents) - doc.set_ents(outputs) return doc + + def pipe(self, stream, batch_size=128): + """Fill `doc.ents` and `span.label_` using the chosen SpanMarker model.""" + if isinstance(stream, str): + stream = [stream] + + if not isinstance(stream, types.GeneratorType): + stream = self.nlp.pipe(stream, batch_size=batch_size) + + for docs in minibatch(stream, size=batch_size): + inputs = [[token.text if not token.is_space else "" for token in doc] for doc in docs] + + # use document-level context in the inference if the model was also trained that way + if self.model.config.trained_with_document_context: + inputs = self.convert_inputs_to_dataset(inputs) + + entities_list = self.model.predict(inputs, batch_size=self.batch_size) + for doc, entities in zip(docs, entities_list): + ents = [] + for entity in entities: + start = entity["word_start_index"] + end = entity["word_end_index"] + span = doc[start:end] + span.label_ = entity["label"] + ents.append(span) + + self.set_ents(doc, ents) + + yield doc diff --git a/tests/test_spacy_integration.py b/tests/test_spacy_integration.py index 6ef8a4e3..70a86616 100644 --- a/tests/test_spacy_integration.py +++ b/tests/test_spacy_integration.py @@ -27,3 +27,32 @@ def test_span_marker_as_spacy_pipeline_component(): ("Atlantic", "LOC"), ("Paris", "LOC"), ] + +def test_span_marker_as_spacy_pipeline_component_pipe(): + nlp = spacy.load("en_core_web_sm", disable=["ner"]) + batch_size = 2 + wrapper = nlp.add_pipe( + "span_marker", config={"model": "tomaarsen/span-marker-bert-tiny-conll03", "batch_size": batch_size} + ) + assert wrapper.batch_size == batch_size + + docs = nlp.pipe(["Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris."]) + doc = list(docs)[0] + assert [(span.text, span.label_) for span in doc.ents] == [ + ("Amelia Earhart", "PER"), + ("Lockheed Vega", "ORG"), + ("Atlantic", "LOC"), + ("Paris", "LOC"), + ] + + # Override a setting that modifies how inference is performed, + # should not have any impact with just one sentence, i.e. it should still work. + wrapper.model.config.trained_with_document_context = True + docs = nlp.pipe(["Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris."]) + doc = list(docs)[0] + assert [(span.text, span.label_) for span in doc.ents] == [ + ("Amelia Earhart", "PER"), + ("Lockheed Vega", "ORG"), + ("Atlantic", "LOC"), + ("Paris", "LOC"), + ]