diff --git a/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py b/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py similarity index 82% rename from argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py rename to argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py index 2cbdc4b744..6a366186bc 100644 --- a/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py +++ b/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""add record metadata column +"""add metadata column to records table Revision ID: 3ff6484f8b37 Revises: ae5522b4c674 @@ -30,12 +30,8 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column("records", sa.Column("metadata", sa.JSON(), nullable=True)) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("records", "metadata") - # ### end Alembic commands ### diff --git a/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py b/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py new file mode 100644 index 0000000000..dd8c3c8d78 --- /dev/null +++ b/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py @@ -0,0 +1,44 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""add distribution column to datasets table + +Revision ID: 45a12f74448b +Revises: ca7293c38970 +Create Date: 2024-06-13 11:23:43.395093 + +""" +import json + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "45a12f74448b" +down_revision = "ca7293c38970" +branch_labels = None +depends_on = None + +DISTRIBUTION_VALUE = json.dumps({"strategy": "overlap", "min_submitted": 1}) + + +def upgrade() -> None: + op.add_column("datasets", sa.Column("distribution", sa.JSON(), nullable=True)) + op.execute(f"UPDATE datasets SET distribution = '{DISTRIBUTION_VALUE}'") + with op.batch_alter_table("datasets") as batch_op: + batch_op.alter_column("distribution", nullable=False) + + +def downgrade() -> None: + op.drop_column("datasets", "distribution") diff --git a/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py b/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py similarity index 81% rename from argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py rename to argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py index 0710902db4..50795b3e50 100644 --- a/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py +++ b/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""add allow_extra_metadata column to dataset table +"""add allow_extra_metadata column to datasets table Revision ID: b8458008b60e Revises: 7cbcccf8b57a @@ -30,14 +30,10 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column( "datasets", sa.Column("allow_extra_metadata", sa.Boolean(), server_default=sa.text("true"), nullable=False) ) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("datasets", "allow_extra_metadata") - # ### end Alembic commands ### diff --git a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py index 5cac33bdb7..219024ce4e 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py @@ -13,11 +13,11 @@ # limitations under the License. from datetime import datetime -from typing import List, Optional +from typing import List, Literal, Optional, Union from uuid import UUID from argilla_server.api.schemas.v1.commons import UpdateSchema -from argilla_server.enums import DatasetStatus +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus from argilla_server.pydantic_v1 import BaseModel, Field, constr try: @@ -44,6 +44,25 @@ ] +class DatasetOverlapDistribution(BaseModel): + strategy: Literal[DatasetDistributionStrategy.overlap] + min_submitted: int + + +DatasetDistribution = DatasetOverlapDistribution + + +class DatasetOverlapDistributionCreate(BaseModel): + strategy: Literal[DatasetDistributionStrategy.overlap] + min_submitted: int = Field( + ge=1, + description="Minimum number of submitted responses to consider a record as completed", + ) + + +DatasetDistributionCreate = DatasetOverlapDistributionCreate + + class RecordMetrics(BaseModel): count: int @@ -74,6 +93,7 @@ class Dataset(BaseModel): guidelines: Optional[str] allow_extra_metadata: bool status: DatasetStatus + distribution: DatasetDistribution workspace_id: UUID last_activity_at: datetime inserted_at: datetime @@ -91,6 +111,10 @@ class DatasetCreate(BaseModel): name: DatasetName guidelines: Optional[DatasetGuidelines] allow_extra_metadata: bool = True + distribution: DatasetDistributionCreate = DatasetOverlapDistributionCreate( + strategy=DatasetDistributionStrategy.overlap, + min_submitted=1, + ) workspace_id: UUID diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 2dbc62a717..ee153aff59 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -136,6 +136,7 @@ async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate): name=dataset_create.name, guidelines=dataset_create.guidelines, allow_extra_metadata=dataset_create.allow_extra_metadata, + distribution=dataset_create.distribution.dict(), workspace_id=dataset_create.workspace_id, ) diff --git a/argilla-server/src/argilla_server/enums.py b/argilla-server/src/argilla_server/enums.py index 13b4843280..7ea01f561b 100644 --- a/argilla-server/src/argilla_server/enums.py +++ b/argilla-server/src/argilla_server/enums.py @@ -43,6 +43,10 @@ class DatasetStatus(str, Enum): ready = "ready" +class DatasetDistributionStrategy(str, Enum): + overlap = "overlap" + + class UserRole(str, Enum): owner = "owner" admin = "admin" diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index b84e81180a..aa45abc556 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -304,6 +304,7 @@ class Dataset(DatabaseModel): guidelines: Mapped[Optional[str]] = mapped_column(Text) allow_extra_metadata: Mapped[bool] = mapped_column(default=True, server_default=sql.true()) status: Mapped[DatasetStatus] = mapped_column(DatasetStatusEnum, default=DatasetStatus.draft, index=True) + distribution: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON)) workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True) inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) updated_at: Mapped[datetime] = mapped_column(default=inserted_at_current_value, onupdate=datetime.utcnow) diff --git a/argilla-server/tests/factories.py b/argilla-server/tests/factories.py index 5c77b9a0f5..c429fed9af 100644 --- a/argilla-server/tests/factories.py +++ b/argilla-server/tests/factories.py @@ -16,7 +16,7 @@ import random import factory -from argilla_server.enums import FieldType, MetadataPropertyType, OptionsOrder +from argilla_server.enums import DatasetDistributionStrategy, FieldType, MetadataPropertyType, OptionsOrder from argilla_server.models import ( Dataset, Field, @@ -203,6 +203,7 @@ class Meta: model = Dataset name = factory.Sequence(lambda n: f"dataset-{n}") + distribution = {"strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1} workspace = factory.SubFactory(WorkspaceFactory) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py new file mode 100644 index 0000000000..7ad7246cda --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py @@ -0,0 +1,119 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus +from argilla_server.models import Dataset +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import WorkspaceFactory + + +@pytest.mark.asyncio +class TestCreateDataset: + def url(self) -> str: + return "/api/v1/datasets" + + async def test_create_dataset_with_default_distribution( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "workspace_id": str(workspace.id), + }, + ) + + dataset = (await db.execute(select(Dataset))).scalar_one() + + assert response.status_code == 201 + assert response.json() == { + "id": str(dataset.id), + "name": "Dataset Name", + "guidelines": None, + "allow_extra_metadata": True, + "status": DatasetStatus.draft, + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + "workspace_id": str(workspace.id), + "last_activity_at": dataset.last_activity_at.isoformat(), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + async def test_create_dataset_with_overlap_distribution( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + "workspace_id": str(workspace.id), + }, + ) + + dataset = (await db.execute(select(Dataset))).scalar_one() + + assert response.status_code == 201 + assert response.json() == { + "id": str(dataset.id), + "name": "Dataset Name", + "guidelines": None, + "allow_extra_metadata": True, + "status": DatasetStatus.draft, + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + "workspace_id": str(workspace.id), + "last_activity_at": dataset.last_activity_at.isoformat(), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + async def test_create_dataset_with_overlap_distribution_using_invalid_min_submitted_value( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset name", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 0, + }, + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index b8aca01cfb..a3e2ed56d2 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -34,6 +34,7 @@ ) from argilla_server.constants import API_KEY_HEADER_NAME from argilla_server.enums import ( + DatasetDistributionStrategy, DatasetStatus, OptionsOrder, RecordInclude, @@ -116,6 +117,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": None, "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_a.workspace_id), "last_activity_at": dataset_a.last_activity_at.isoformat(), "inserted_at": dataset_a.inserted_at.isoformat(), @@ -127,6 +132,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": "guidelines", "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_b.workspace_id), "last_activity_at": dataset_b.last_activity_at.isoformat(), "inserted_at": dataset_b.inserted_at.isoformat(), @@ -138,6 +147,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": None, "allow_extra_metadata": True, "status": "ready", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_c.workspace_id), "last_activity_at": dataset_c.last_activity_at.isoformat(), "inserted_at": dataset_c.inserted_at.isoformat(), @@ -653,8 +666,6 @@ async def test_list_dataset_vectors_settings_without_authentication(self, async_ assert response.status_code == 401 - # Helper function to create records with responses - async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create(name="dataset") @@ -667,6 +678,10 @@ async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header: "guidelines": None, "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), @@ -839,13 +854,16 @@ async def test_create_dataset(self, async_client: "AsyncClient", db: "AsyncSessi await db.refresh(workspace) response_body = response.json() - assert (await db.execute(select(func.count(Dataset.id)))).scalar() == 1 assert response_body == { "id": str(UUID(response_body["id"])), "name": "name", "guidelines": "guidelines", "allow_extra_metadata": False, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(workspace.id), "last_activity_at": datetime.fromisoformat(response_body["last_activity_at"]).isoformat(), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), @@ -4752,6 +4770,10 @@ async def test_update_dataset(self, async_client: "AsyncClient", db: "AsyncSessi "guidelines": guidelines, "allow_extra_metadata": allow_extra_metadata, "status": "ready", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(),