diff --git a/tempo_embeddings/settings.py b/tempo_embeddings/settings.py index c7741af..fc7b46a 100644 --- a/tempo_embeddings/settings.py +++ b/tempo_embeddings/settings.py @@ -35,9 +35,9 @@ # Snellius: Path().home() / "data", # Yoda drive mounted on MacOS: - # Path( - # "/Volumes/i-lab.data.uu.nl/research-semantics-of-sustainability/semantics-of-sustainability/data" - # ), + Path( + "/Volumes/i-lab.data.uu.nl/research-semantics-of-sustainability/semantics-of-sustainability/data" + ), ] """Directories in which corpora are stored; the first one found is used.""" @@ -47,7 +47,7 @@ print(f"Using corpus directory: '{CORPUS_DIR}'") except StopIteration: logging.error(f"No corpus directory found in {_CORPUS_DIRS}") - CORPUS_DIR = None + CORPUS_DIR = Path(".") DEFAULT_LANGUAGE_MODEL: str = ( "NetherlandsForensicInstitute/robbert-2022-dutch-sentence-transformers" diff --git a/tempo_embeddings/text/passage.py b/tempo_embeddings/text/passage.py index 879f534..dbb1ac0 100644 --- a/tempo_embeddings/text/passage.py +++ b/tempo_embeddings/text/passage.py @@ -5,7 +5,7 @@ from typing import Any, Iterable, Optional from dateutil.parser import ParserError, parse -from pydantic import BaseModel, ConfigDict, ValidationError, field_validator +from pydantic import BaseModel, ConfigDict, field_validator from .highlighting import Highlighting @@ -29,7 +29,7 @@ def parse_date(cls, value) -> datetime.datetime: if not value.tzinfo: value = value.replace(tzinfo=datetime.timezone.utc) except ParserError as e: - raise ValidationError(e) + raise ValueError("Error in Metadata") from e return value def __init__( @@ -43,6 +43,11 @@ def __init__( char2tokens: Optional[list[int]] = None, unique_id: str = None, ): + """ + Initializes a Passage object. + + Raises: + ValidationError: If the metadata is invalid.""" # pylint: disable=too-many-arguments self._text = text.strip() self._unique_id = unique_id @@ -423,9 +428,9 @@ def merge_until(self, passages: list["Passage"], *, length: int) -> "Passage": A new Passage potentially merged to increase its length. """ - if len(self) > length * 1.5: - logging.warning( - "Very long passage (%d characters): %s", len(self), self.metadata + if len(self) > length * 2: + logging.info( + "Very long passage (%d >> %d): %s", len(self), length, self.metadata ) elif passages and (len(self) + len(passages[0]) <= length): return (self + passages.pop(0)).merge_until(passages, length=length) diff --git a/tempo_embeddings/text/segmenter.py b/tempo_embeddings/text/segmenter.py index fc78f29..adf03eb 100644 --- a/tempo_embeddings/text/segmenter.py +++ b/tempo_embeddings/text/segmenter.py @@ -9,6 +9,7 @@ import stanza import torch import wtpsplit +from pydantic import ValidationError from sentence_splitter import SentenceSplitter from .. import settings @@ -60,13 +61,18 @@ def split(self, text: str) -> Iterable[str]: return NotImplemented def passages( - self, text: str, *, metadata: Optional[dict[str, Any]] = None + self, + text: str, + *, + metadata: Optional[dict[str, Any]] = None, + strict: bool = False, ) -> Iterable[Passage]: """Yield passages from the text. Args: text: the text to split into passages. metadata: the metadata to attach to the passages. + strict: if True, raise an error on invalid metadata. Otherwise, log a warning. Yields: Passage: the passages from the text. """ @@ -74,7 +80,18 @@ def passages( for idx, sentence in enumerate(self.split(text)): metadata = (metadata or {}) | {"sentence_index": idx} - yield Passage(sentence, metadata) + try: + yield Passage(sentence, metadata) + except ValidationError as e: + if strict: + raise e + else: + logging.warning( + "Skipping sentence %d with invalid metadata in '%s': %s", + idx, + metadata.get("provenance"), + e, + ) def passages_from_dict_reader( self, diff --git a/tests/unit/text/test_segmenter.py b/tests/unit/text/test_segmenter.py index 075f687..0825d27 100644 --- a/tests/unit/text/test_segmenter.py +++ b/tests/unit/text/test_segmenter.py @@ -3,6 +3,7 @@ from io import StringIO import pytest +from pydantic import ValidationError from tempo_embeddings.text.highlighting import Highlighting from tempo_embeddings.text.passage import Passage @@ -105,6 +106,37 @@ def test_passages(self, text, expected): passages = SentenceSplitterSegmenter("en").passages(text) assert list(passages) == expected + @pytest.mark.parametrize( + "text,metadata,strict,expected,exception", + [ + ("text", {}, True, [Passage("text", metadata={"sentence_index": 0})], None), + ( + "text", + {"date": "01-05-1889"}, + True, + [Passage("text", metadata={"date": "01-05-1889", "sentence_index": 0})], + None, + ), + ( + "text", + {"date": "01-05-1889"}, + False, + [Passage("text", metadata={"date": "01-05-1889", "sentence_index": 0})], + None, + ), + ("text", {"date": "99-05-1889"}, True, [], pytest.raises(ValidationError)), + ("text", {"date": "99-05-1889"}, False, [], None), + ], + ) + def test_passages_metadata_strict( + self, text, metadata, strict, expected, exception + ): + with exception or does_not_raise(): + passages = SentenceSplitterSegmenter("en").passages( + text, metadata=metadata, strict=strict + ) + assert list(passages) == expected + @pytest.mark.parametrize( "_csv, length, provenance, filter_terms, expected", [