diff --git a/src/diracx/db/os/utils.py b/src/diracx/db/os/utils.py index 4c7125f41..2eb842bc0 100644 --- a/src/diracx/db/os/utils.py +++ b/src/diracx/db/os/utils.py @@ -7,12 +7,16 @@ import logging import os from abc import ABCMeta, abstractmethod +from datetime import datetime from typing import Any, AsyncIterator, Self from opensearchpy import AsyncOpenSearch +from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension +OS_DATE_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%z" + logger = logging.getLogger(__name__) @@ -128,10 +132,97 @@ async def upsert(self, doc_id, document) -> None: async def search( self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None ) -> list[dict[str, Any]]: - # TODO: Implement properly + """Search the database for matching results. + + See the DiracX search API documentation for details. + """ + body = {} + if parameters: + body["_source"] = parameters + if search: + body["query"] = apply_search_filters(self.fields, search) + body["sort"] = [] + for sort in sorts: + field_name = sort["parameter"] + field_type = self.fields.get(field_name, {}).get("type") + require_type("sort", field_name, field_type, {"keyword", "long", "date"}) + body["sort"].append({field_name: {"order": sort["direction"]}}) + + params = {} + if page is not None: + params["from"] = (page - 1) * per_page + params["size"] = per_page + response = await self.client.search( - body={"query": {"bool": {"must": [{"term": {"JobID": 798811207}}]}}}, - params=dict(size=per_page), - index=f"{self.index_prefix}*", + body=body, params=params, index=f"{self.index_prefix}*" + ) + hits = [hit["_source"] for hit in response["hits"]["hits"]] + + # Dates are returned as strings, convert them to Python datetimes + for hit in hits: + for field_name in hit: + if field_name not in self.fields: + continue + if self.fields[field_name]["type"] == "date": + hit[field_name] = datetime.strptime(hit[field_name], OS_DATE_FORMAT) + + return hits + + +def require_type(operator, field_name, field_type, allowed_types): + if field_type not in allowed_types: + raise InvalidQueryError( + f"Cannot apply {operator} to {field_name} ({field_type=}, {allowed_types=})" ) - return [hit["_source"] for hit in response["hits"]["hits"]] + + +def apply_search_filters(db_fields, search): + """Build an OpenSearch query from the given DiracX search parameters. + + If the searched parameters cannot be efficiently translated to a query for + OpenSearch an InvalidQueryError exception is raised. + """ + result = { + "must": [], + "must_not": [], + } + for query in search: + field_name = query["parameter"] + field_type = db_fields.get(field_name, {}).get("type") + if field_type is None: + raise InvalidQueryError( + f"Field {field_name} is not included in the index mapping" + ) + + match operator := query["operator"]: + case "eq": + require_type( + operator, field_name, field_type, {"keyword", "long", "date"} + ) + result["must"].append({"term": {field_name: {"value": query["value"]}}}) + case "neq": + require_type( + operator, field_name, field_type, {"keyword", "long", "date"} + ) + result["must_not"].append( + {"term": {field_name: {"value": query["value"]}}} + ) + case "gt": + require_type(operator, field_name, field_type, {"long", "date"}) + result["must"].append({"range": {field_name: {"gt": query["value"]}}}) + case "lt": + require_type(operator, field_name, field_type, {"long", "date"}) + result["must"].append({"range": {field_name: {"lt": query["value"]}}}) + case "in": + require_type( + operator, field_name, field_type, {"keyword", "long", "date"} + ) + result["must"].append({"terms": {field_name: query["values"]}}) + # TODO: Implement like and ilike + # If the pattern is a simple "col like 'abc%'", we can use a prefix query + # Else we need to use a wildcard query where we replace % with * and _ with ? + # This should also need to handle escaping of %/_/*/? + case _: + raise InvalidQueryError(f"Unknown filter {query=}") + + return {"bool": result} diff --git a/src/diracx/db/sql/utils.py b/src/diracx/db/sql/utils.py index 619ebe07e..9a2dc530e 100644 --- a/src/diracx/db/sql/utils.py +++ b/src/diracx/db/sql/utils.py @@ -177,7 +177,9 @@ def apply_search_filters(table, stmt, search): elif query["operator"] == "in": expr = column.in_(query["values"]) elif query["operator"] in "like": - expr = column.like(query["values"]) + expr = column.like(query["value"]) + elif query["operator"] in "ilike": + expr = column.ilike(query["value"]) else: raise InvalidQueryError(f"Unknown filter {query=}") stmt = stmt.where(expr) diff --git a/tests/db/opensearch/conftest.py b/tests/db/opensearch/conftest.py index 417c2d9f5..3f8f89050 100644 --- a/tests/db/opensearch/conftest.py +++ b/tests/db/opensearch/conftest.py @@ -27,7 +27,8 @@ class DummyOSDB(BaseOSDB): fields = { "DateField": {"type": "date"}, - "IntegerField": {"type": "long"}, + "IntField": {"type": "long"}, + "KeywordField0": {"type": "keyword"}, "KeywordField1": {"type": "keyword"}, "KeywordField2": {"type": "keyword"}, "TextField": {"type": "text"}, diff --git a/tests/db/opensearch/test_index_template.py b/tests/db/opensearch/test_index_template.py index c01a516cc..5760ee451 100644 --- a/tests/db/opensearch/test_index_template.py +++ b/tests/db/opensearch/test_index_template.py @@ -9,7 +9,7 @@ DUMMY_DOCUMENT = { "DateField": datetime.now(tz=timezone.utc), - "IntegerField": 1234, + "IntField": 1234, "KeywordField1": "keyword1", "KeywordField2": "keyword two", "TextField": "text value", diff --git a/tests/db/opensearch/test_search.py b/tests/db/opensearch/test_search.py new file mode 100644 index 000000000..da6a4ba19 --- /dev/null +++ b/tests/db/opensearch/test_search.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import pytest + +from diracx.core.exceptions import InvalidQueryError + +from .conftest import DummyOSDB + +DOC1 = { + "DateField": datetime.now(tz=timezone.utc), + "IntField": 1234, + "KeywordField0": "a", + "KeywordField1": "keyword1", + "KeywordField2": "keyword one", + "TextField": "text value", + "UnknownField": "unknown field 1", +} +DOC2 = { + "DateField": datetime.now(tz=timezone.utc) - timedelta(days=1, minutes=34), + "IntField": 679, + "KeywordField0": "c", + "KeywordField1": "keyword1", + "KeywordField2": "keyword two", + "TextField": "another text value", + "UnknownField": "unknown field 2", +} +DOC3 = { + "DateField": datetime.now(tz=timezone.utc) - timedelta(days=1), + "IntField": 42, + "KeywordField0": "b", + "KeywordField1": "keyword2", + "KeywordField2": "keyword two", + "TextField": "yet another text value", +} + + +@pytest.fixture() +async def prefilled_db(dummy_opensearch_db: DummyOSDB): + """Fill the database with dummy records for testing.""" + await dummy_opensearch_db.upsert(798811211, DOC1) + await dummy_opensearch_db.upsert(998811211, DOC2) + await dummy_opensearch_db.upsert(798811212, DOC3) + + # Force a refresh to make sure the documents are available + await dummy_opensearch_db.client.indices.refresh( + index=f"{dummy_opensearch_db.index_prefix}*" + ) + + yield dummy_opensearch_db + + +async def test_specified_parameters(prefilled_db: DummyOSDB): + results = await prefilled_db.search(None, [], []) + assert len(results) == 3 + assert DOC1 in results and DOC2 in results and DOC3 in results + + results = await prefilled_db.search([], [], []) + assert len(results) == 3 + assert DOC1 in results and DOC2 in results and DOC3 in results + + results = await prefilled_db.search(["IntField"], [], []) + expected_results = [] + for doc in [DOC1, DOC2, DOC3]: + expected_doc = {key: doc[key] for key in {"IntField"}} + # Ensure the document is not already in the list + # If it is the all() check below no longer makes sense + assert expected_doc not in expected_results + expected_results.append(expected_doc) + assert len(results) == len(expected_results) + assert all(result in expected_results for result in results) + + results = await prefilled_db.search(["IntField", "UnknownField"], [], []) + expected_results = [ + {"IntField": DOC1["IntField"], "UnknownField": DOC1["UnknownField"]}, + {"IntField": DOC2["IntField"], "UnknownField": DOC2["UnknownField"]}, + {"IntField": DOC3["IntField"]}, + ] + assert len(results) == len(expected_results) + assert all(result in expected_results for result in results) + + +async def test_pagination_asc(prefilled_db: DummyOSDB): + sort = [{"parameter": "IntField", "direction": "asc"}] + + results = await prefilled_db.search(None, [], sort) + assert results == [DOC3, DOC2, DOC1] + + # Pagination has no effect if a specific page isn't requested + results = await prefilled_db.search(None, [], sort, per_page=2) + assert results == [DOC3, DOC2, DOC1] + + results = await prefilled_db.search(None, [], sort, per_page=2, page=1) + assert results == [DOC3, DOC2] + + results = await prefilled_db.search(None, [], sort, per_page=2, page=2) + assert results == [DOC1] + + results = await prefilled_db.search(None, [], sort, per_page=2, page=3) + assert results == [] + + results = await prefilled_db.search(None, [], sort, per_page=1, page=1) + assert results == [DOC3] + + results = await prefilled_db.search(None, [], sort, per_page=1, page=2) + assert results == [DOC2] + + results = await prefilled_db.search(None, [], sort, per_page=1, page=3) + assert results == [DOC1] + + results = await prefilled_db.search(None, [], sort, per_page=1, page=4) + assert results == [] + + +async def test_pagination_desc(prefilled_db: DummyOSDB): + sort = [{"parameter": "IntField", "direction": "desc"}] + + results = await prefilled_db.search(None, [], sort, per_page=2, page=1) + assert results == [DOC1, DOC2] + + results = await prefilled_db.search(None, [], sort, per_page=2, page=2) + assert results == [DOC3] + + +async def test_eq_filter_long(prefilled_db: DummyOSDB): + part = {"parameter": "IntField", "operator": "eq"} + + # Search for an ID which doesn't exist + results = await prefilled_db.search(None, [part | {"value": "78"}], []) + assert results == [] + + # Check the DB contains what we expect when not filtering + results = await prefilled_db.search(None, [], []) + assert len(results) == 3 + assert DOC1 in results + assert DOC2 in results + assert DOC3 in results + + # Search separately for the two documents which do exist + results = await prefilled_db.search(None, [part | {"value": "1234"}], []) + assert results == [DOC1] + results = await prefilled_db.search(None, [part | {"value": "679"}], []) + assert results == [DOC2] + results = await prefilled_db.search(None, [part | {"value": "42"}], []) + assert results == [DOC3] + + +async def test_operators_long(prefilled_db: DummyOSDB): + part = {"parameter": "IntField"} + + query = part | {"operator": "neq", "value": "1234"} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + + query = part | {"operator": "in", "values": ["1234", "42"]} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + + query = part | {"operator": "lt", "value": "1234"} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + + query = part | {"operator": "lt", "value": "679"} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC3["IntField"]} + + query = part | {"operator": "gt", "value": "1234"} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == set() + + query = part | {"operator": "lt", "value": "42"} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == set() + + +async def test_operators_date(prefilled_db: DummyOSDB): + part = {"parameter": "DateField"} + + query = part | {"operator": "eq", "value": DOC3["DateField"]} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC3["IntField"]} + + query = part | {"operator": "neq", "value": DOC2["DateField"]} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + + doc1_time = DOC1["DateField"].strftime("%Y-%m-%dT%H:%M") + doc2_time = DOC2["DateField"].strftime("%Y-%m-%dT%H:%M") + doc3_time = DOC3["DateField"].strftime("%Y-%m-%dT%H:%M") + + query = part | {"operator": "in", "values": [doc1_time, doc2_time]} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC2["IntField"]} + + query = part | {"operator": "lt", "value": doc1_time} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + + query = part | {"operator": "lt", "value": doc3_time} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC2["IntField"]} + + query = part | {"operator": "lt", "value": doc2_time} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == set() + + query = part | {"operator": "gt", "value": doc1_time} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == set() + + query = part | {"operator": "gt", "value": doc3_time} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC1["IntField"]} + + query = part | {"operator": "gt", "value": doc2_time} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + + +@pytest.mark.parametrize( + "date_format", + [ + "%Y-%m-%d", + "%Y-%m-%dT%H", + "%Y-%m-%dT%H:%M", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M:%S.%fZ", + ], +) +async def test_operators_date_partial_doc1(prefilled_db: DummyOSDB, date_format: str): + """Search by datetime without specifying an exact match + + The parameterized date_format argument should match DOC1 but not DOC2 or DOC3. + """ + formatted_date = DOC1["DateField"].strftime(date_format) + + query = {"parameter": "DateField", "operator": "eq", "value": formatted_date} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC1["IntField"]} + + query = {"parameter": "DateField", "operator": "neq", "value": formatted_date} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + + +async def test_operators_keyword(prefilled_db: DummyOSDB): + part = {"parameter": "KeywordField1"} + + query = part | {"operator": "eq", "value": DOC1["KeywordField1"]} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC2["IntField"]} + + query = part | {"operator": "neq", "value": DOC1["KeywordField1"]} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC3["IntField"]} + + part = {"parameter": "KeywordField0"} + + query = part | { + "operator": "in", + "values": [DOC1["KeywordField0"], DOC3["KeywordField0"]], + } + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + + query = part | {"operator": "in", "values": ["missing"]} + results = await prefilled_db.search(["IntField"], [query], []) + assert {x["IntField"] for x in results} == set() + + with pytest.raises(InvalidQueryError): + query = part | {"operator": "lt", "value": "a"} + await prefilled_db.search(["IntField"], [query], []) + + with pytest.raises(InvalidQueryError): + query = part | {"operator": "gt", "value": "a"} + await prefilled_db.search(["IntField"], [query], []) + + +async def test_unknown_operator(prefilled_db: DummyOSDB): + with pytest.raises(InvalidQueryError): + await prefilled_db.search( + None, [{"parameter": "IntField", "operator": "unknown"}], [] + ) + + +async def test_unindexed_field(prefilled_db: DummyOSDB): + with pytest.raises(InvalidQueryError): + await prefilled_db.search( + None, [{"parameter": "UnknownField", "eq": "eq", "value": "foobar"}], [] + ) + + +async def test_sort_long(prefilled_db: DummyOSDB): + results = await prefilled_db.search( + None, [], [{"parameter": "IntField", "direction": "asc"}] + ) + assert results == [DOC3, DOC2, DOC1] + results = await prefilled_db.search( + None, [], [{"parameter": "IntField", "direction": "desc"}] + ) + assert results == [DOC1, DOC2, DOC3] + + +async def test_sort_date(prefilled_db: DummyOSDB): + results = await prefilled_db.search( + None, [], [{"parameter": "DateField", "direction": "asc"}] + ) + assert results == [DOC2, DOC3, DOC1] + results = await prefilled_db.search( + None, [], [{"parameter": "DateField", "direction": "desc"}] + ) + assert results == [DOC1, DOC3, DOC2] + + +async def test_sort_keyword(prefilled_db: DummyOSDB): + results = await prefilled_db.search( + None, [], [{"parameter": "KeywordField0", "direction": "asc"}] + ) + assert results == [DOC1, DOC3, DOC2] + results = await prefilled_db.search( + None, [], [{"parameter": "KeywordField0", "direction": "desc"}] + ) + assert results == [DOC2, DOC3, DOC1] + + +async def test_sort_text(prefilled_db: DummyOSDB): + with pytest.raises(InvalidQueryError): + await prefilled_db.search( + None, [], [{"parameter": "TextField", "direction": "asc"}] + ) + + +async def test_sort_unknown(prefilled_db: DummyOSDB): + with pytest.raises(InvalidQueryError): + await prefilled_db.search( + None, [], [{"parameter": "UnknownField", "direction": "asc"}] + ) + + +async def test_sort_multiple(prefilled_db: DummyOSDB): + results = await prefilled_db.search( + None, + [], + [ + {"parameter": "KeywordField1", "direction": "asc"}, + {"parameter": "IntField", "direction": "asc"}, + ], + ) + assert results == [DOC2, DOC1, DOC3] + + results = await prefilled_db.search( + None, + [], + [ + {"parameter": "KeywordField1", "direction": "asc"}, + {"parameter": "IntField", "direction": "desc"}, + ], + ) + assert results == [DOC1, DOC2, DOC3] + + results = await prefilled_db.search( + None, + [], + [ + {"parameter": "KeywordField1", "direction": "desc"}, + {"parameter": "IntField", "direction": "asc"}, + ], + ) + assert results == [DOC3, DOC2, DOC1] + + results = await prefilled_db.search( + None, + [], + [ + {"parameter": "IntField", "direction": "asc"}, + {"parameter": "KeywordField1", "direction": "asc"}, + ], + ) + assert results == [DOC3, DOC2, DOC1]