From 44da3d2ed2e96adec8121c64923f0f0db3155b54 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Thu, 13 Jun 2024 09:58:30 +0200 Subject: [PATCH] [BUGFIX] `argilla-server`: Query on response values without an user (#5003) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR includes a temporal workaround to fix searching records when filtering with response values without providing any user. The solution uses index mapping to identify field paths to use in the query. Those fields are combined into a single bool query with an OR (should) operator. **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [X] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I followed the style guidelines of this project - [ ] I did a self-review of my code - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --- argilla-server/.gitignore | 4 + argilla-server/pdm.lock | 2 +- .../argilla_server/search_engine/commons.py | 83 +++++++++++++++---- .../tests/unit/search_engine/test_commons.py | 69 +++++++++++++-- 4 files changed, 137 insertions(+), 21 deletions(-) diff --git a/argilla-server/.gitignore b/argilla-server/.gitignore index 965d4f8682..9485964133 100644 --- a/argilla-server/.gitignore +++ b/argilla-server/.gitignore @@ -164,3 +164,7 @@ cython_debug/ # Misc .DS_Store + + +# Generated static files +src/argilla_server/static/ \ No newline at end of file diff --git a/argilla-server/pdm.lock b/argilla-server/pdm.lock index 03d7fc5bf5..d7236facac 100644 --- a/argilla-server/pdm.lock +++ b/argilla-server/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "test", "postgresql"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:0cc19bee819fb8659dba34b024dad5b736971274d6b62077f7edd155afe2c1c9" +content_hash = "sha256:495a399675d75cf0686a39118a520e835a978ab18916ab467c349f4703629122" [[package]] name = "aiofiles" diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 928162ad80..22af1f12fa 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -17,6 +17,9 @@ from typing import Any, Dict, Iterable, List, Optional, Union from uuid import UUID +from elasticsearch8 import AsyncElasticsearch +from opensearchpy import AsyncOpenSearch + from argilla_server.enums import FieldType, MetadataPropertyType, RecordSortField, ResponseStatusFilter, SimilarityOrder from argilla_server.models import ( Dataset, @@ -253,6 +256,34 @@ def is_response_status_scope(scope: FilterScope) -> bool: return isinstance(scope, ResponseFilterScope) and scope.property == "status" and scope.question is None +def is_response_value_scope_without_user(scope: FilterScope) -> bool: + return ( + isinstance(scope, ResponseFilterScope) + and scope.user is None + and scope.question is not None + and (scope.property is None or scope.property == "value") + ) + + +def _get_response_value_fields_for_question(index_mapping: dict, question: str) -> List[str]: + """This function helper use the index mapping retrieved using client.get_mapping method to get all the defined + properties to extract the defined fields for a specific question. The number of fields will depend on the number + of users that have answered the question. + + This is a workaround to fix errors when querying response value without user and it will be removed once we review + mappings for responses. + """ + + mapping_def = next(iter(index_mapping.values())) + mapping_properties: Dict[str, Any] = mapping_def["mappings"]["properties"] + + response_fields = [] + for user_id, user_responses in mapping_properties["responses"]["properties"].items(): + if question in user_responses["properties"]["values"]["properties"]: + response_fields.append(es_field_for_response_value(User(id=UUID(user_id)), question=question)) + return response_fields + + @dataclasses.dataclass class BaseElasticAndOpenSearchEngine(SearchEngine): """ @@ -279,6 +310,8 @@ class BaseElasticAndOpenSearchEngine(SearchEngine): # See https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-settings-limit.html#mapping-settings-limit default_total_fields_limit: int = 2000 + client: Union[AsyncElasticsearch, AsyncOpenSearch] = dataclasses.field(init=False) + async def create_index(self, dataset: Dataset): settings = self._configure_index_settings() mappings = self._configure_index_mappings(dataset) @@ -397,6 +430,7 @@ async def similarity_search( if bool(value) == bool(record): raise ValueError("Must provide either vector value or record to compute the similarity search") + index = await self._get_dataset_index(dataset) vector_value = value record_id = None @@ -412,13 +446,13 @@ async def similarity_search( query_filters = [] if filter: + index_mapping = self.client.indices.get_mapping(index=index) # Wrapping filter in a list to use easily on each engine implementation - query_filters = [self.build_elasticsearch_filter(filter)] + query_filters = [self.build_elasticsearch_filter(filter, index_mapping)] if query: query_filters.append(self._build_text_query(dataset, text=query)) - index = await self._get_dataset_index(dataset) response = await self._request_similarity_search( index=index, vector_settings=vector_settings, @@ -430,9 +464,9 @@ async def similarity_search( return await self._process_search_response(response, threshold) - def build_elasticsearch_filter(self, filter: Filter) -> Dict[str, Any]: + def build_elasticsearch_filter(self, filter: Filter, index_mapping: dict) -> Dict[str, Any]: if isinstance(filter, AndFilter): - filters = [self.build_elasticsearch_filter(f) for f in filter.filters] + filters = [self.build_elasticsearch_filter(f, index_mapping) for f in filter.filters] return es_bool_query(should=filters, minimum_should_match=len(filters)) # This is a special case for response status filter, since it's compound by multiple filters @@ -442,14 +476,13 @@ def build_elasticsearch_filter(self, filter: Filter) -> Dict[str, Any]: ) return self._build_response_status_filter(status_filter) - es_field = self._scope_to_elasticsearch_field(filter.scope) + # This case is a workaround to fix errors when querying response value without user. + # Once we review mappings for responses, we should remove this. + if is_response_value_scope_without_user(filter.scope): + return self._build_response_value_filter_without_user(filter, index_mapping) - if isinstance(filter, TermsFilter): - return es_terms_query(es_field, values=filter.values) - elif isinstance(filter, RangeFilter): - return es_range_query(es_field, gte=filter.ge, lte=filter.le) - else: - raise ValueError(f"Cannot process request for filter {filter}") + es_field = self._scope_to_elasticsearch_field(filter.scope) + return self._map_filter_to_es_filter(filter, es_field) def build_elasticsearch_sort(self, sort: List[Order]) -> str: sort_config = [] @@ -490,6 +523,28 @@ def _build_response_status_filter(status_filter: UserResponseStatusFilter) -> Di return {"bool": {"should": filters, "minimum_should_match": 1}} + def _build_response_value_filter_without_user(self, filter: Filter, index_mapping: dict) -> dict: + """This is a workaround to fix errors when querying response value without user and consist on + combining all the filters for each user in a bool query using an OR operator. + + This should be removed once we review mappings for responses. + """ + question_response_fields = _get_response_value_fields_for_question(index_mapping, filter.scope.question) + + all_user_filters = [self._map_filter_to_es_filter(filter, field) for field in question_response_fields] + + if all_user_filters: + return es_bool_query(should=all_user_filters, minimum_should_match=1) + return {} + + def _map_filter_to_es_filter(self, filter: Filter, es_field: str) -> dict: + if isinstance(filter, TermsFilter): + return es_terms_query(es_field, values=filter.values) + elif isinstance(filter, RangeFilter): + return es_range_query(es_field, gte=filter.ge, lte=filter.le) + else: + raise ValueError(f"Cannot process request for filter {filter}") + def _inverse_vector(self, vector_value: List[float]) -> List[float]: return [vector_value[i] * -1 for i in range(0, len(vector_value))] @@ -583,12 +638,14 @@ async def search( if sort_by: sort = _unify_sort_by_with_order(sort_by, sort) # END TODO + index = await self._get_dataset_index(dataset) text_query = self._build_text_query(dataset, text=query) bool_query: Dict[str, Any] = {"must": [text_query]} if filter: - bool_query["filter"] = self.build_elasticsearch_filter(filter) + index_mapping = await self.client.indices.get_mapping(index=index) + bool_query["filter"] = self.build_elasticsearch_filter(filter, index_mapping) es_query = {"bool": bool_query} @@ -603,8 +660,6 @@ async def search( } } - index = await self._get_dataset_index(dataset) - es_sort = self.build_elasticsearch_sort(sort) if sort else None response = await self._index_search_request(index, query=es_query, size=limit, from_=offset, sort=es_sort) diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index 4d4757c161..0c0f673ff7 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -21,6 +21,7 @@ from argilla_server.search_engine import ( FloatMetadataFilter, IntegerMetadataFilter, + ResponseFilterScope, SortBy, SuggestionFilterScope, TermsFilter, @@ -92,8 +93,8 @@ async def dataset_for_pagination(opensearch: OpenSearch): @pytest_asyncio.fixture(scope="function") @pytest.mark.asyncio async def test_banking_sentiment_dataset_non_indexed(): - text_question = await TextQuestionFactory() - rating_question = await RatingQuestionFactory() + text_question = await TextQuestionFactory(name="text") + rating_question = await RatingQuestionFactory(name="rating") dataset = await DatasetFactory.create( fields=[ @@ -594,6 +595,50 @@ async def test_search_with_response_status_filter( assert len(result.items) == expected_items assert result.total == expected_items + async def test_search_with_response_value_without_user( + self, + search_engine: BaseElasticAndOpenSearchEngine, + opensearch: OpenSearch, + test_banking_sentiment_dataset: Dataset, + ): + user = await UserFactory.create() + await self._configure_record_responses( + opensearch, + test_banking_sentiment_dataset, + [ResponseStatusFilter.draft], + number_of_answered_records=2, + user=user, + rating_value=2, + ) + + another_user = await UserFactory.create() + await self._configure_record_responses( + opensearch, + test_banking_sentiment_dataset, + [ResponseStatusFilter.draft], + number_of_answered_records=3, + user=another_user, + rating_value=4, + ) + + results_for_user = await search_engine.search( + test_banking_sentiment_dataset, + filter=TermsFilter(ResponseFilterScope(question="rating", user=None), values=["2"]), + ) + assert results_for_user.total == 2 + + results_for_another_user = await search_engine.search( + test_banking_sentiment_dataset, + filter=TermsFilter(ResponseFilterScope(question="rating", user=None), values=["4"]), + ) + assert results_for_another_user.total == 3 + + combined_results = await search_engine.search( + test_banking_sentiment_dataset, + filter=TermsFilter(ResponseFilterScope(question="rating", user=None), values=["2", "4"]), + ) + assert combined_results.total == 3 + @pytest.mark.parametrize( "statuses, expected_items", [ @@ -707,7 +752,7 @@ async def test_search_with_response_status_filter_does_not_affect_the_result_sco @pytest.mark.parametrize( "property, filter_match_value, filter_unmatch_value", - [("value", "A", "C"), ("score", 0.5, 0), ("agent", "peter", "john"), ("type", "human", "model")], + [("value", "A", "C"), ("score", "0.5", "0"), ("agent", "peter", "john"), ("type", "human", "model")], ) async def test_search_with_suggestion_filter( self, @@ -1377,6 +1422,7 @@ async def _configure_record_responses( response_status: List[ResponseStatusFilter], number_of_answered_records: int, user: Optional[User] = None, + rating_value: Optional[int] = None, ): index_name = es_index_name_for_dataset(dataset) @@ -1392,7 +1438,9 @@ async def _configure_record_responses( # Create two responses with the same status (one in each record) for i, status in enumerate(response_status): if status != ResponseStatusFilter.missing: - await self._update_records_responses(opensearch, index_name, selected_records, status, user) + await self._update_records_responses( + opensearch, index_name, selected_records, status, user, rating_value + ) for status in all_statuses: if status not in response_status and status != ResponseStatusFilter.missing: @@ -1407,11 +1455,20 @@ async def _update_records_responses( records: List[Record], status: ResponseStatusFilter, user: Optional[User] = None, + rating_value: Optional[int] = None, ): another_user = await UserFactory.create() for record in records: - users_responses = {f"{another_user.id}.status": status.value} + users_responses = { + f"{another_user.id}.status": status.value, + f"{another_user.id}.values.rating": -1, + } if user: - users_responses.update({f"{user.id}.status": status.value}) + users_responses.update( + { + f"{user.id}.status": status.value, + f"{user.id}.values.rating": rating_value or -1, + } + ) opensearch.update(index_name, id=record.id, body={"doc": {"responses": users_responses}})