Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add metadata support for datasets #5586

Merged
merged 7 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ These are the section headers that we use:

## [Unreleased]()

### Added

- Added new `metadata` attribute for endpoints getting, creating and updating Datasets so now it is possible to store metadata associated to a dataset. ([#5586](https://github.com/argilla-io/argilla/pull/5586))

### Changed

- Now it is possible to publish a dataset without required fields. Allowing being published with at least one field (required or not). ([#5569](https://github.com/argilla-io/argilla/pull/5569))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add metadata column to datasets table

Revision ID: 660d6c6b3360
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
Revises: 237f7c674d74
Create Date: 2024-10-04 16:47:21.611404

"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "660d6c6b3360"
down_revision = "237f7c674d74"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column("datasets", sa.Column("metadata", sa.JSON(), nullable=True))


def downgrade() -> None:
op.drop_column("datasets", "metadata")
15 changes: 14 additions & 1 deletion argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

from datetime import datetime
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional, Union, Dict, Any
from uuid import UUID

from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.pydantic_v1 import BaseModel, Field, constr
from argilla_server.pydantic_v1.utils import GetterDict

try:
from typing import Annotated
Expand Down Expand Up @@ -104,20 +105,30 @@ class UsersProgress(BaseModel):
users: List[UserProgress]


class DatasetGetterDict(GetterDict):
def get(self, key: str, default: Any) -> Any:
if key == "metadata":
return getattr(self._obj, "metadata_", None)

return super().get(key, default)


class Dataset(BaseModel):
id: UUID
name: str
guidelines: Optional[str]
allow_extra_metadata: bool
status: DatasetStatus
distribution: DatasetDistribution
metadata: Optional[Dict[str, Any]]
workspace_id: UUID
last_activity_at: datetime
inserted_at: datetime
updated_at: datetime

class Config:
orm_mode = True
getter_dict = DatasetGetterDict


class Datasets(BaseModel):
Expand All @@ -132,6 +143,7 @@ class DatasetCreate(BaseModel):
strategy=DatasetDistributionStrategy.overlap,
min_submitted=1,
)
metadata: Optional[Dict[str, Any]] = None
workspace_id: UUID


Expand All @@ -140,5 +152,6 @@ class DatasetUpdate(UpdateSchema):
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: Optional[bool]
distribution: Optional[DatasetDistributionUpdate]
metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata")

__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}
3 changes: 2 additions & 1 deletion argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) ->
return result.scalars().all()


async def create_dataset(db: AsyncSession, dataset_attrs: dict):
async def create_dataset(db: AsyncSession, dataset_attrs: dict) -> Dataset:
dataset = Dataset(
name=dataset_attrs["name"],
guidelines=dataset_attrs["guidelines"],
allow_extra_metadata=dataset_attrs["allow_extra_metadata"],
distribution=dataset_attrs["distribution"],
metadata_=dataset_attrs["metadata"],
workspace_id=dataset_attrs["workspace_id"],
)

Expand Down
1 change: 1 addition & 0 deletions argilla-server/src/argilla_server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"1.18": "bda6fe24314e",
"1.28": "ca7293c38970",
"2.0": "237f7c674d74",
"2.4": "660d6c6b3360",
}
)

