diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records_bulk.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records_bulk.py index 69cc536a0f..9aef5ca0ac 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records_bulk.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records_bulk.py @@ -26,7 +26,6 @@ from argilla_server.models import Dataset, User from argilla_server.search_engine import SearchEngine, get_search_engine from argilla_server.security import auth -from argilla_server.telemetry import TelemetryClient, get_telemetry_client router = APIRouter() @@ -43,7 +42,6 @@ async def create_dataset_records_bulk( db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), current_user: User = Security(auth.get_current_user), - telemetry_client: TelemetryClient = Depends(get_telemetry_client), ): dataset = await Dataset.get_or_raise( db, @@ -58,9 +56,7 @@ async def create_dataset_records_bulk( await authorize(current_user, DatasetPolicy.create_records(dataset)) - records_bulk = await CreateRecordsBulk(db, search_engine).create_records_bulk(dataset, records_bulk_create) - - return records_bulk + return await CreateRecordsBulk(db, search_engine).create_records_bulk(dataset, records_bulk_create) @router.put("/datasets/{dataset_id}/records/bulk", response_model=RecordsBulk) @@ -71,7 +67,6 @@ async def upsert_dataset_records_bulk( db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), current_user: User = Security(auth.get_current_user), - telemetry_client: TelemetryClient = Depends(get_telemetry_client), ): dataset = await Dataset.get_or_raise( db, @@ -86,9 +81,4 @@ async def upsert_dataset_records_bulk( await authorize(current_user, DatasetPolicy.upsert_records(dataset)) - records_bulk = await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, records_bulk_create) - - updated = len(records_bulk.updated_item_ids) - created = len(records_bulk.items) - updated - - return records_bulk + return await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, records_bulk_create) diff --git a/argilla-server/src/argilla_server/bulk/records_bulk.py b/argilla-server/src/argilla_server/bulk/records_bulk.py index babaa2b70c..7af00f3443 100644 --- a/argilla-server/src/argilla_server/bulk/records_bulk.py +++ b/argilla-server/src/argilla_server/bulk/records_bulk.py @@ -220,7 +220,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp await self._db.commit() - await self._notify_record_events(records) + await self._notify_upsert_record_events(records) return RecordsBulkWithUpdateInfo( items=records, @@ -241,7 +241,7 @@ async def _fetch_existing_dataset_records( return {**records_by_external_id, **records_by_id} - async def _notify_record_events(self, records: List[Record]) -> None: + async def _notify_upsert_record_events(self, records: List[Record]) -> None: for record in records: if record.inserted_at == record.updated_at: await notify_record_event_v1(self._db, RecordEvent.created, record) diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index dfdcad6ca7..56f91804f6 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -827,7 +827,7 @@ async def delete_records( ) -> None: params = [Record.id.in_(records_ids), Record.dataset_id == dataset.id] - records = (await db.execute(select(Record).filter(*params))).scalars().all() + records = (await db.execute(select(Record).filter(*params).order_by(Record.inserted_at.asc()))).scalars().all() deleted_record_events_v1 = [] for record in records: diff --git a/argilla-server/tests/conftest.py b/argilla-server/tests/conftest.py index 67d704bf2c..55b4a53af5 100644 --- a/argilla-server/tests/conftest.py +++ b/argilla-server/tests/conftest.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from typing import TYPE_CHECKING, AsyncGenerator, Generator - import httpx +import asyncio import pytest import pytest_asyncio + +from rq import Queue +from typing import TYPE_CHECKING, AsyncGenerator, Generator +from sqlalchemy import NullPool, create_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + from argilla_server.cli.database.migrate import migrate_db from argilla_server.database import database_url_sync +from argilla_server.jobs.queues import REDIS_CONNECTION from argilla_server.settings import settings -from sqlalchemy import NullPool, create_engine -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from tests.database import SyncTestSession, TestSession, set_task @@ -97,6 +100,16 @@ def sync_db(sync_connection: "Connection") -> Generator["Session", None, None]: sync_connection.rollback() +@pytest.fixture(autouse=True) +def empty_job_queues(): + queues = Queue.all(connection=REDIS_CONNECTION) + + for queue in queues: + queue.empty() + + yield + + @pytest.fixture def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession": """Create a mocked `AsyncSession` that proxies to the sync session. This will allow us to execute the async CLI commands diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py index 4a40ed48a1..21831a1c8c 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from uuid import UUID - import pytest -from argilla_server.enums import DatasetStatus, QuestionType, ResponseStatus, SuggestionType -from argilla_server.models.database import Record, Response, Suggestion, User + +from uuid import UUID from httpx import AsyncClient +from fastapi.encoders import jsonable_encoder from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.webhooks.v1.enums import RecordEvent +from argilla_server.webhooks.v1.records import build_record_event +from argilla_server.models.database import Record, Response, Suggestion, User +from argilla_server.enums import DatasetStatus, QuestionType, ResponseStatus, SuggestionType + from tests.factories import ( DatasetFactory, LabelSelectionQuestionFactory, @@ -32,6 +37,7 @@ ImageFieldFactory, TextQuestionFactory, ChatFieldFactory, + WebhookFactory, ) @@ -551,3 +557,48 @@ async def test_create_dataset_records_bulk_with_chat_field_without_content_key( } } assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 0 + + async def test_create_dataset_records_bulk_enqueue_webhook_record_created_events( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + await TextFieldFactory.create(name="prompt", dataset=dataset) + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + webhook = await WebhookFactory.create(events=[RecordEvent.created]) + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "prompt": "Does exercise help reduce stress?", + }, + }, + { + "fields": { + "prompt": "What is the best way to reduce stress?", + }, + }, + ], + }, + ) + + assert response.status_code == 201 + + records = (await db.execute(select(Record).order_by(Record.inserted_at.asc()))).scalars().all() + + event_a = await build_record_event(db, RecordEvent.created, records[0]) + event_b = await build_record_event(db, RecordEvent.created, records[1]) + + assert HIGH_QUEUE.count == 2 + + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.created + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data) + + assert HIGH_QUEUE.jobs[1].args[0] == webhook.id + assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.created + assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_upsert_dataset_records_bulk.py similarity index 58% rename from argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py rename to argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_upsert_dataset_records_bulk.py index 82b035a58a..73c37c5bb1 100644 --- a/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_upsert_dataset_records_bulk.py @@ -16,11 +16,26 @@ from uuid import UUID from httpx import AsyncClient +from fastapi.encoders import jsonable_encoder +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession -from argilla_server.models import User + +from argilla_server.models import User, Record +from argilla_server.jobs.queues import HIGH_QUEUE from argilla_server.enums import DatasetDistributionStrategy, ResponseStatus, DatasetStatus, RecordStatus +from argilla_server.webhooks.v1.enums import RecordEvent +from argilla_server.webhooks.v1.records import build_record_event -from tests.factories import DatasetFactory, RecordFactory, TextQuestionFactory, ResponseFactory, AnnotatorFactory +from tests.factories import ( + DatasetFactory, + RecordFactory, + TextFieldFactory, + TextQuestionFactory, + AnnotatorFactory, + WebhookFactory, + ResponseFactory, +) @pytest.mark.asyncio @@ -151,3 +166,95 @@ async def test_upsert_dataset_records_bulk_updates_records_status( assert record_b.status == RecordStatus.pending assert record_c.status == RecordStatus.pending assert record_d.status == RecordStatus.pending + + async def test_upsert_dataset_records_bulk_enqueue_webhook_record_created_events( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + await TextFieldFactory.create(name="prompt", dataset=dataset) + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + webhook = await WebhookFactory.create(events=[RecordEvent.created, RecordEvent.updated]) + + response = await async_client.put( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "prompt": "Does exercise help reduce stress?", + }, + }, + { + "fields": { + "prompt": "What is the best way to reduce stress?", + }, + }, + ], + }, + ) + + assert response.status_code == 200 + + records = (await db.execute(select(Record).order_by(Record.inserted_at.asc()))).scalars().all() + + event_a = await build_record_event(db, RecordEvent.created, records[0]) + event_b = await build_record_event(db, RecordEvent.created, records[1]) + + assert HIGH_QUEUE.count == 2 + + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.created + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data) + + assert HIGH_QUEUE.jobs[1].args[0] == webhook.id + assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.created + assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data) + + async def test_upsert_dataset_records_bulk_enqueue_webhook_record_updated_events( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + await TextFieldFactory.create(name="prompt", dataset=dataset) + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + records = await RecordFactory.create_batch(2, dataset=dataset) + + webhook = await WebhookFactory.create(events=[RecordEvent.created, RecordEvent.updated]) + + response = await async_client.put( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "id": str(records[0].id), + "metadata": { + "metadata-key": "metadata-value", + }, + }, + { + "id": str(records[1].id), + "metadata": { + "metadata-key": "metadata-value", + }, + }, + ], + }, + ) + + assert response.status_code == 200 + + event_a = await build_record_event(db, RecordEvent.updated, records[0]) + event_b = await build_record_event(db, RecordEvent.updated, records[1]) + + assert HIGH_QUEUE.count == 2 + + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.updated + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data) + + assert HIGH_QUEUE.jobs[1].args[0] == webhook.id + assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.updated + assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_delete_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_delete_dataset_records.py new file mode 100644 index 0000000000..19773b3512 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_delete_dataset_records.py @@ -0,0 +1,60 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient +from fastapi.encoders import jsonable_encoder +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.webhooks.v1.enums import RecordEvent +from argilla_server.webhooks.v1.records import build_record_event + +from tests.factories import DatasetFactory, RecordFactory, WebhookFactory + + +@pytest.mark.asyncio +class TestDeleteDatasetRecords: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/records" + + async def test_delete_dataset_records_enqueue_webhook_record_deleted_events( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + records = await RecordFactory.create_batch(2, dataset=dataset) + webhook = await WebhookFactory.create(events=[RecordEvent.deleted]) + + event_a = await build_record_event(db, RecordEvent.deleted, records[0]) + event_b = await build_record_event(db, RecordEvent.deleted, records[1]) + + response = await async_client.delete( + self.url(dataset.id), + headers=owner_auth_header, + params={"ids": f"{records[0].id},{records[1].id}"}, + ) + + assert response.status_code == 204 + + assert HIGH_QUEUE.count == 2 + + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.deleted + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data) + + assert HIGH_QUEUE.jobs[1].args[0] == webhook.id + assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.deleted + assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py index 4261145d0c..567900a529 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py @@ -13,13 +13,19 @@ # limitations under the License. import pytest -from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus -from argilla_server.models import Dataset + from httpx import AsyncClient from sqlalchemy import func, select +from fastapi.encoders import jsonable_encoder from sqlalchemy.ext.asyncio import AsyncSession -from tests.factories import WorkspaceFactory +from argilla_server.models import Dataset +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus +from argilla_server.webhooks.v1.enums import DatasetEvent +from argilla_server.webhooks.v1.datasets import build_dataset_event + +from tests.factories import WebhookFactory, WorkspaceFactory @pytest.mark.asyncio @@ -137,3 +143,28 @@ async def test_create_dataset_with_invalid_distribution_strategy( assert response.status_code == 422 assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 + + async def test_create_dataset_enqueue_webhook_dataset_created_event( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + webhook = await WebhookFactory.create(events=[DatasetEvent.created]) + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 201 + + dataset = (await db.execute(select(Dataset))).scalar_one() + event = await build_dataset_event(db, DatasetEvent.created, dataset) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == DatasetEvent.created + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_delete_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_delete_dataset.py new file mode 100644 index 0000000000..d01feabcbf --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_delete_dataset.py @@ -0,0 +1,52 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient +from fastapi.encoders import jsonable_encoder +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.webhooks.v1.enums import DatasetEvent +from argilla_server.webhooks.v1.datasets import build_dataset_event + +from tests.factories import DatasetFactory, WebhookFactory + + +@pytest.mark.asyncio +class TestDeleteDataset: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}" + + async def test_delete_dataset_enqueue_webhook_dataset_deleted_event( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + webhook = await WebhookFactory.create(events=[DatasetEvent.deleted]) + + event = await build_dataset_event(db, DatasetEvent.deleted, dataset) + + response = await async_client.delete( + self.url(dataset.id), + headers=owner_auth_header, + ) + + assert response.status_code == 200 + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == DatasetEvent.deleted + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_publish_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_publish_dataset.py new file mode 100644 index 0000000000..9fb3a11481 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_publish_dataset.py @@ -0,0 +1,55 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient +from fastapi.encoders import jsonable_encoder +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.webhooks.v1.enums import DatasetEvent +from argilla_server.webhooks.v1.datasets import build_dataset_event + +from tests.factories import DatasetFactory, TextFieldFactory, RatingQuestionFactory, WebhookFactory + + +@pytest.mark.asyncio +class TestPublishDataset: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/publish" + + async def test_publish_dataset_enqueue_webhook_dataset_published_event( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + await TextFieldFactory.create(dataset=dataset, required=True) + await RatingQuestionFactory.create(dataset=dataset, required=True) + + webhook = await WebhookFactory.create(events=[DatasetEvent.published]) + + response = await async_client.put( + self.url(dataset.id), + headers=owner_auth_header, + ) + + assert response.status_code == 200 + + event = await build_dataset_event(db, DatasetEvent.published, dataset) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == DatasetEvent.published + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py index 91c4f54f17..e3dbf842bb 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from uuid import UUID - import pytest + +from uuid import UUID from httpx import AsyncClient +from fastapi.encoders import jsonable_encoder +from sqlalchemy.ext.asyncio import AsyncSession +from argilla_server.jobs.queues import HIGH_QUEUE from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus -from tests.factories import DatasetFactory, RecordFactory, ResponseFactory +from argilla_server.webhooks.v1.datasets import build_dataset_event +from argilla_server.webhooks.v1.enums import DatasetEvent + +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, WebhookFactory @pytest.mark.asyncio @@ -152,3 +158,24 @@ async def test_update_dataset_distribution_as_none(self, async_client: AsyncClie "strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1, } + + async def test_update_dataset_enqueue_webhook_dataset_updated_event( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + webhook = await WebhookFactory.create(events=[DatasetEvent.updated]) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"name": "Updated dataset"}, + ) + + assert response.status_code == 200 + + event = await build_dataset_event(db, DatasetEvent.updated, dataset) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == DatasetEvent.updated + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py b/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py index ce433d036d..68fbd93685 100644 --- a/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py @@ -12,19 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime -from uuid import UUID - import pytest +from uuid import UUID +from datetime import datetime from httpx import AsyncClient from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from fastapi.encoders import jsonable_encoder -from argilla_server.enums import ResponseStatus, RecordStatus, DatasetDistributionStrategy from argilla_server.models import Response, User +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.webhooks.v1.enums import RecordEvent, ResponseEvent +from argilla_server.webhooks.v1.responses import build_response_event +from argilla_server.webhooks.v1.records import build_record_event +from argilla_server.enums import ResponseStatus, RecordStatus, DatasetDistributionStrategy -from tests.factories import DatasetFactory, RecordFactory, SpanQuestionFactory, TextQuestionFactory +from tests.factories import DatasetFactory, RecordFactory, SpanQuestionFactory, TextQuestionFactory, WebhookFactory @pytest.mark.asyncio @@ -516,3 +520,118 @@ async def test_create_record_response_does_not_updates_record_status_to_complete assert response.status_code == 201 assert record.status == RecordStatus.pending + + async def test_create_record_response_enqueue_webhook_response_created_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + } + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + webhook = await WebhookFactory.create(events=[ResponseEvent.created]) + + resp = await async_client.post( + self.url(record.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert resp.status_code == 201 + + response = (await db.execute(select(Response))).scalar_one() + event = await build_response_event(db, ResponseEvent.created, response) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == ResponseEvent.created + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) + + async def test_create_record_response_enqueue_webhook_record_updated_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + webhook = await WebhookFactory.create(events=[RecordEvent.updated]) + + response = await async_client.post( + self.url(record.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert response.status_code == 201 + + event = await build_record_event(db, RecordEvent.updated, record) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.updated + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) + + async def test_create_record_response_enqueue_webhook_record_completed_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + webhook = await WebhookFactory.create(events=[RecordEvent.completed]) + + response = await async_client.post( + self.url(record.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert response.status_code == 201 + + event = await build_record_event(db, RecordEvent.completed, record) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.completed + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_delete_record.py b/argilla-server/tests/unit/api/handlers/v1/records/test_delete_record.py new file mode 100644 index 0000000000..ab017e50fa --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_delete_record.py @@ -0,0 +1,52 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient +from fastapi.encoders import jsonable_encoder +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.webhooks.v1.enums import RecordEvent +from argilla_server.webhooks.v1.records import build_record_event + +from tests.factories import RecordFactory, WebhookFactory + + +@pytest.mark.asyncio +class TestDeleteRecord: + def url(self, record_id: UUID) -> str: + return f"/api/v1/records/{record_id}" + + async def test_delete_record_enqueue_webhook_record_deleted_event( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + record = await RecordFactory.create() + webhook = await WebhookFactory.create(events=[RecordEvent.deleted]) + + event = await build_record_event(db, RecordEvent.deleted, record) + + response = await async_client.delete( + self.url(record.id), + headers=owner_auth_header, + ) + + assert response.status_code == 200 + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.deleted + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_update_record.py b/argilla-server/tests/unit/api/handlers/v1/records/test_update_record.py new file mode 100644 index 0000000000..d8eba32655 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_update_record.py @@ -0,0 +1,53 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient +from fastapi.encoders import jsonable_encoder +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.webhooks.v1.enums import RecordEvent +from argilla_server.webhooks.v1.records import build_record_event + +from tests.factories import RecordFactory, WebhookFactory + + +@pytest.mark.asyncio +class TestUpdateRecord: + def url(self, record_id: UUID) -> str: + return f"/api/v1/records/{record_id}" + + async def test_update_record_enqueue_webhook_record_updated_event( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + record = await RecordFactory.create() + webhook = await WebhookFactory.create(events=[RecordEvent.updated]) + + response = await async_client.patch( + self.url(record.id), + headers=owner_auth_header, + json={}, + ) + + assert response.status_code == 200 + + event = await build_record_event(db, RecordEvent.updated, record) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.updated + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py index 07b4bf0199..3cfe3fb7a4 100644 --- a/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py @@ -11,28 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os +import pytest + +from uuid import UUID, uuid4 from datetime import datetime from unittest.mock import call -from uuid import UUID, uuid4 +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from fastapi.encoders import jsonable_encoder -import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import ResponseStatus, RecordStatus +from argilla_server.enums import DatasetDistributionStrategy, ResponseStatus, RecordStatus +from argilla_server.jobs.queues import HIGH_QUEUE from argilla_server.models import Response, User from argilla_server.search_engine import SearchEngine from argilla_server.use_cases.responses.upsert_responses_in_bulk import UpsertResponsesInBulkUseCase -from httpx import AsyncClient -from sqlalchemy import func, select -from sqlalchemy.ext.asyncio import AsyncSession - +from argilla_server.webhooks.v1.enums import RecordEvent, ResponseEvent +from argilla_server.webhooks.v1.responses import build_response_event +from argilla_server.webhooks.v1.records import build_record_event from tests.factories import ( AnnotatorFactory, DatasetFactory, RatingQuestionFactory, RecordFactory, ResponseFactory, + WebhookFactory, WorkspaceUserFactory, + TextQuestionFactory, ) @@ -447,3 +455,183 @@ async def refresh_records(records): await use_case.execute([bulk_item.item for bulk_item in bulk_items], user) profiler.open_in_browser() + + async def test_create_current_user_responses_bulk_enqueue_webhook_response_created_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + webhook = await WebhookFactory.create(events=[ResponseEvent.created, ResponseEvent.updated]) + + resp = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "items": [ + { + "values": { + "text-question": { + "value": "Created value", + }, + }, + "status": ResponseStatus.submitted, + "record_id": str(record.id), + }, + ], + }, + ) + + assert resp.status_code == 200 + + response = (await db.execute(select(Response))).scalar_one() + event = await build_response_event(db, ResponseEvent.created, response) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == ResponseEvent.created + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) + + async def test_create_current_user_responses_bulk_enqueue_webhook_response_updated_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await ResponseFactory.create( + values={"text-question": {"value": "Created value"}}, + status=ResponseStatus.submitted, + record=record, + user=owner, + ) + + webhook = await WebhookFactory.create(events=[ResponseEvent.created, ResponseEvent.updated]) + + resp = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "items": [ + { + "values": { + "text-question": { + "value": "Updated value", + }, + }, + "status": ResponseStatus.submitted, + "record_id": str(record.id), + }, + ], + }, + ) + + assert resp.status_code == 200 + + event = await build_response_event(db, ResponseEvent.updated, response) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == ResponseEvent.updated + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) + + async def test_create_current_user_responses_bulk_enqueue_webhook_record_updated_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + webhook = await WebhookFactory.create(events=[RecordEvent.updated]) + + resp = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "items": [ + { + "values": { + "text-question": { + "value": "Created value", + }, + }, + "status": ResponseStatus.submitted, + "record_id": str(record.id), + }, + ], + }, + ) + + assert resp.status_code == 200 + + event = await build_record_event(db, RecordEvent.updated, record) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.updated + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) + + async def test_create_current_user_responses_bulk_enqueue_webhook_record_completed_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + webhook = await WebhookFactory.create(events=[RecordEvent.completed]) + + resp = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "items": [ + { + "values": { + "text-question": { + "value": "Created value", + }, + }, + "status": ResponseStatus.submitted, + "record_id": str(record.id), + }, + ], + }, + ) + + assert resp.status_code == 200 + + event = await build_record_event(db, RecordEvent.completed, record) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.completed + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py index 6b9d4ec749..af66f6d2ec 100644 --- a/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py @@ -12,16 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from uuid import UUID - import pytest +from uuid import UUID from httpx import AsyncClient +from fastapi.encoders import jsonable_encoder +from sqlalchemy.ext.asyncio import AsyncSession from argilla_server.models import User +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.webhooks.v1.enums import RecordEvent, ResponseEvent +from argilla_server.webhooks.v1.responses import build_response_event +from argilla_server.webhooks.v1.records import build_record_event from argilla_server.enums import DatasetDistributionStrategy, RecordStatus, ResponseStatus -from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, TextQuestionFactory +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, TextQuestionFactory, WebhookFactory @pytest.mark.asyncio @@ -64,3 +69,56 @@ async def test_delete_response_does_not_updates_record_status_to_pending( assert resp.status_code == 200 assert record.status == RecordStatus.completed + + async def test_delete_response_enqueue_webhook_response_deleted_event( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await ResponseFactory.create() + webhook = await WebhookFactory.create(events=[ResponseEvent.deleted]) + + event = await build_response_event(db, ResponseEvent.deleted, response) + + resp = await async_client.delete(self.url(response.id), headers=owner_auth_header) + + assert resp.status_code == 200 + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == ResponseEvent.deleted + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) + + async def test_delete_response_enqueue_webhook_record_updated_event( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + record = await RecordFactory.create() + responses = await ResponseFactory.create_batch(2, record=record) + webhook = await WebhookFactory.create(events=[RecordEvent.updated]) + + response = await async_client.delete(self.url(responses[0].id), headers=owner_auth_header) + + assert response.status_code == 200 + + event = await build_record_event(db, RecordEvent.updated, record) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.updated + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) + + async def test_delete_response_enqueue_webhook_record_completed_event( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + record = await RecordFactory.create() + responses = await ResponseFactory.create_batch(2, record=record) + webhook = await WebhookFactory.create(events=[RecordEvent.completed]) + + response = await async_client.delete(self.url(responses[0].id), headers=owner_auth_header) + + assert response.status_code == 200 + + event = await build_record_event(db, RecordEvent.completed, record) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.completed + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py index d5097f8c7b..4d5f8a4792 100644 --- a/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py @@ -12,19 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime -from uuid import UUID - import pytest -from httpx import AsyncClient +from uuid import UUID +from datetime import datetime +from httpx import AsyncClient from sqlalchemy import select from sqlalchemy.ext.asyncio.session import AsyncSession +from fastapi.encoders import jsonable_encoder -from argilla_server.enums import ResponseStatus, DatasetDistributionStrategy, RecordStatus from argilla_server.models import Response, User +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.webhooks.v1.enums import RecordEvent, ResponseEvent +from argilla_server.webhooks.v1.responses import build_response_event +from argilla_server.webhooks.v1.records import build_record_event +from argilla_server.enums import ResponseStatus, DatasetDistributionStrategy, RecordStatus -from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, SpanQuestionFactory, TextQuestionFactory +from tests.factories import ( + DatasetFactory, + RecordFactory, + ResponseFactory, + SpanQuestionFactory, + TextQuestionFactory, + WebhookFactory, +) @pytest.mark.asyncio @@ -625,3 +636,147 @@ async def test_update_response_updates_record_status_to_pending( assert resp.status_code == 200 assert record.status == RecordStatus.pending + + async def test_update_response_enqueue_webhook_response_updated_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await ResponseFactory.create( + values={ + "text-question": { + "value": "Hello", + }, + }, + status=ResponseStatus.submitted, + user=owner, + record=record, + ) + + webhook = await WebhookFactory.create(events=[ResponseEvent.updated]) + + resp = await async_client.put( + self.url(response.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "Update value", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert resp.status_code == 200 + + event = await build_response_event(db, ResponseEvent.updated, response) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == ResponseEvent.updated + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) + + async def test_update_response_enqueue_webhook_record_updated_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await ResponseFactory.create( + values={ + "text-question": { + "value": "Hello", + }, + }, + status=ResponseStatus.draft, + user=owner, + record=record, + ) + + webhook = await WebhookFactory.create(events=[RecordEvent.updated]) + + resp = await async_client.put( + self.url(response.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "Update value", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert resp.status_code == 200 + + event = await build_record_event(db, RecordEvent.updated, record) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.updated + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) + + async def test_update_response_enqueue_webhook_record_completed_event( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await ResponseFactory.create( + values={ + "text-question": { + "value": "Hello", + }, + }, + status=ResponseStatus.draft, + user=owner, + record=record, + ) + + webhook = await WebhookFactory.create(events=[RecordEvent.completed]) + + resp = await async_client.put( + self.url(response.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "Update value", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert resp.status_code == 200 + + event = await build_record_event(db, RecordEvent.completed, record) + + assert HIGH_QUEUE.count == 1 + assert HIGH_QUEUE.jobs[0].args[0] == webhook.id + assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.completed + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event.data) diff --git a/argilla-server/tests/unit/jobs/__init__.py b/argilla-server/tests/unit/jobs/__init__.py new file mode 100644 index 0000000000..4b6cecae7f --- /dev/null +++ b/argilla-server/tests/unit/jobs/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/argilla-server/tests/unit/jobs/webhook_jobs/__init__.py b/argilla-server/tests/unit/jobs/webhook_jobs/__init__.py new file mode 100644 index 0000000000..4b6cecae7f --- /dev/null +++ b/argilla-server/tests/unit/jobs/webhook_jobs/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/argilla-server/tests/unit/jobs/webhook_jobs/test_enqueue_notify_events.py b/argilla-server/tests/unit/jobs/webhook_jobs/test_enqueue_notify_events.py new file mode 100644 index 0000000000..4f50f00ca2 --- /dev/null +++ b/argilla-server/tests/unit/jobs/webhook_jobs/test_enqueue_notify_events.py @@ -0,0 +1,58 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +import pytest + +from fastapi.encoders import jsonable_encoder +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.jobs.queues import HIGH_QUEUE +from argilla_server.jobs.webhook_jobs import enqueue_notify_events +from argilla_server.webhooks.v1.enums import ResponseEvent +from argilla_server.webhooks.v1.responses import build_response_event + +from tests.factories import ResponseFactory, WebhookFactory + + +@pytest.mark.asyncio +class TestEnqueueNotifyEvents: + async def test_enqueue_notify_events(self, db: AsyncSession): + response = await ResponseFactory.create() + + webhooks = await WebhookFactory.create_batch(2, events=[ResponseEvent.created]) + webhooks_disabled = await WebhookFactory.create_batch(2, events=[ResponseEvent.created], enabled=False) + webhooks_with_other_events = await WebhookFactory.create_batch(2, events=[ResponseEvent.deleted]) + + event = await build_response_event(db, ResponseEvent.created, response) + jsonable_data = jsonable_encoder(event.data) + + await enqueue_notify_events( + db=db, + event=ResponseEvent.created, + timestamp=event.timestamp, + data=jsonable_data, + ) + + assert HIGH_QUEUE.count == 2 + + assert HIGH_QUEUE.jobs[0].args[0] == webhooks[0].id + assert HIGH_QUEUE.jobs[0].args[1] == ResponseEvent.created + assert HIGH_QUEUE.jobs[0].args[2] == event.timestamp + assert HIGH_QUEUE.jobs[0].args[3] == jsonable_data + + assert HIGH_QUEUE.jobs[1].args[0] == webhooks[1].id + assert HIGH_QUEUE.jobs[1].args[1] == ResponseEvent.created + assert HIGH_QUEUE.jobs[1].args[2] == event.timestamp + assert HIGH_QUEUE.jobs[1].args[3] == jsonable_data