Skip to content

Commit

Permalink
[BUGFIX] Filter record metadata value based on metadata property poli…
Browse files Browse the repository at this point in the history
…cies (#4906)

<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR fixes metadata visibility errors when fetching records for a
user.


**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored May 29, 2024
1 parent 23c3d54 commit 65e2627
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import re
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
Expand All @@ -29,7 +29,7 @@
from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE
from argilla_server.models import Dataset as DatasetModel
from argilla_server.models import Record, User
from argilla_server.policies import DatasetPolicyV1, authorize
from argilla_server.policies import DatasetPolicyV1, RecordPolicyV1, authorize, is_authorized
from argilla_server.schemas.v1.datasets import Dataset
from argilla_server.schemas.v1.records import (
Filters,
Expand Down Expand Up @@ -83,7 +83,6 @@
_VALID_SORT_VALUES = tuple(sort.value for sort in SortOrder)
_METADATA_PROPERTY_SORT_BY_REGEX = re.compile(r"^metadata\.(?P<name>(?=.*[a-z0-9])[a-z0-9_-]+)$")


SortByQueryParamParsed = Annotated[
Dict[str, str],
Depends(
Expand Down Expand Up @@ -410,7 +409,7 @@ async def list_current_user_dataset_records(
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
current_user: User = Security(auth.get_current_user),
):
dataset = await _get_dataset_or_raise(db, dataset_id)
dataset = await _get_dataset_or_raise(db, dataset_id, with_metadata_properties=True)

await authorize(current_user, DatasetPolicyV1.get(dataset))

Expand All @@ -427,6 +426,10 @@ async def list_current_user_dataset_records(
sort_by_query_param=sort_by_query_param,
)

for record in records:
record.dataset = dataset
record.metadata_ = await _filter_record_metadata_for_user(record, current_user)

return Records(items=records, total=total)


Expand Down Expand Up @@ -570,8 +573,7 @@ async def search_current_user_dataset_records(
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
current_user: User = Security(auth.get_current_user),
):
dataset = await _get_dataset_or_raise(db, dataset_id, with_fields=True)

dataset = await _get_dataset_or_raise(db, dataset_id, with_fields=True, with_metadata_properties=True)
await authorize(current_user, DatasetPolicyV1.search_records(dataset))

await _validate_search_records_query(db, body, dataset_id)
Expand All @@ -589,7 +591,7 @@ async def search_current_user_dataset_records(
sort_by_query_param=sort_by_query_param,
)

record_id_score_map = {
record_id_score_map: Dict[UUID, Dict[str, Union[float, SearchRecord, None]]] = {
response.record_id: {"query_score": response.score, "search_record": None}
for response in search_responses.items
}
Expand All @@ -603,6 +605,9 @@ async def search_current_user_dataset_records(
)

for record in records:
record.dataset = dataset
record.metadata_ = await _filter_record_metadata_for_user(record, current_user)

record_id_score_map[record.id]["search_record"] = SearchRecord(
record=RecordSchema.from_orm(record), query_score=record_id_score_map[record.id]["query_score"]
)
Expand Down Expand Up @@ -698,3 +703,14 @@ async def list_dataset_records_search_suggestions_options(
for sa in suggestion_agents_by_question
]
)


async def _filter_record_metadata_for_user(record: Record, user: User) -> Optional[Dict[str, Any]]:
if record.metadata_ is None:
return None

metadata = {}
for metadata_name in list(record.metadata_.keys()):
if await is_authorized(user, RecordPolicyV1.get_metadata(record, metadata_name)):
metadata[metadata_name] = record.metadata_[metadata_name]
return metadata
16 changes: 16 additions & 0 deletions argilla-server/src/argilla_server/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,22 @@ async def is_allowed(actor: User) -> bool:

return is_allowed

@classmethod
def get_metadata(cls, record: Record, metadata_name: str):
async def is_allowed(actor: User) -> bool:
if actor.is_owner:
return True

metadata_property = record.dataset.metadata_property_by_name(metadata_name)
if metadata_property:
return await is_authorized(actor, MetadataPropertyPolicyV1.get(metadata_property))

return actor.is_admin and await _exists_workspace_user_by_user_and_workspace_id(
actor, record.dataset.workspace_id
)

return is_allowed


class ResponsePolicyV1:
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,177 @@
from uuid import UUID

import pytest
from argilla_server.constants import API_KEY_HEADER_NAME
from argilla_server.enums import UserRole
from argilla_server.search_engine import SearchEngine, SearchResponseItem, SearchResponses
from httpx import AsyncClient

from tests.factories import DatasetFactory, RecordFactory, TextFieldFactory, VectorFactory, VectorSettingsFactory
from tests.factories import (
AdminFactory,
AnnotatorFactory,
DatasetFactory,
RecordFactory,
TermsMetadataPropertyFactory,
TextFieldFactory,
VectorFactory,
VectorSettingsFactory,
WorkspaceUserFactory,
)


@pytest.mark.asyncio
class TestSearchCurrentUserDatasetRecords:
def url(self, dataset_id: UUID) -> str:
return f"/api/v1/me/datasets/{dataset_id}/records/search"

async def test_search_with_filtered_metadata(
self, async_client: AsyncClient, mock_search_engine: SearchEngine, owner_auth_header: dict
):
dataset = await DatasetFactory.create()

await TextFieldFactory.create(name="input", dataset=dataset)
await TermsMetadataPropertyFactory.create(
name="annotator_meta", dataset=dataset, allowed_roles=[UserRole.admin, UserRole.annotator]
)
await TermsMetadataPropertyFactory.create(name="admin_meta", dataset=dataset, allowed_roles=[UserRole.admin])
await TermsMetadataPropertyFactory.create(name="owner_meta", dataset=dataset, allowed_roles=[])
record = await RecordFactory.create(
metadata_={"admin_meta": "value", "annotator_meta": "value", "owner_meta": "value", "extra": "value"},
dataset=dataset,
)

mock_search_engine.search.return_value = SearchResponses(
items=[SearchResponseItem(record_id=record.id, score=1.0)],
total=1,
)

response = await async_client.post(
self.url(dataset.id),
headers=owner_auth_header,
json={"query": {}},
)

assert response.status_code == 200
assert response.json() == {
"items": [
{
"record": {
"id": str(record.id),
"fields": record.fields,
"metadata": record.metadata_,
"external_id": record.external_id,
"dataset_id": str(dataset.id),
"inserted_at": record.inserted_at.isoformat(),
"updated_at": record.updated_at.isoformat(),
},
"query_score": 1.0,
}
],
"total": 1,
}

async def test_search_with_filtered_metadata_as_annotator(
self,
async_client: AsyncClient,
mock_search_engine: SearchEngine,
):
user = await AnnotatorFactory.create()
dataset = await DatasetFactory.create()
await WorkspaceUserFactory.create(user_id=user.id, workspace_id=dataset.workspace_id)

await TextFieldFactory.create(name="input", dataset=dataset)
await TermsMetadataPropertyFactory.create(
name="annotator_meta", dataset=dataset, allowed_roles=[UserRole.admin, UserRole.annotator]
)
await TermsMetadataPropertyFactory.create(name="admin_meta", dataset=dataset, allowed_roles=[UserRole.admin])
await TermsMetadataPropertyFactory.create(name="owner_meta", dataset=dataset, allowed_roles=[])

record = await RecordFactory.create(
metadata_={"admin_meta": "value", "annotator_meta": "value", "owner_meta": "value", "extra": "value"},
dataset=dataset,
)

mock_search_engine.search.return_value = SearchResponses(
items=[SearchResponseItem(record_id=record.id, score=1.0)],
total=1,
)

response = await async_client.post(
self.url(dataset.id),
headers={API_KEY_HEADER_NAME: user.api_key},
json={"query": {}},
)

assert response.status_code == 200
assert response.json() == {
"items": [
{
"record": {
"id": str(record.id),
"fields": record.fields,
"metadata": {"annotator_meta": "value"},
"external_id": record.external_id,
"dataset_id": str(dataset.id),
"inserted_at": record.inserted_at.isoformat(),
"updated_at": record.updated_at.isoformat(),
},
"query_score": 1.0,
}
],
"total": 1,
}

async def test_search_with_filtered_metadata_as_admin(
self,
async_client: AsyncClient,
mock_search_engine: SearchEngine,
):
dataset = await DatasetFactory.create()

user = await AdminFactory.create()
await WorkspaceUserFactory.create(user_id=user.id, workspace_id=dataset.workspace_id)

await TextFieldFactory.create(name="input", dataset=dataset)
await TermsMetadataPropertyFactory.create(
name="annotator_meta", dataset=dataset, allowed_roles=[UserRole.admin, UserRole.annotator]
)
await TermsMetadataPropertyFactory.create(name="admin_meta", dataset=dataset, allowed_roles=[UserRole.admin])
await TermsMetadataPropertyFactory.create(name="owner_meta", dataset=dataset, allowed_roles=[])
record = await RecordFactory.create(
metadata_={"admin_meta": "value", "annotator_meta": "value", "owner_meta": "value", "extra": "value"},
dataset=dataset,
)

mock_search_engine.search.return_value = SearchResponses(
items=[SearchResponseItem(record_id=record.id, score=1.0)],
total=1,
)

response = await async_client.post(
self.url(dataset.id),
headers={API_KEY_HEADER_NAME: user.api_key},
json={"query": {}},
)

assert response.status_code == 200
assert response.json() == {
"items": [
{
"record": {
"id": str(record.id),
"fields": record.fields,
"metadata": {"admin_meta": "value", "annotator_meta": "value", "extra": "value"},
"external_id": record.external_id,
"dataset_id": str(dataset.id),
"inserted_at": record.inserted_at.isoformat(),
"updated_at": record.updated_at.isoformat(),
},
"query_score": 1.0,
}
],
"total": 1,
}

async def test_with_vector_query_using_record_without_vector(
self, async_client: AsyncClient, owner_auth_header: dict
):
Expand Down
Loading

0 comments on commit 65e2627

Please sign in to comment.