Expand Down
1 change: 1 addition & 0 deletions argilla-server/src/argilla_server/models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ class Dataset(DatabaseModel):
allow_extra_metadata: Mapped[bool] = mapped_column(default=True, server_default=sql.true())
status: Mapped[DatasetStatus] = mapped_column(DatasetStatusEnum, default=DatasetStatus.draft, index=True)
distribution: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON))
metadata_: Mapped[Optional[dict]] = mapped_column("metadata", JSON, nullable=True)
workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True)
inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(default=inserted_at_current_value, onupdate=datetime.utcnow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.

import pytest
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.models import Dataset

from typing import Any
from httpx import AsyncClient
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.models import Dataset

from tests.factories import WorkspaceFactory


Expand Down Expand Up @@ -54,6 +57,7 @@ async def test_create_dataset_with_default_distribution(
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
},
"metadata": None,
"workspace_id": str(workspace.id),
"last_activity_at": dataset.last_activity_at.isoformat(),
"inserted_at": dataset.inserted_at.isoformat(),
Expand Down Expand Up @@ -91,6 +95,7 @@ async def test_create_dataset_with_overlap_distribution(
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 4,
},
"metadata": None,
"workspace_id": str(workspace.id),
"last_activity_at": dataset.last_activity_at.isoformat(),
"inserted_at": dataset.inserted_at.isoformat(),
Expand Down Expand Up @@ -137,3 +142,63 @@ async def test_create_dataset_with_invalid_distribution_strategy(

assert response.status_code == 422
assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0

async def test_create_dataset_with_default_metadata(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
workspace = await WorkspaceFactory.create()

response = await async_client.post(
self.url(),
headers=owner_auth_header,
json={
"name": "Dataset Name",
"workspace_id": str(workspace.id),
},
)

assert response.status_code == 201
assert response.json()["metadata"] == None

dataset = (await db.execute(select(Dataset))).scalar_one()
assert dataset.metadata_ == None

async def test_create_dataset_with_custom_metadata(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
workspace = await WorkspaceFactory.create()

response = await async_client.post(
self.url(),
headers=owner_auth_header,
json={
"name": "Dataset Name",
"metadata": {"key": "value"},
"workspace_id": str(workspace.id),
},
)

assert response.status_code == 201
assert response.json()["metadata"] == {"key": "value"}

dataset = (await db.execute(select(Dataset))).scalar_one()
assert dataset.metadata_ == {"key": "value"}

@pytest.mark.parametrize("invalid_metadata", ["invalid_metadata", 123])
async def test_create_dataset_with_invalid_metadata(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict, invalid_metadata: Any
):
workspace = await WorkspaceFactory.create()

response = await async_client.post(
self.url(),
headers=owner_auth_header,
json={
"name": "Dataset Name",
"metadata": invalid_metadata,
"workspace_id": str(workspace.id),
},
)

assert response.status_code == 422
assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from uuid import UUID

import pytest
Expand Down Expand Up @@ -72,29 +73,6 @@ async def test_update_dataset_without_distribution(self, async_client: AsyncClie
"min_submitted": 1,
}

async def test_update_dataset_without_distribution_for_published_dataset(
self, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"name": "Dataset updated name"},
)

assert response.status_code == 200
assert response.json()["distribution"] == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

assert dataset.name == "Dataset updated name"
assert dataset.distribution == {
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_distribution_with_invalid_strategy(
self, async_client: AsyncClient, owner_auth_header: dict
):
Expand Down Expand Up @@ -152,3 +130,81 @@ async def test_update_dataset_distribution_as_none(self, async_client: AsyncClie
"strategy": DatasetDistributionStrategy.overlap,
"min_submitted": 1,
}

async def test_update_dataset_metadata(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create(metadata_={"key-a": "value-a", "key-b": "value-b"})

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={
"metadata": {
"key-a": "value-a-updated",
"key-c": "value-c",
},
},
)

assert response.status_code == 200
assert response.json()["metadata"] == {
"key-a": "value-a-updated",
"key-b": "value-b",
"key-c": "value-c",
}

assert dataset.metadata_ == {
"key-a": "value-a-updated",
"key-b": "value-b",
"key-c": "value-c",
}

async def test_update_dataset_without_metadata(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create(metadata_={"key": "value"})

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"name": "Dataset updated name"},
)

assert response.status_code == 200
assert response.json()["metadata"] == {"key": "value"}

assert dataset.name == "Dataset updated name"
assert dataset.metadata_ == {"key": "value"}

async def test_update_dataset_with_invalid_metadata(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create(metadata_={"key": "value"})

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"metadata": "invalid_metadata"},
)

assert response.status_code == 422
assert dataset.metadata_ == {"key": "value"}

async def test_update_dataset_metadata_as_empty_dict(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create(metadata_={"key": "value"})

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"metadata": {}},
)

assert response.status_code == 200
assert dataset.metadata_ == {"key": "value"}

async def test_update_dataset_metadata_as_none(self, async_client: AsyncClient, owner_auth_header: dict):
dataset = await DatasetFactory.create(metadata_={"key": "value"})

response = await async_client.patch(
self.url(dataset.id),
headers=owner_auth_header,
json={"metadata": None},
jfcalvo marked this conversation as resolved.
Show resolved Hide resolved
)

assert response.status_code == 200
assert dataset.metadata_ == None
Loading
Loading