Skip to content

Commit

Permalink
Merge pull request #134 from Semantics-of-Sustainability/fix/read_errors
Browse files Browse the repository at this point in the history
Fix/read errors
  • Loading branch information
carschno authored Nov 4, 2024
2 parents 2da3a5b + fe6191c commit 9d36c92
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 11 deletions.
8 changes: 4 additions & 4 deletions tempo_embeddings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
# Snellius:
Path().home() / "data",
# Yoda drive mounted on MacOS:
# Path(
# "/Volumes/i-lab.data.uu.nl/research-semantics-of-sustainability/semantics-of-sustainability/data"
# ),
Path(
"/Volumes/i-lab.data.uu.nl/research-semantics-of-sustainability/semantics-of-sustainability/data"
),
]
"""Directories in which corpora are stored; the first one found is used."""

Expand All @@ -47,7 +47,7 @@
print(f"Using corpus directory: '{CORPUS_DIR}'")
except StopIteration:
logging.error(f"No corpus directory found in {_CORPUS_DIRS}")
CORPUS_DIR = None
CORPUS_DIR = Path(".")

DEFAULT_LANGUAGE_MODEL: str = (
"NetherlandsForensicInstitute/robbert-2022-dutch-sentence-transformers"
Expand Down
15 changes: 10 additions & 5 deletions tempo_embeddings/text/passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Iterable, Optional

from dateutil.parser import ParserError, parse
from pydantic import BaseModel, ConfigDict, ValidationError, field_validator
from pydantic import BaseModel, ConfigDict, field_validator

from .highlighting import Highlighting

Expand All @@ -29,7 +29,7 @@ def parse_date(cls, value) -> datetime.datetime:
if not value.tzinfo:
value = value.replace(tzinfo=datetime.timezone.utc)
except ParserError as e:
raise ValidationError(e)
raise ValueError("Error in Metadata") from e
return value

def __init__(
Expand All @@ -43,6 +43,11 @@ def __init__(
char2tokens: Optional[list[int]] = None,
unique_id: str = None,
):
"""
Initializes a Passage object.
Raises:
ValidationError: If the metadata is invalid."""
# pylint: disable=too-many-arguments
self._text = text.strip()
self._unique_id = unique_id
Expand Down Expand Up @@ -423,9 +428,9 @@ def merge_until(self, passages: list["Passage"], *, length: int) -> "Passage":
A new Passage potentially merged to increase its length.
"""

if len(self) > length * 1.5:
logging.warning(
"Very long passage (%d characters): %s", len(self), self.metadata
if len(self) > length * 2:
logging.info(
"Very long passage (%d >> %d): %s", len(self), length, self.metadata
)
elif passages and (len(self) + len(passages[0]) <= length):
return (self + passages.pop(0)).merge_until(passages, length=length)
Expand Down
21 changes: 19 additions & 2 deletions tempo_embeddings/text/segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import stanza
import torch
import wtpsplit
from pydantic import ValidationError
from sentence_splitter import SentenceSplitter

from .. import settings
Expand Down Expand Up @@ -60,21 +61,37 @@ def split(self, text: str) -> Iterable[str]:
return NotImplemented

def passages(
self, text: str, *, metadata: Optional[dict[str, Any]] = None
self,
text: str,
*,
metadata: Optional[dict[str, Any]] = None,
strict: bool = False,
) -> Iterable[Passage]:
"""Yield passages from the text.
Args:
text: the text to split into passages.
metadata: the metadata to attach to the passages.
strict: if True, raise an error on invalid metadata. Otherwise, log a warning.
Yields:
Passage: the passages from the text.
"""

for idx, sentence in enumerate(self.split(text)):
metadata = (metadata or {}) | {"sentence_index": idx}

yield Passage(sentence, metadata)
try:
yield Passage(sentence, metadata)
except ValidationError as e:
if strict:
raise e
else:
logging.warning(
"Skipping sentence %d with invalid metadata in '%s': %s",
idx,
metadata.get("provenance"),
e,
)

def passages_from_dict_reader(
self,
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/text/test_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from io import StringIO

import pytest
from pydantic import ValidationError

from tempo_embeddings.text.highlighting import Highlighting
from tempo_embeddings.text.passage import Passage
Expand Down Expand Up @@ -105,6 +106,37 @@ def test_passages(self, text, expected):
passages = SentenceSplitterSegmenter("en").passages(text)
assert list(passages) == expected

@pytest.mark.parametrize(
"text,metadata,strict,expected,exception",
[
("text", {}, True, [Passage("text", metadata={"sentence_index": 0})], None),
(
"text",
{"date": "01-05-1889"},
True,
[Passage("text", metadata={"date": "01-05-1889", "sentence_index": 0})],
None,
),
(
"text",
{"date": "01-05-1889"},
False,
[Passage("text", metadata={"date": "01-05-1889", "sentence_index": 0})],
None,
),
("text", {"date": "99-05-1889"}, True, [], pytest.raises(ValidationError)),
("text", {"date": "99-05-1889"}, False, [], None),
],
)
def test_passages_metadata_strict(
self, text, metadata, strict, expected, exception
):
with exception or does_not_raise():
passages = SentenceSplitterSegmenter("en").passages(
text, metadata=metadata, strict=strict
)
assert list(passages) == expected

@pytest.mark.parametrize(
"_csv, length, provenance, filter_terms, expected",
[
Expand Down

0 comments on commit 9d36c92

Please sign in to comment.