Skip to content

Commit

Permalink
Merge branch 'main' into notify-nightly-failures
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Mar 13, 2024
2 parents 42cff9d + b0b71e4 commit 3889dc2
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 350 deletions.
8 changes: 7 additions & 1 deletion integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def test_write_dataframe(self, document_store: PgvectorDocumentStore):
assert retrieved_docs == docs


<<<<<<< HEAD
def test_init(patches_for_unit_tests, monkeypatch): # noqa: ARG001 patches are not explicitly called but necessary
=======
@pytest.mark.usefixtures("patches_for_unit_tests")
def test_init(monkeypatch):
>>>>>>> main
monkeypatch.setenv("PG_CONN_STR", "some_connection_string")

document_store = PgvectorDocumentStore(
Expand All @@ -63,7 +68,8 @@ def test_init(patches_for_unit_tests, monkeypatch): # noqa: ARG001 patches are
assert document_store.hnsw_ef_search == 50


def test_to_dict(patches_for_unit_tests, monkeypatch): # noqa: ARG001 patches are not explicitly called but necessary
@pytest.mark.usefixtures("patches_for_unit_tests")
def test_to_dict(monkeypatch):
monkeypatch.setenv("PG_CONN_STR", "some_connection_string")

document_store = PgvectorDocumentStore(
Expand Down
6 changes: 3 additions & 3 deletions integrations/pgvector/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock

import pytest
from haystack.dataclasses import Document
from haystack.utils.auth import EnvVarSecret
from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever
Expand Down Expand Up @@ -55,9 +56,8 @@ def test_to_dict(self, mock_store):
},
}

def test_from_dict(
self, patches_for_unit_tests, monkeypatch # noqa:ARG002 patches are not explicitly called but necessary
):
@pytest.mark.usefixtures("patches_for_unit_tests")
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("PG_CONN_STR", "some-connection-string")
t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever"
data = {
Expand Down
2 changes: 1 addition & 1 deletion integrations/weaviate/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ services:
- '8080'
- --scheme
- http
image: semitechnologies/weaviate:1.23.2
image: semitechnologies/weaviate:1.24.1
ports:
- 8080:8080
- 50051:50051
Expand Down
2 changes: 1 addition & 1 deletion integrations/weaviate/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"haystack-ai",
"weaviate-client==3.*",
"weaviate-client",
"haystack-pydoc-tools",
"python-dateutil",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from haystack.errors import FilterError
from pandas import DataFrame

import weaviate
from weaviate.collections.classes.filters import Filter, FilterReturn

def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]:

def convert_filters(filters: Dict[str, Any]) -> FilterReturn:
"""
Convert filters from Haystack format to Weaviate format.
"""
Expand All @@ -14,7 +17,7 @@ def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
raise FilterError(msg)

if "field" in filters:
return {"operator": "And", "operands": [_parse_comparison_condition(filters)]}
return Filter.all_of([_parse_comparison_condition(filters)])
return _parse_logical_condition(filters)


Expand All @@ -29,7 +32,7 @@ def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
"not in": "in",
"AND": "OR",
"OR": "AND",
"NOT": "AND",
"NOT": "OR",
}


Expand All @@ -51,7 +54,13 @@ def _invert_condition(filters: Dict[str, Any]) -> Dict[str, Any]:
return inverted_condition


def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
LOGICAL_OPERATORS = {
"AND": Filter.all_of,
"OR": Filter.any_of,
}


def _parse_logical_condition(condition: Dict[str, Any]) -> FilterReturn:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
Expand All @@ -67,7 +76,7 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
operands.append(_parse_logical_condition(c))
else:
operands.append(_parse_comparison_condition(c))
return {"operator": operator.lower().capitalize(), "operands": operands}
return LOGICAL_OPERATORS[operator](operands)
elif operator == "NOT":
inverted_conditions = _invert_condition(condition)
return _parse_logical_condition(inverted_conditions)
Expand All @@ -76,28 +85,6 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
raise FilterError(msg)


def _infer_value_type(value: Any) -> str:
if value is None:
return "valueNull"

if isinstance(value, bool):
return "valueBoolean"
if isinstance(value, int):
return "valueInt"
if isinstance(value, float):
return "valueNumber"

if isinstance(value, str):
try:
parser.isoparse(value)
return "valueDate"
except ValueError:
return "valueText"

msg = f"Unknown value type {type(value)}"
raise FilterError(msg)


def _handle_date(value: Any) -> str:
if isinstance(value, str):
try:
Expand All @@ -107,25 +94,22 @@ def _handle_date(value: Any) -> str:
return value


def _equal(field: str, value: Any) -> Dict[str, Any]:
def _equal(field: str, value: Any) -> FilterReturn:
if value is None:
return {"path": field, "operator": "IsNull", "valueBoolean": True}
return {"path": field, "operator": "Equal", _infer_value_type(value): _handle_date(value)}
return weaviate.classes.query.Filter.by_property(field).is_none(True)
return weaviate.classes.query.Filter.by_property(field).equal(_handle_date(value))


