Skip to content

Commit

Permalink
Merge pull request #127 from Semantics-of-Sustainability/fix/metadata
Browse files Browse the repository at this point in the history
Fix/metadata
  • Loading branch information
carschno authored Oct 24, 2024
2 parents 44bb799 + bba45dd commit c873786
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 27 deletions.
6 changes: 5 additions & 1 deletion scripts/build_sos_wv_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def arguments_parser():
if args.overwrite:
db.delete_collection(corpus_name)

ingested_files = set(db.get_metadata_values(corpus_name, "provenance"))
try:
ingested_files = set(db.get_metadata_values(corpus_name, "provenance"))
except ValueError as e:
logging.debug(e)
ingested_files = set()
logging.info(f"Skipping {len(ingested_files)} files for '{corpus_name}'.")

corpus_config = corpus_reader[corpus_name]
Expand Down
54 changes: 30 additions & 24 deletions tempo_embeddings/text/passage.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
import datetime
import hashlib
import logging
import string
from typing import Any, Iterable, Optional

import pandas as pd
from dateutil.parser import ParserError, parse
from pydantic import BaseModel, ConfigDict, ValidationError, field_validator

from .highlighting import Highlighting


class Passage:
"""A text passage with optional metadata and highlighting."""

_TYPE_CONVERTERS = {"year": int, "date": parse}
"""When exporting to dict, convert these fields using the specified converter."""
class Metadata(BaseModel):
model_config = ConfigDict(extra="allow")

year: Optional[int] = None
date: Optional[datetime.datetime] = None
sentence_index: Optional[int] = None

@field_validator("date", mode="before")
@classmethod
def parse_date(cls, value) -> datetime.datetime:
if isinstance(value, str):
try:
value = parse(value)
if not value.tzinfo:
value = value.replace(tzinfo=datetime.timezone.utc)
except ParserError as e:
raise ValidationError(e)
return value

def __init__(
self,
Expand All @@ -25,11 +42,11 @@ def __init__(
full_word_spans: Optional[list[tuple[int, int]]] = None,
char2tokens: Optional[list[int]] = None,
unique_id: str = None,
) -> None:
):
# pylint: disable=too-many-arguments
self._text = text.strip()
self._unique_id = unique_id
self._metadata = metadata or {}
self._metadata = Passage.Metadata(**(metadata or {}))
self._highlighting = highlighting
self._embedding = embedding
self._embedding_compressed = embedding_compressed
Expand All @@ -50,7 +67,7 @@ def tokenized_text(self) -> str:

@property
def metadata(self) -> dict:
return self._metadata
return self._metadata.model_dump(exclude_none=True)

@property
def highlighting(self) -> Optional[Highlighting]:
Expand Down Expand Up @@ -180,7 +197,7 @@ def hover_data(

return {"text": self.highlighted_text()} | metadata

def to_dict(self) -> pd.DataFrame:
def to_dict(self) -> dict[str, Any]:
"""Returns a dictionary representation of the passage."""

# TODO: merge with hover_data()
Expand All @@ -202,7 +219,7 @@ def set_metadata(self, key: str, value: Any) -> None:
key: The metadata key to set.
value: The value to set the metadata key to.
"""
self._metadata[key] = value
setattr(self._metadata, key, value)

def word_span(self, start, end) -> tuple[int, int]:
word_index = self.tokenization.char_to_word(start)
Expand Down Expand Up @@ -241,20 +258,20 @@ def __len__(self) -> int:
def __hash__(self) -> int:
return (
hash(self._text)
+ hash(frozenset(self._metadata.keys()))
+ hash(frozenset(self._metadata.values()))
+ hash(frozenset(self.metadata.keys()))
+ hash(frozenset(self.metadata.values()))
+ hash(self._highlighting)
)

def __eq__(self, other: object) -> bool:
return (
isinstance(other, Passage)
and self._text == other._text
and self._metadata == other._metadata
and self.metadata == other.metadata
)

def __repr__(self) -> str:
return f"Passage({self._text!r}, {self._metadata!r}, {self._highlighting!r})"
return f"Passage({self._text!r}, {self.metadata!r}, {self._highlighting!r})"

def _partial_match(self, token: str, case_sensitive) -> Iterable[Highlighting]:
if case_sensitive:
Expand Down Expand Up @@ -328,7 +345,7 @@ def with_highlighting(self, highlighting: Highlighting) -> "Passage":
raise RuntimeError("Passage already has a highlighting.")
return Passage(
text=self._text,
metadata=self._metadata,
metadata=self.metadata,
highlighting=highlighting,
full_word_spans=self.full_word_spans,
char2tokens=self.char2tokens,
Expand All @@ -346,17 +363,6 @@ def from_weaviate_record(cls, _object, *, collection: str) -> "Passage":
"""

metadata = _object.properties | {"collection": collection}

# convert date types; can be removed once Weaviate index has the right types
for field in cls._TYPE_CONVERTERS:
if field in metadata:
try:
metadata[field] = cls._TYPE_CONVERTERS[field](metadata[field])
except ParserError as e:
logging.error(
f"Could not convert '{metadata[field]}' in '{field}': {e}"
)

text = metadata.pop("passage")
highlighting = (
Highlighting.from_string(metadata.pop("highlighting"))
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/text/test_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ def test_to_dataframe(self, corpus):
"ID_DB": _id,
"highlight_start": 1,
"highlight_end": 3,
"provenance": "test_file",
"year": year,
"provenance": "test_file",
"x": 0.0,
"y": 0.0,
"corpus": "TestCorpus",
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/text/test_passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def test_init(self, text, metadata, expected):
def test_words(self, passage, expected):
assert list(passage.words()) == expected

def test_set_metadata(self):
passage = Passage("test", metadata={"key": "value"})
passage.set_metadata("test_key", "test_value")

assert passage.metadata == {"key": "value", "test_key": "test_value"}

@pytest.mark.xfail(
platform.system() == "Windows",
raises=WeaviateStartUpError,
Expand All @@ -153,7 +159,7 @@ def test_from_weaviate_record(self, weaviate_db_manager_with_data, test_passages
test_passages,
**STRICT,
):
expected.metadata["collection"] = collection
expected.set_metadata("collection", collection)
assert (
Passage.from_weaviate_record(_object, collection=collection) == expected
)

0 comments on commit c873786

Please sign in to comment.