Skip to content

Commit

Permalink
[BUGFIX] argilla-server: Query on response values without an user (#…
Browse files Browse the repository at this point in the history
…5003)

<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR includes a temporal workaround to fix searching records when
filtering with response values without providing any user.

The solution uses index mapping to identify field paths to use in the
query. Those fields are combined into a single bool query with an OR
(should) operator.

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [X] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jun 13, 2024
1 parent e7efb09 commit 44da3d2
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 21 deletions.
4 changes: 4 additions & 0 deletions argilla-server/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,7 @@ cython_debug/

# Misc
.DS_Store


# Generated static files
src/argilla_server/static/
2 changes: 1 addition & 1 deletion argilla-server/pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

83 changes: 69 additions & 14 deletions argilla-server/src/argilla_server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from typing import Any, Dict, Iterable, List, Optional, Union
from uuid import UUID

from elasticsearch8 import AsyncElasticsearch
from opensearchpy import AsyncOpenSearch

from argilla_server.enums import FieldType, MetadataPropertyType, RecordSortField, ResponseStatusFilter, SimilarityOrder
from argilla_server.models import (
Dataset,
Expand Down Expand Up @@ -253,6 +256,34 @@ def is_response_status_scope(scope: FilterScope) -> bool:
return isinstance(scope, ResponseFilterScope) and scope.property == "status" and scope.question is None


def is_response_value_scope_without_user(scope: FilterScope) -> bool:
return (
isinstance(scope, ResponseFilterScope)
and scope.user is None
and scope.question is not None
and (scope.property is None or scope.property == "value")
)


def _get_response_value_fields_for_question(index_mapping: dict, question: str) -> List[str]:
"""This function helper use the index mapping retrieved using client.get_mapping method to get all the defined
properties to extract the defined fields for a specific question. The number of fields will depend on the number
of users that have answered the question.
This is a workaround to fix errors when querying response value without user and it will be removed once we review
mappings for responses.
"""

mapping_def = next(iter(index_mapping.values()))
mapping_properties: Dict[str, Any] = mapping_def["mappings"]["properties"]

response_fields = []
for user_id, user_responses in mapping_properties["responses"]["properties"].items():
if question in user_responses["properties"]["values"]["properties"]:
response_fields.append(es_field_for_response_value(User(id=UUID(user_id)), question=question))
return response_fields


@dataclasses.dataclass
class BaseElasticAndOpenSearchEngine(SearchEngine):
"""
Expand All @@ -279,6 +310,8 @@ class BaseElasticAndOpenSearchEngine(SearchEngine):
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-settings-limit.html#mapping-settings-limit
default_total_fields_limit: int = 2000

client: Union[AsyncElasticsearch, AsyncOpenSearch] = dataclasses.field(init=False)

async def create_index(self, dataset: Dataset):
settings = self._configure_index_settings()
mappings = self._configure_index_mappings(dataset)
Expand Down Expand Up @@ -397,6 +430,7 @@ async def similarity_search(
if bool(value) == bool(record):
raise ValueError("Must provide either vector value or record to compute the similarity search")

index = await self._get_dataset_index(dataset)
vector_value = value
record_id = None

Expand All @@ -412,13 +446,13 @@ async def similarity_search(

query_filters = []
if filter:
index_mapping = self.client.indices.get_mapping(index=index)
# Wrapping filter in a list to use easily on each engine implementation
query_filters = [self.build_elasticsearch_filter(filter)]
query_filters = [self.build_elasticsearch_filter(filter, index_mapping)]

if query:
query_filters.append(self._build_text_query(dataset, text=query))

index = await self._get_dataset_index(dataset)
response = await self._request_similarity_search(
index=index,
vector_settings=vector_settings,
Expand All @@ -430,9 +464,9 @@ async def similarity_search(

return await self._process_search_response(response, threshold)

def build_elasticsearch_filter(self, filter: Filter) -> Dict[str, Any]:
def build_elasticsearch_filter(self, filter: Filter, index_mapping: dict) -> Dict[str, Any]:
if isinstance(filter, AndFilter):
filters = [self.build_elasticsearch_filter(f) for f in filter.filters]
filters = [self.build_elasticsearch_filter(f, index_mapping) for f in filter.filters]
return es_bool_query(should=filters, minimum_should_match=len(filters))

# This is a special case for response status filter, since it's compound by multiple filters
Expand All @@ -442,14 +476,13 @@ def build_elasticsearch_filter(self, filter: Filter) -> Dict[str, Any]:
)
return self._build_response_status_filter(status_filter)

es_field = self._scope_to_elasticsearch_field(filter.scope)
# This case is a workaround to fix errors when querying response value without user.
# Once we review mappings for responses, we should remove this.
if is_response_value_scope_without_user(filter.scope):
return self._build_response_value_filter_without_user(filter, index_mapping)

if isinstance(filter, TermsFilter):
return es_terms_query(es_field, values=filter.values)
elif isinstance(filter, RangeFilter):
return es_range_query(es_field, gte=filter.ge, lte=filter.le)
else:
raise ValueError(f"Cannot process request for filter {filter}")
es_field = self._scope_to_elasticsearch_field(filter.scope)
return self._map_filter_to_es_filter(filter, es_field)

def build_elasticsearch_sort(self, sort: List[Order]) -> str:
sort_config = []
Expand Down Expand Up @@ -490,6 +523,28 @@ def _build_response_status_filter(status_filter: UserResponseStatusFilter) -> Di

return {"bool": {"should": filters, "minimum_should_match": 1}}

def _build_response_value_filter_without_user(self, filter: Filter, index_mapping: dict) -> dict:
"""This is a workaround to fix errors when querying response value without user and consist on
combining all the filters for each user in a bool query using an OR operator.
This should be removed once we review mappings for responses.
"""
question_response_fields = _get_response_value_fields_for_question(index_mapping, filter.scope.question)

all_user_filters = [self._map_filter_to_es_filter(filter, field) for field in question_response_fields]

if all_user_filters:
return es_bool_query(should=all_user_filters, minimum_should_match=1)
return {}

def _map_filter_to_es_filter(self, filter: Filter, es_field: str) -> dict:
if isinstance(filter, TermsFilter):
return es_terms_query(es_field, values=filter.values)
elif isinstance(filter, RangeFilter):
return es_range_query(es_field, gte=filter.ge, lte=filter.le)
else:
raise ValueError(f"Cannot process request for filter {filter}")

def _inverse_vector(self, vector_value: List[float]) -> List[float]:
return [vector_value[i] * -1 for i in range(0, len(vector_value))]

Expand Down Expand Up @@ -583,12 +638,14 @@ async def search(
if sort_by:
sort = _unify_sort_by_with_order(sort_by, sort)
# END TODO
index = await self._get_dataset_index(dataset)

text_query = self._build_text_query(dataset, text=query)
bool_query: Dict[str, Any] = {"must": [text_query]}

if filter:
bool_query["filter"] = self.build_elasticsearch_filter(filter)
index_mapping = await self.client.indices.get_mapping(index=index)
bool_query["filter"] = self.build_elasticsearch_filter(filter, index_mapping)

es_query = {"bool": bool_query}

Expand All @@ -603,8 +660,6 @@ async def search(
}
}

index = await self._get_dataset_index(dataset)

es_sort = self.build_elasticsearch_sort(sort) if sort else None
response = await self._index_search_request(index, query=es_query, size=limit, from_=offset, sort=es_sort)

Expand Down
69 changes: 63 additions & 6 deletions argilla-server/tests/unit/search_engine/test_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from argilla_server.search_engine import (
FloatMetadataFilter,
IntegerMetadataFilter,
ResponseFilterScope,
SortBy,
SuggestionFilterScope,
TermsFilter,
Expand Down Expand Up @@ -92,8 +93,8 @@ async def dataset_for_pagination(opensearch: OpenSearch):
@pytest_asyncio.fixture(scope="function")
@pytest.mark.asyncio
async def test_banking_sentiment_dataset_non_indexed():
text_question = await TextQuestionFactory()
rating_question = await RatingQuestionFactory()
text_question = await TextQuestionFactory(name="text")
rating_question = await RatingQuestionFactory(name="rating")

dataset = await DatasetFactory.create(
fields=[
Expand Down Expand Up @@ -594,6 +595,50 @@ async def test_search_with_response_status_filter(
assert len(result.items) == expected_items
assert result.total == expected_items

async def test_search_with_response_value_without_user(
self,
search_engine: BaseElasticAndOpenSearchEngine,
opensearch: OpenSearch,
test_banking_sentiment_dataset: Dataset,
):
user = await UserFactory.create()
await self._configure_record_responses(
opensearch,
test_banking_sentiment_dataset,
[ResponseStatusFilter.draft],
number_of_answered_records=2,
user=user,
rating_value=2,
)

another_user = await UserFactory.create()
await self._configure_record_responses(
opensearch,
test_banking_sentiment_dataset,
[ResponseStatusFilter.draft],
number_of_answered_records=3,
user=another_user,
rating_value=4,
)

results_for_user = await search_engine.search(
test_banking_sentiment_dataset,
filter=TermsFilter(ResponseFilterScope(question="rating", user=None), values=["2"]),
)
assert results_for_user.total == 2

results_for_another_user = await search_engine.search(
test_banking_sentiment_dataset,
filter=TermsFilter(ResponseFilterScope(question="rating", user=None), values=["4"]),
)
assert results_for_another_user.total == 3

combined_results = await search_engine.search(
test_banking_sentiment_dataset,
filter=TermsFilter(ResponseFilterScope(question="rating", user=None), values=["2", "4"]),
)
assert combined_results.total == 3

@pytest.mark.parametrize(
"statuses, expected_items",
[
Expand Down Expand Up @@ -707,7 +752,7 @@ async def test_search_with_response_status_filter_does_not_affect_the_result_sco

@pytest.mark.parametrize(
"property, filter_match_value, filter_unmatch_value",
[("value", "A", "C"), ("score", 0.5, 0), ("agent", "peter", "john"), ("type", "human", "model")],
[("value", "A", "C"), ("score", "0.5", "0"), ("agent", "peter", "john"), ("type", "human", "model")],
)
async def test_search_with_suggestion_filter(
self,
Expand Down Expand Up @@ -1377,6 +1422,7 @@ async def _configure_record_responses(
response_status: List[ResponseStatusFilter],
number_of_answered_records: int,
user: Optional[User] = None,
rating_value: Optional[int] = None,
):
index_name = es_index_name_for_dataset(dataset)

Expand All @@ -1392,7 +1438,9 @@ async def _configure_record_responses(
# Create two responses with the same status (one in each record)
for i, status in enumerate(response_status):
if status != ResponseStatusFilter.missing:
await self._update_records_responses(opensearch, index_name, selected_records, status, user)
await self._update_records_responses(
opensearch, index_name, selected_records, status, user, rating_value
)

for status in all_statuses:
if status not in response_status and status != ResponseStatusFilter.missing:
Expand All @@ -1407,11 +1455,20 @@ async def _update_records_responses(
records: List[Record],
status: ResponseStatusFilter,
user: Optional[User] = None,
rating_value: Optional[int] = None,
):
another_user = await UserFactory.create()

for record in records:
users_responses = {f"{another_user.id}.status": status.value}
users_responses = {
f"{another_user.id}.status": status.value,
f"{another_user.id}.values.rating": -1,
}
if user:
users_responses.update({f"{user.id}.status": status.value})
users_responses.update(
{
f"{user.id}.status": status.value,
f"{user.id}.values.rating": rating_value or -1,
}
)
opensearch.update(index_name, id=record.id, body={"doc": {"responses": users_responses}})

0 comments on commit 44da3d2

Please sign in to comment.