def _not_equal(field: str, value: Any) -> Dict[str, Any]:
def _not_equal(field: str, value: Any) -> FilterReturn:
if value is None:
return {"path": field, "operator": "IsNull", "valueBoolean": False}
return {
"operator": "Or",
"operands": [
{"path": field, "operator": "NotEqual", _infer_value_type(value): _handle_date(value)},
{"path": field, "operator": "IsNull", "valueBoolean": True},
],
}
return weaviate.classes.query.Filter.by_property(field).is_none(False)

return weaviate.classes.query.Filter.by_property(field).not_equal(
_handle_date(value)
) | weaviate.classes.query.Filter.by_property(field).is_none(True)

def _greater_than(field: str, value: Any) -> Dict[str, Any]:

def _greater_than(field: str, value: Any) -> FilterReturn:
if value is None:
# When the value is None and '>' is used we create a filter that would return a Document
# if it has a field set and not set at the same time.
Expand All @@ -144,10 +128,10 @@ def _greater_than(field: str, value: Any) -> Dict[str, Any]:
if type(value) in [list, DataFrame]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return {"path": field, "operator": "GreaterThan", _infer_value_type(value): _handle_date(value)}
return weaviate.classes.query.Filter.by_property(field).greater_than(_handle_date(value))


def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]:
def _greater_than_equal(field: str, value: Any) -> FilterReturn:
if value is None:
# When the value is None and '>=' is used we create a filter that would return a Document
# if it has a field set and not set at the same time.
Expand All @@ -166,10 +150,10 @@ def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]:
if type(value) in [list, DataFrame]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return {"path": field, "operator": "GreaterThanEqual", _infer_value_type(value): _handle_date(value)}
return weaviate.classes.query.Filter.by_property(field).greater_or_equal(_handle_date(value))


def _less_than(field: str, value: Any) -> Dict[str, Any]:
def _less_than(field: str, value: Any) -> FilterReturn:
if value is None:
# When the value is None and '<' is used we create a filter that would return a Document
# if it has a field set and not set at the same time.
Expand All @@ -188,10 +172,10 @@ def _less_than(field: str, value: Any) -> Dict[str, Any]:
if type(value) in [list, DataFrame]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return {"path": field, "operator": "LessThan", _infer_value_type(value): _handle_date(value)}
return weaviate.classes.query.Filter.by_property(field).less_than(_handle_date(value))


def _less_than_equal(field: str, value: Any) -> Dict[str, Any]:
def _less_than_equal(field: str, value: Any) -> FilterReturn:
if value is None:
# When the value is None and '<=' is used we create a filter that would return a Document
# if it has a field set and not set at the same time.
Expand All @@ -210,22 +194,23 @@ def _less_than_equal(field: str, value: Any) -> Dict[str, Any]:
if type(value) in [list, DataFrame]:
msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='"
raise FilterError(msg)
return {"path": field, "operator": "LessThanEqual", _infer_value_type(value): _handle_date(value)}
return weaviate.classes.query.Filter.by_property(field).less_or_equal(_handle_date(value))


def _in(field: str, value: Any) -> Dict[str, Any]:
def _in(field: str, value: Any) -> FilterReturn:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators"
raise FilterError(msg)

return {"operator": "And", "operands": [_equal(field, v) for v in value]}
return weaviate.classes.query.Filter.by_property(field).contains_any(value)


def _not_in(field: str, value: Any) -> Dict[str, Any]:
def _not_in(field: str, value: Any) -> FilterReturn:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators"
raise FilterError(msg)
return {"operator": "And", "operands": [_not_equal(field, v) for v in value]}
operands = [weaviate.classes.query.Filter.by_property(field).not_equal(v) for v in value]
return Filter.all_of(operands)


COMPARISON_OPERATORS = {
Expand All @@ -240,7 +225,7 @@ def _not_in(field: str, value: Any) -> Dict[str, Any]:
}


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
def _parse_comparison_condition(condition: Dict[str, Any]) -> FilterReturn:
field: str = condition["field"]

if field.startswith("meta."):
Expand All @@ -265,15 +250,11 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
return COMPARISON_OPERATORS[operator](field, value)


def _match_no_document(field: str) -> Dict[str, Any]:
def _match_no_document(field: str) -> FilterReturn:
"""
Returns a filters that will match no Document, this is used to keep the behavior consistent
between different Document Stores.
"""
return {
"operator": "And",
"operands": [
{"path": field, "operator": "IsNull", "valueBoolean": False},
{"path": field, "operator": "IsNull", "valueBoolean": True},
],
}

operands = [weaviate.classes.query.Filter.by_property(field).is_none(val) for val in [False, True]]
return Filter.all_of(operands)
Loading

0 comments on commit 3889dc2

Please sign in to comment.