Skip to content

Commit

Permalink
Merge pull request #126 from Semantics-of-Sustainability/fix/staten_g…
Browse files Browse the repository at this point in the history
…eneraal

Fix/staten generaal
  • Loading branch information
carschno authored Oct 23, 2024
2 parents 06a8766 + 684a910 commit 44bb799
Show file tree
Hide file tree
Showing 7 changed files with 494 additions and 355 deletions.
686 changes: 363 additions & 323 deletions notebooks/term_frequency.ipynb

Large diffs are not rendered by default.

69 changes: 64 additions & 5 deletions tempo_embeddings/embeddings/weaviate_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,28 @@ def get_metadata_values(self, collection: str, field: str) -> list[str]:

return [group.grouped_by.value for group in response.groups]

def properties(self, collection: str) -> set[str]:
"""Get the properties of a collection.
Args:
collection (str): The collection name
Returns:
set[str]: The property names
Raises:
ValueError: If the collection does not exist
"""
try:
return {
property.name
for property in self._client.collections.get(collection)
.config.get()
.properties
}
except UnexpectedStatusCodeError as e:
raise ValueError(
f"Error retrieving properties for collection '{collection}': {e}"
) from e

def ingest(
self,
corpus: Corpus,
Expand Down Expand Up @@ -353,12 +375,15 @@ def get_corpus(
filter_words,
YearSpan(year_from, year_to),
metadata_filters,
metadata_not,
QueryBuilder.clean_metadata(metadata_not, self.properties(collection)),
),
include_vector=include_embeddings,
)
passages: tuple[Passage] = tuple(
[Passage.from_weaviate_record(o) for o in response.objects]
[
Passage.from_weaviate_record(o, collection=collection)
for o in response.objects
]
)
label = collection
if passages and filter_words:
Expand Down Expand Up @@ -401,7 +426,9 @@ def doc_frequency(
filters=QueryBuilder.build_filter(
filter_words=search_terms,
metadata=metadata,
metadata_not=metadata_not,
metadata_not=QueryBuilder.clean_metadata(
metadata_not, self.properties(collection)
),
),
total_count=True,
)
Expand Down Expand Up @@ -483,6 +510,11 @@ def query_vector_neighbors(
) -> Iterable[tuple[Passage, float]]:
# TODO: use autocut: https://weaviate.io/developers/weaviate/api/graphql/additional-operators#autocut

if max_neighbors > 10000:
logging.warning(
"Limiting maximum number of neighbors to 10000 (was: %d)", max_neighbors
)
max_neighbors = 10000
wv_collection = self._client.collections.get(collection)
response = wv_collection.query.near_vector(
near_vector=vector,
Expand All @@ -491,12 +523,18 @@ def query_vector_neighbors(
include_vector=True,
return_metadata=MetadataQuery(distance=True),
filters=QueryBuilder.build_filter(
year_span=year_span, metadata_not=metadata_not
year_span=year_span,
metadata_not=QueryBuilder.clean_metadata(
metadata_not, self.properties(collection)
),
),
)

for o in response.objects:
yield Passage.from_weaviate_record(o), o.metadata.distance
yield (
Passage.from_weaviate_record(o, collection=collection),
o.metadata.distance,
)

def query_text_neighbors(
self, collection: Collection, text: list[float], k_neighbors=10
Expand Down Expand Up @@ -750,3 +788,24 @@ def build_filter(
filters.append(Filter.by_property(field).not_equal(value))

return Filter.all_of(filters) if filters else None

@staticmethod
def clean_metadata(
metadata: Optional[dict[str, Any]], collection_properties: set[str]
) -> dict[str, Any]:
"""Remove metadata fields that are not in the collection properties.
Args:
metadata (dict[str, Any]): A metadata dictionary containing field-value pairs
collection_properties (set[str]): The collection properties
Returns:
dict[str, Any]: A new dictionary containing only fields that are in the collection properties
"""
if metadata:
return {
key: value
for key, value in metadata.items()
if key in collection_properties
}
else:
return {}
9 changes: 5 additions & 4 deletions tempo_embeddings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
"/data/datasets/research-semantics-of-sustainability/semantics-of-sustainability/data/"
),
# Snellius:
Path("/home/cschnober/data/"),
Path().home() / "data",
# Yoda drive mounted on MacOS:
Path(
"/Volumes/i-lab.data.uu.nl/research-semantics-of-sustainability/semantics-of-sustainability/data"
),
# Path(
# "/Volumes/i-lab.data.uu.nl/research-semantics-of-sustainability/semantics-of-sustainability/data"
# ),
]
"""Directories in which corpora are stored; the first one found is used."""

Expand All @@ -60,6 +60,7 @@
DEVICE: Optional[str] = os.environ.get("DEVICE")

WEAVIATE_CONFIG_COLLECTION: str = "TempoEmbeddings"
WEAVIATE_SERVERS = [("Research Cloud", "145.38.187.187"), ("local", "localhost")]

STRICT = {"strict": True} if int(platform.python_version_tuple()[1]) >= 10 else {}
"""Optional argument for zip() to enforce strict mode in Python 3.10+."""
14 changes: 10 additions & 4 deletions tempo_embeddings/text/passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Iterable, Optional

import pandas as pd
from dateutil.parser import parse
from dateutil.parser import ParserError, parse

from .highlighting import Highlighting

Expand Down Expand Up @@ -335,21 +335,27 @@ def with_highlighting(self, highlighting: Highlighting) -> "Passage":
)

@classmethod
def from_weaviate_record(cls, _object) -> "Passage":
def from_weaviate_record(cls, _object, *, collection: str) -> "Passage":
"""Create a Passage from a Weaviate object.
Args:
_object: A Weaviate object.
collection: The collection the object belongs to
Returns:
A Passage object.
"""

