Skip to content

Commit

Permalink
[ENHANCEMENT] argilla server: Return users on dataset progress (#5701)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR adds support to return a list of usernames in the dataset
progress endpoint.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Refactor (change restructuring the codebase without changing
functionality)
- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: José Francisco Calvo <[email protected]>
  • Loading branch information
frascuchon and jfcalvo committed Nov 27, 2024
1 parent b849166 commit 1338444
Show file tree
Hide file tree
Showing 15 changed files with 306 additions and 61 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add datasets_users table
Revision ID: 580a6553186f
Revises: 6ed1b8bf8e08
Create Date: 2024-11-20 12:15:24.631417
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "580a6553186f"
down_revision = "6ed1b8bf8e08"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"datasets_users",
sa.Column("dataset_id", sa.Uuid(), nullable=False),
sa.Column("user_id", sa.Uuid(), nullable=False),
sa.Column("inserted_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("dataset_id", "user_id"),
)
op.create_index(op.f("ix_datasets_users_dataset_id"), "datasets_users", ["dataset_id"], unique=False)
op.create_index(op.f("ix_datasets_users_user_id"), "datasets_users", ["user_id"], unique=False)

bind = op.get_bind()

statement = """
INSERT INTO datasets_users (dataset_id, user_id, inserted_at, updated_at)
SELECT dataset_id, user_id, {now_func}, {now_func} FROM (
SELECT DISTINCT records.dataset_id AS dataset_id, responses.user_id as user_id
FROM responses
JOIN records ON records.id = responses.record_id
) AS subquery
"""

if bind.dialect.name == "postgresql":
op.execute(statement.format(now_func="NOW()"))
elif bind.dialect.name == "sqlite":
op.execute(statement.format(now_func="datetime('now')"))
else:
raise Exception("Unsupported database dialect")


def downgrade() -> None:
op.drop_index(op.f("ix_datasets_users_user_id"), table_name="datasets_users")
op.drop_index(op.f("ix_datasets_users_dataset_id"), table_name="datasets_users")
op.drop_table("datasets_users")
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ async def get_current_user_dataset_metrics(

await authorize(current_user, DatasetPolicy.get(dataset))

result = await datasets.get_user_dataset_metrics(search_engine, current_user, dataset)
result = await datasets.get_user_dataset_metrics(db, search_engine, current_user, dataset)

return DatasetMetrics(responses=result)


@router.get("/datasets/{dataset_id}/progress", response_model=DatasetProgress)
@router.get("/datasets/{dataset_id}/progress", response_model=DatasetProgress, response_model_exclude_unset=True)
async def get_dataset_progress(
*,
dataset_id: UUID,
Expand All @@ -171,7 +171,7 @@ async def get_dataset_progress(

await authorize(current_user, DatasetPolicy.get(dataset))

result = await datasets.get_dataset_progress(search_engine, dataset)
result = await datasets.get_dataset_progress(db, search_engine, dataset)

return DatasetProgress(**result)

Expand All @@ -181,14 +181,13 @@ async def get_dataset_users_progress(
*,
dataset_id: UUID,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

await authorize(current_user, DatasetPolicy.get(dataset))

progress = await datasets.get_dataset_users_progress(dataset.id)
progress = await datasets.get_dataset_users_progress(db, dataset)

return UsersProgress(users=progress)

Expand Down
15 changes: 9 additions & 6 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ class DatasetMetrics(BaseModel):
responses: ResponseMetrics


class DatasetProgress(BaseModel):
total: int
completed: int
pending: int


class RecordResponseDistribution(BaseModel):
submitted: int = 0
discarded: int = 0
Expand All @@ -101,6 +95,15 @@ class UserProgress(BaseModel):
completed: RecordResponseDistribution = RecordResponseDistribution()
pending: RecordResponseDistribution = RecordResponseDistribution()

model_config = ConfigDict(from_attributes=True)


class DatasetProgress(BaseModel):
total: int
completed: int
pending: int
users: List[UserProgress] = Field(default_factory=list)


class UsersProgress(BaseModel):
users: List[UserProgress]
Expand Down
10 changes: 10 additions & 0 deletions argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from argilla_server.api.schemas.v1.responses import UserResponseCreate
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
from argilla_server.models.database import DatasetUser
from argilla_server.webhooks.v1.enums import RecordEvent
from argilla_server.webhooks.v1.records import notify_record_event as notify_record_event_v1
from argilla_server.contexts import distribution
Expand Down Expand Up @@ -109,13 +110,22 @@ async def _upsert_records_responses(
self, records_and_responses: List[Tuple[Record, List[UserResponseCreate]]]
) -> List[Response]:
upsert_many_responses = []
datasets_users = set()
for idx, (record, responses) in enumerate(records_and_responses):
for response_create in responses or []:
upsert_many_responses.append(dict(**response_create.model_dump(), record_id=record.id))
datasets_users.add((response_create.user_id, record.dataset_id))

if not upsert_many_responses:
return []

await DatasetUser.upsert_many(
self._db,
objects=[{"user_id": user_id, "dataset_id": dataset_id} for user_id, dataset_id in datasets_users],
constraints=[DatasetUser.user_id, DatasetUser.dataset_id],
autocommit=False,
)

return await Response.upsert_many(
self._db,
objects=upsert_many_responses,
Expand Down
85 changes: 41 additions & 44 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
ResponseCreate,
ResponseUpdate,
ResponseUpsert,
UserResponseCreate,
)
from argilla_server.api.schemas.v1.vector_settings import (
VectorSettings as VectorSettingsSchema,
Expand All @@ -58,6 +57,7 @@
VectorSettingsCreate,
)
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
from argilla_server.models.database import DatasetUser
from argilla_server.webhooks.v1.enums import DatasetEvent, ResponseEvent, RecordEvent
from argilla_server.webhooks.v1.records import (
build_record_event as build_record_event_v1,
Expand Down Expand Up @@ -391,11 +391,12 @@ async def _configure_query_relationships(


async def get_user_dataset_metrics(
db: AsyncSession,
search_engine: SearchEngine,
user: User,
dataset: Dataset,
) -> dict:
total_records = (await get_dataset_progress(search_engine, dataset))["total"]
total_records = (await get_dataset_progress(db, search_engine, dataset))["total"]
result = await search_engine.get_dataset_user_progress(dataset, user)

submitted_responses = result.get("submitted", 0)
Expand All @@ -413,34 +414,52 @@ async def get_user_dataset_metrics(


async def get_dataset_progress(
db: AsyncSession,
search_engine: SearchEngine,
dataset: Dataset,
) -> dict:
result = await search_engine.get_dataset_progress(dataset)
users = await get_users_with_responses_for_dataset(db, dataset)

return {
"total": result.get("total", 0),
"completed": result.get("completed", 0),
"pending": result.get("pending", 0),
"users": users,
}


async def get_dataset_users_progress(dataset_id: UUID) -> List[dict]:
async def get_users_with_responses_for_dataset(
db: AsyncSession,
dataset: Dataset,
) -> Sequence[User]:
query = (
select(DatasetUser)
.filter_by(dataset_id=dataset.id)
.options(selectinload(DatasetUser.user))
.order_by(DatasetUser.inserted_at.asc())
)

result = await db.scalars(query)
return [r.user for r in result.all()]


async def get_dataset_users_progress(db: AsyncSession, dataset: Dataset) -> List[dict]:
query = (
select(User.username, Record.status, Response.status, func.count(Response.id))
.join(Record)
.join(User)
.where(Record.dataset_id == dataset_id)
.where(Record.dataset_id == dataset.id)
.group_by(User.username, Record.status, Response.status)
)

async for session in get_async_db():
annotators_progress = defaultdict(lambda: defaultdict(dict))
results = (await session.execute(query)).all()
annotators_progress = defaultdict(lambda: defaultdict(dict))
results = (await db.execute(query)).all()

for username, record_status, response_status, count in results:
annotators_progress[username][record_status][response_status] = count
for username, record_status, response_status, count in results:
annotators_progress[username][record_status][response_status] = count

return [{"username": username, **progress} for username, progress in annotators_progress.items()]
return [{"username": username, **progress} for username, progress in annotators_progress.items()]


_EXTRA_METADATA_FLAG = "extra"
Expand Down Expand Up @@ -567,38 +586,6 @@ async def _validate_record_metadata(
raise UnprocessableEntityError(f"metadata is not valid: {e}") from e


async def _build_record_responses(
db: AsyncSession,
record: Record,
responses_create: Optional[List[UserResponseCreate]],
cache: Optional[Set[UUID]] = None,
) -> List[Response]:
"""Create responses for a record."""
if not responses_create:
return []

responses = []

for idx, response_create in enumerate(responses_create):
try:
cache = await validate_user_exists(db, response_create.user_id, cache)

ResponseCreateValidator.validate(response_create, record)

responses.append(
Response(
values=jsonable_encoder(response_create.values),
status=response_create.status,
user_id=response_create.user_id,
record=record,
)
)
except (UnprocessableEntityError, ValueError) as e:
raise UnprocessableEntityError(f"response at position {idx} is not valid: {e}") from e

return responses


async def _build_record_suggestions(
db: AsyncSession,
record: Record,
Expand Down Expand Up @@ -830,8 +817,13 @@ async def create_response(
user_id=user.id,
autocommit=False,
)

await _touch_dataset_last_activity_at(db, record.dataset)
await DatasetUser.upsert(
db,
schema={"dataset_id": record.dataset_id, "user_id": user.id},
constraints=[DatasetUser.dataset_id, DatasetUser.user_id],
autocommit=False,
)

await db.commit()

Expand Down Expand Up @@ -888,7 +880,12 @@ async def upsert_response(
autocommit=False,
)
await _touch_dataset_last_activity_at(db, response.record.dataset)

await DatasetUser.upsert(
db,
schema={"dataset_id": record.dataset_id, "user_id": user.id},
constraints=[DatasetUser.dataset_id, DatasetUser.user_id],
autocommit=False,
)
await db.commit()

await distribution.update_record_status(search_engine, record.id)
Expand Down
1 change: 1 addition & 0 deletions argilla-server/src/argilla_server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"2.0": "237f7c674d74",
"2.4": "660d6c6b3360",
"2.5": "6ed1b8bf8e08",
"2.6": "580a6553186f",
}
)

Expand Down
Loading

0 comments on commit 1338444

Please sign in to comment.