Skip to content

Commit

Permalink
feat: add dataset support to be created using distribution settings
Browse files Browse the repository at this point in the history
  • Loading branch information
jfcalvo committed Jun 13, 2024
1 parent 8e9f42b commit d5b762c
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""add record metadata column
"""add metadata column to records table
Revision ID: 3ff6484f8b37
Revises: ae5522b4c674
Expand All @@ -30,12 +30,8 @@


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("records", sa.Column("metadata", sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("records", "metadata")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add distribution column to datasets table
Revision ID: 45a12f74448b
Revises: ca7293c38970
Create Date: 2024-06-13 11:23:43.395093
"""
import json

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "45a12f74448b"
down_revision = "ca7293c38970"
branch_labels = None
depends_on = None

DISTRIBUTION_VALUE = json.dumps({"strategy": "overlap", "min_submitted": 1})


def upgrade() -> None:
op.add_column("datasets", sa.Column("distribution", sa.JSON(), nullable=True))
op.execute(f"UPDATE datasets SET distribution = '{DISTRIBUTION_VALUE}'")
with op.batch_alter_table("datasets") as batch_op:
batch_op.alter_column("distribution", nullable=False)


def downgrade() -> None:
op.drop_column("datasets", "distribution")
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""add allow_extra_metadata column to dataset table
"""add allow_extra_metadata column to datasets table
Revision ID: b8458008b60e
Revises: 7cbcccf8b57a
Expand All @@ -30,14 +30,10 @@


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"datasets", sa.Column("allow_extra_metadata", sa.Boolean(), server_default=sa.text("true"), nullable=False)
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("datasets", "allow_extra_metadata")
# ### end Alembic commands ###
28 changes: 26 additions & 2 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

from datetime import datetime
from typing import List, Optional
from typing import List, Literal, Optional, Union
from uuid import UUID

from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.enums import DatasetStatus
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.pydantic_v1 import BaseModel, Field, constr

try:
Expand All @@ -44,6 +44,25 @@
]


class DatasetOverlapDistribution(BaseModel):
strategy: Literal[DatasetDistributionStrategy.overlap]
min_submitted: int


DatasetDistribution = DatasetOverlapDistribution


class DatasetOverlapDistributionCreate(BaseModel):
strategy: Literal[DatasetDistributionStrategy.overlap]
min_submitted: int = Field(
ge=1,
description="Minimum number of submitted responses to consider a record as completed",
)


DatasetDistributionCreate = DatasetOverlapDistributionCreate


class RecordMetrics(BaseModel):
count: int

Expand Down Expand Up @@ -74,6 +93,7 @@ class Dataset(BaseModel):
guidelines: Optional[str]
allow_extra_metadata: bool
status: DatasetStatus
distribution: DatasetDistribution
workspace_id: UUID
last_activity_at: datetime
inserted_at: datetime
Expand All @@ -91,6 +111,10 @@ class DatasetCreate(BaseModel):
name: DatasetName
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: bool = True
distribution: DatasetDistributionCreate = DatasetOverlapDistributionCreate(
strategy=DatasetDistributionStrategy.overlap,
min_submitted=1,
)
workspace_id: UUID


Expand Down
1 change: 1 addition & 0 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate):
name=dataset_create.name,
guidelines=dataset_create.guidelines,
allow_extra_metadata=dataset_create.allow_extra_metadata,
distribution=dataset_create.distribution.dict(),
workspace_id=dataset_create.workspace_id,
)

Expand Down
4 changes: 4 additions & 0 deletions argilla-server/src/argilla_server/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class DatasetStatus(str, Enum):
ready = "ready"


class DatasetDistributionStrategy(str, Enum):
overlap = "overlap"


class UserRole(str, Enum):
owner = "owner"
admin = "admin"
Expand Down
1 change: 1 addition & 0 deletions argilla-server/src/argilla_server/models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ class Dataset(DatabaseModel):
guidelines: Mapped[Optional[str]] = mapped_column(Text)
allow_extra_metadata: Mapped[bool] = mapped_column(default=True, server_default=sql.true())
status: Mapped[DatasetStatus] = mapped_column(DatasetStatusEnum, default=DatasetStatus.draft, index=True)
distribution: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON))
workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True)
inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(default=inserted_at_current_value, onupdate=datetime.utcnow)
Expand Down
3 changes: 2 additions & 1 deletion argilla-server/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import random

import factory
from argilla_server.enums import FieldType, MetadataPropertyType, OptionsOrder
from argilla_server.enums import DatasetDistributionStrategy, FieldType, MetadataPropertyType, OptionsOrder
from argilla_server.models import (
Dataset,
Field,
Expand Down Expand Up @@ -203,6 +203,7 @@ class Meta:
model = Dataset

name = factory.Sequence(lambda n: f"dataset-{n}")
distribution = {"strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1}
workspace = factory.SubFactory(WorkspaceFactory)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.models import Dataset
from httpx import AsyncClient
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from tests.factories import WorkspaceFactory


@pytest.mark.asyncio
class TestCreateDataset:
def url(self) -> str:
return "/api/v1/datasets"

async def test_create_dataset_with_default_distribution(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
workspace = await WorkspaceFactory.create()

response = await async_client.post(
self.url(),
headers=owner_auth_header,
json={
"name": "Dataset Name",
"workspace_id": str(workspace.id),
},
)

dataset = (await db.execute(select(Dataset))).scalar_one()

assert response.status_code == 201
assert response.json() == {
"id": str(dataset.id),
"name": "Dataset Name",
"guidelines": None,
"allow_extra_metadata": True,
"status": DatasetStatus.draft,
"distribution": {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
},
"workspace_id": str(workspace.id),
"last_activity_at": dataset.last_activity_at.isoformat(),
"inserted_at": dataset.inserted_at.isoformat(),
"updated_at": dataset.updated_at.isoformat(),
}

async def test_create_dataset_with_overlap_distribution(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
workspace = await WorkspaceFactory.create()

response = await async_client.post(
self.url(),
headers=owner_auth_header,
json={
"name": "Dataset Name",
"distribution": {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
},
"workspace_id": str(workspace.id),
},
)

dataset = (await db.execute(select(Dataset))).scalar_one()

assert response.status_code == 201
assert response.json() == {
"id": str(dataset.id),
"name": "Dataset Name",
"guidelines": None,
"allow_extra_metadata": True,
"status": DatasetStatus.draft,
"distribution": {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
},
"workspace_id": str(workspace.id),
"last_activity_at": dataset.last_activity_at.isoformat(),
"inserted_at": dataset.inserted_at.isoformat(),
"updated_at": dataset.updated_at.isoformat(),
}

async def test_create_dataset_with_overlap_distribution_using_invalid_min_submitted_value(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
workspace = await WorkspaceFactory.create()

response = await async_client.post(
self.url(),
headers=owner_auth_header,
json={
"name": "Dataset name",
"distribution": {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 0,
},
"workspace_id": str(workspace.id),
},
)

assert response.status_code == 422
assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0
Loading

0 comments on commit d5b762c

Please sign in to comment.