metadata = _object.properties
metadata = _object.properties | {"collection": collection}

# convert date types; can be removed once Weaviate index has the right types
for field in cls._TYPE_CONVERTERS:
if field in metadata:
metadata[field] = cls._TYPE_CONVERTERS[field](metadata[field])
try:
metadata[field] = cls._TYPE_CONVERTERS[field](metadata[field])
except ParserError as e:
logging.error(
f"Could not convert '{metadata[field]}' in '{field}': {e}"
)

text = metadata.pop("passage")
highlighting = (
Expand Down
16 changes: 11 additions & 5 deletions tempo_embeddings/visualization/jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ class JScatterVisualizer:
def __init__(
self,
corpus,
categorical_fields: list[str] = ["newspaper", "label"],
categorical_fields: list[str] = ["collection", "label"],
continuous_filter_fields: list[str] = ["year"],
tooltip_fields: list[str] = ["year", "text", "label", "top words", "newspaper"],
fillna: dict[str, str] = {"newspaper": "NRC"},
tooltip_fields: list[str] = [
"year",
"text",
"label",
"top words",
"collection",
],
fillna: dict[str, str] = None,
color_by: str = "label",
keyword_extractor: Optional[KeywordExtractor] = None,
):
Expand Down Expand Up @@ -142,7 +148,7 @@ def __init__(
"""Keeps track of filtered indices per filter field."""

self._corpora: list[Corpus] = corpora
self._fillna = fillna
self._fillna = fillna or {}
self._tooltip_fields = tooltip_fields
self._color_by = color_by

Expand Down Expand Up @@ -219,7 +225,7 @@ def _category_field_filter(
logging.warning(f"Categorical field '{field}' not found, ignoring")
return

options = self._df[field].unique().tolist()
options = self._df[field].dropna().unique().tolist()

if len(options) > 1:
selector = widgets.SelectMultiple(
Expand Down
47 changes: 35 additions & 12 deletions tests/unit/embeddings/test_weaviate_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def test_get_corpus(self, weaviate_db_manager_with_data):
for passage, year in zip(
sorted(corpus.passages, key=lambda p: p.metadata["year"]), range(1950, 1956)
):
assert passage.metadata == {"provenance": "test_file", "year": year}
assert passage.metadata == {
"provenance": "test_file",
"year": year,
"collection": "TestCorpus",
}

@pytest.mark.parametrize(
"term, metadata, normalize, expected",
Expand Down Expand Up @@ -124,18 +128,11 @@ def test_neighbours(self, weaviate_db_manager_with_data, corpus, k):

neighbours: Corpus = weaviate_db_manager_with_data.neighbours(sub_corpus, k)

if k + sub_corpus_size >= len(corpus):
expected_passages = set(corpus.passages) ^ set(sub_corpus.passages)
assert set(neighbours.passages) == expected_passages
else:
assert len(neighbours) == k
assert all(passage in corpus.passages for passage in neighbours.passages)
assert all(
passage not in sub_corpus.passages for passage in neighbours.passages
)

assert len(neighbours) == min(
k, len(corpus)
), "Number of neighbours should be k or all passages."
assert neighbours.label == f"TestCorpus {str(k)} neighbours"
assert neighbours.umap is sub_corpus.umap
assert neighbours.umap is sub_corpus.umap, "UMAP model should be inherited"

@pytest.mark.parametrize("k", [0, 1, 2, 3, 4, 5, 10])
@pytest.mark.skip(reason="Weaviate distances are different from distances()")
Expand Down Expand Up @@ -271,6 +268,19 @@ def test_get_metadata_values(
)
assert sorted(values) == sorted(expected)

@pytest.mark.parametrize(
"collection, expected, exception",
[
("TestCorpus", {"provenance", "year", "passage", "highlighting"}, None),
("invalid", {}, pytest.raises(ValueError)),
],
)
def test_properties(
self, weaviate_db_manager_with_data, collection, expected, exception
):
with exception or does_not_raise():
assert weaviate_db_manager_with_data.properties(collection) == expected

def test_validate_config_missing_collection(self, weaviate_db_manager, corpus):
weaviate_db_manager.validate_config() # Empty collection

Expand Down Expand Up @@ -493,3 +503,16 @@ def test_build_filter(
filter_words, year_span, metadata, metadata_not
)
TestQueryBuilder.assert_filter_equals(filter, expected)

@pytest.mark.parametrize(
"metadata, properties, expected",
[
({}, set(), {}),
(None, set(), {}),
({"field1": "value1"}, {"field1"}, {"field1": "value1"}),
({"field1": "value1"}, {"field2"}, {}),
({"field1": "value1", "f2": "v2"}, {"field1"}, {"field1": "value1"}),
],
)
def test_clean_metadata(self, metadata, properties, expected):
assert QueryBuilder.clean_metadata(metadata, properties) == expected
8 changes: 6 additions & 2 deletions tests/unit/text/test_passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ def test_words(self, passage, expected):
reason="Weaviate Embedded not supported on Windows",
)
def test_from_weaviate_record(self, weaviate_db_manager_with_data, test_passages):
collection = "TestCorpus"
objects = (
weaviate_db_manager_with_data._client.collections.get("TestCorpus")
weaviate_db_manager_with_data._client.collections.get(collection)
.query.fetch_objects(include_vector=True)
.objects
)
Expand All @@ -152,4 +153,7 @@ def test_from_weaviate_record(self, weaviate_db_manager_with_data, test_passages
test_passages,
**STRICT,
):
assert Passage.from_weaviate_record(_object) == expected
expected.metadata["collection"] = collection
assert (
Passage.from_weaviate_record(_object, collection=collection) == expected
)

0 comments on commit 44bb799

Please sign in to comment.