diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index ddd5f96013..b79ec7ce70 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -16,6 +16,10 @@ These are the section headers that we use: ## [Unreleased]() +### Added + +- Added new `metadata` attribute for endpoints getting, creating and updating Datasets so now it is possible to store metadata associated to a dataset. ([#5586](https://github.com/argilla-io/argilla/pull/5586)) + ### Changed - Now it is possible to publish a dataset without required fields. Allowing being published with at least one field (required or not). ([#5569](https://github.com/argilla-io/argilla/pull/5569)) diff --git a/argilla-server/src/argilla_server/alembic/versions/660d6c6b3360_add_metadata_column_to_datasets_table.py b/argilla-server/src/argilla_server/alembic/versions/660d6c6b3360_add_metadata_column_to_datasets_table.py new file mode 100644 index 0000000000..8fd3b5b0bf --- /dev/null +++ b/argilla-server/src/argilla_server/alembic/versions/660d6c6b3360_add_metadata_column_to_datasets_table.py @@ -0,0 +1,39 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""add metadata column to datasets table + +Revision ID: 660d6c6b3360 +Revises: 237f7c674d74 +Create Date: 2024-10-04 16:47:21.611404 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "660d6c6b3360" +down_revision = "237f7c674d74" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column("datasets", sa.Column("metadata", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("datasets", "metadata") diff --git a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py index 2becb6c1f2..f5c757e862 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py @@ -13,12 +13,13 @@ # limitations under the License. from datetime import datetime -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional, Union, Dict, Any from uuid import UUID from argilla_server.api.schemas.v1.commons import UpdateSchema from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus from argilla_server.pydantic_v1 import BaseModel, Field, constr +from argilla_server.pydantic_v1.utils import GetterDict try: from typing import Annotated @@ -104,6 +105,14 @@ class UsersProgress(BaseModel): users: List[UserProgress] +class DatasetGetterDict(GetterDict): + def get(self, key: str, default: Any) -> Any: + if key == "metadata": + return getattr(self._obj, "metadata_", None) + + return super().get(key, default) + + class Dataset(BaseModel): id: UUID name: str @@ -111,6 +120,7 @@ class Dataset(BaseModel): allow_extra_metadata: bool status: DatasetStatus distribution: DatasetDistribution + metadata: Optional[Dict[str, Any]] workspace_id: UUID last_activity_at: datetime inserted_at: datetime @@ -118,6 +128,7 @@ class Dataset(BaseModel): class Config: orm_mode = True + getter_dict = DatasetGetterDict class Datasets(BaseModel): @@ -132,6 +143,7 @@ class DatasetCreate(BaseModel): strategy=DatasetDistributionStrategy.overlap, min_submitted=1, ) + metadata: Optional[Dict[str, Any]] = None workspace_id: UUID @@ -140,5 +152,6 @@ class DatasetUpdate(UpdateSchema): guidelines: Optional[DatasetGuidelines] allow_extra_metadata: Optional[bool] distribution: Optional[DatasetDistributionUpdate] + metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata") __non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"} diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index d606581c7d..5ff0cd8ae8 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -136,12 +136,13 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) -> return result.scalars().all() -async def create_dataset(db: AsyncSession, dataset_attrs: dict): +async def create_dataset(db: AsyncSession, dataset_attrs: dict) -> Dataset: dataset = Dataset( name=dataset_attrs["name"], guidelines=dataset_attrs["guidelines"], allow_extra_metadata=dataset_attrs["allow_extra_metadata"], distribution=dataset_attrs["distribution"], + metadata_=dataset_attrs["metadata"], workspace_id=dataset_attrs["workspace_id"], ) diff --git a/argilla-server/src/argilla_server/database.py b/argilla-server/src/argilla_server/database.py index 9d8eb7ae50..6441cb3593 100644 --- a/argilla-server/src/argilla_server/database.py +++ b/argilla-server/src/argilla_server/database.py @@ -37,6 +37,7 @@ "1.18": "bda6fe24314e", "1.28": "ca7293c38970", "2.0": "237f7c674d74", + "2.4": "660d6c6b3360", } ) diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 2d61e3836e..bced3416fe 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -344,6 +344,7 @@ class Dataset(DatabaseModel): allow_extra_metadata: Mapped[bool] = mapped_column(default=True, server_default=sql.true()) status: Mapped[DatasetStatus] = mapped_column(DatasetStatusEnum, default=DatasetStatus.draft, index=True) distribution: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON)) + metadata_: Mapped[Optional[dict]] = mapped_column("metadata", JSON, nullable=True) workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True) inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) updated_at: Mapped[datetime] = mapped_column(default=inserted_at_current_value, onupdate=datetime.utcnow) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py index 4261145d0c..ce955a29c9 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py @@ -13,12 +13,15 @@ # limitations under the License. import pytest -from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus -from argilla_server.models import Dataset + +from typing import Any from httpx import AsyncClient from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus +from argilla_server.models import Dataset + from tests.factories import WorkspaceFactory @@ -54,6 +57,7 @@ async def test_create_dataset_with_default_distribution( "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1, }, + "metadata": None, "workspace_id": str(workspace.id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), @@ -91,6 +95,7 @@ async def test_create_dataset_with_overlap_distribution( "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 4, }, + "metadata": None, "workspace_id": str(workspace.id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), @@ -137,3 +142,63 @@ async def test_create_dataset_with_invalid_distribution_strategy( assert response.status_code == 422 assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 + + async def test_create_dataset_with_default_metadata( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 201 + assert response.json()["metadata"] == None + + dataset = (await db.execute(select(Dataset))).scalar_one() + assert dataset.metadata_ == None + + async def test_create_dataset_with_custom_metadata( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "metadata": {"key": "value"}, + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 201 + assert response.json()["metadata"] == {"key": "value"} + + dataset = (await db.execute(select(Dataset))).scalar_one() + assert dataset.metadata_ == {"key": "value"} + + @pytest.mark.parametrize("invalid_metadata", ["invalid_metadata", 123]) + async def test_create_dataset_with_invalid_metadata( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict, invalid_metadata: Any + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "metadata": invalid_metadata, + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py index 91c4f54f17..113cdc3ce1 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from uuid import UUID import pytest @@ -72,29 +73,6 @@ async def test_update_dataset_without_distribution(self, async_client: AsyncClie "min_submitted": 1, } - async def test_update_dataset_without_distribution_for_published_dataset( - self, async_client: AsyncClient, owner_auth_header: dict - ): - dataset = await DatasetFactory.create(status=DatasetStatus.ready) - - response = await async_client.patch( - self.url(dataset.id), - headers=owner_auth_header, - json={"name": "Dataset updated name"}, - ) - - assert response.status_code == 200 - assert response.json()["distribution"] == { - "strategy": DatasetDistributionStrategy.overlap, - "min_submitted": 1, - } - - assert dataset.name == "Dataset updated name" - assert dataset.distribution == { - "strategy": DatasetDistributionStrategy.overlap, - "min_submitted": 1, - } - async def test_update_dataset_distribution_with_invalid_strategy( self, async_client: AsyncClient, owner_auth_header: dict ): @@ -152,3 +130,81 @@ async def test_update_dataset_distribution_as_none(self, async_client: AsyncClie "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1, } + + async def test_update_dataset_metadata(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create(metadata_={"key-a": "value-a", "key-b": "value-b"}) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "metadata": { + "key-a": "value-a-updated", + "key-c": "value-c", + }, + }, + ) + + assert response.status_code == 200 + assert response.json()["metadata"] == { + "key-a": "value-a-updated", + "key-b": "value-b", + "key-c": "value-c", + } + + assert dataset.metadata_ == { + "key-a": "value-a-updated", + "key-b": "value-b", + "key-c": "value-c", + } + + async def test_update_dataset_without_metadata(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create(metadata_={"key": "value"}) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"name": "Dataset updated name"}, + ) + + assert response.status_code == 200 + assert response.json()["metadata"] == {"key": "value"} + + assert dataset.name == "Dataset updated name" + assert dataset.metadata_ == {"key": "value"} + + async def test_update_dataset_with_invalid_metadata(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create(metadata_={"key": "value"}) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"metadata": "invalid_metadata"}, + ) + + assert response.status_code == 422 + assert dataset.metadata_ == {"key": "value"} + + async def test_update_dataset_metadata_as_empty_dict(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create(metadata_={"key": "value"}) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"metadata": {}}, + ) + + assert response.status_code == 200 + assert dataset.metadata_ == {"key": "value"} + + async def test_update_dataset_metadata_as_none(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create(metadata_={"key": "value"}) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"metadata": None}, + ) + + assert response.status_code == 200 + assert dataset.metadata_ == None diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index d362a4666d..dc867e3a6b 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -123,6 +123,7 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1, }, + "metadata": None, "workspace_id": str(dataset_a.workspace_id), "last_activity_at": dataset_a.last_activity_at.isoformat(), "inserted_at": dataset_a.inserted_at.isoformat(), @@ -138,6 +139,7 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1, }, + "metadata": None, "workspace_id": str(dataset_b.workspace_id), "last_activity_at": dataset_b.last_activity_at.isoformat(), "inserted_at": dataset_b.inserted_at.isoformat(), @@ -153,6 +155,7 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1, }, + "metadata": None, "workspace_id": str(dataset_c.workspace_id), "last_activity_at": dataset_c.last_activity_at.isoformat(), "inserted_at": dataset_c.inserted_at.isoformat(), @@ -684,6 +687,7 @@ async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header: "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1, }, + "metadata": None, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), @@ -890,6 +894,7 @@ async def test_create_dataset(self, async_client: "AsyncClient", db: "AsyncSessi "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1, }, + "metadata": None, "workspace_id": str(workspace.id), "last_activity_at": datetime.fromisoformat(response_body["last_activity_at"]).isoformat(), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), @@ -4458,6 +4463,7 @@ async def test_update_dataset(self, async_client: "AsyncClient", db: "AsyncSessi "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1, }, + "metadata": None, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), @@ -4540,7 +4546,10 @@ async def test_update_dataset_as_annotator(self, async_client: "AsyncClient"): response = await async_client.patch( f"/api/v1/datasets/{dataset.id}", headers={API_KEY_HEADER_NAME: user.api_key}, - json={"name": "New Name", "guidelines": "New Guidelines"}, + json={ + "name": "New Name", + "guidelines": "New Guidelines", + }, ) assert response.status_code == 403