Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add missing tests for webhooks feature #5537

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from argilla_server.models import Dataset, User
from argilla_server.search_engine import SearchEngine, get_search_engine
from argilla_server.security import auth
from argilla_server.telemetry import TelemetryClient, get_telemetry_client

router = APIRouter()

Expand All @@ -43,7 +42,6 @@ async def create_dataset_records_bulk(
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
):
dataset = await Dataset.get_or_raise(
db,
Expand All @@ -58,9 +56,7 @@ async def create_dataset_records_bulk(

await authorize(current_user, DatasetPolicy.create_records(dataset))

records_bulk = await CreateRecordsBulk(db, search_engine).create_records_bulk(dataset, records_bulk_create)

return records_bulk
return await CreateRecordsBulk(db, search_engine).create_records_bulk(dataset, records_bulk_create)


@router.put("/datasets/{dataset_id}/records/bulk", response_model=RecordsBulk)
Expand All @@ -71,7 +67,6 @@ async def upsert_dataset_records_bulk(
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
):
dataset = await Dataset.get_or_raise(
db,
Expand All @@ -86,9 +81,4 @@ async def upsert_dataset_records_bulk(

await authorize(current_user, DatasetPolicy.upsert_records(dataset))

records_bulk = await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, records_bulk_create)

updated = len(records_bulk.updated_item_ids)
created = len(records_bulk.items) - updated

return records_bulk
return await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, records_bulk_create)
4 changes: 2 additions & 2 deletions argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp

await self._db.commit()

await self._notify_record_events(records)
await self._notify_upsert_record_events(records)

return RecordsBulkWithUpdateInfo(
items=records,
Expand All @@ -241,7 +241,7 @@ async def _fetch_existing_dataset_records(

return {**records_by_external_id, **records_by_id}

async def _notify_record_events(self, records: List[Record]) -> None:
async def _notify_upsert_record_events(self, records: List[Record]) -> None:
for record in records:
if record.inserted_at == record.updated_at:
await notify_record_event_v1(self._db, RecordEvent.created, record)
Expand Down
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ async def delete_records(
) -> None:
params = [Record.id.in_(records_ids), Record.dataset_id == dataset.id]

records = (await db.execute(select(Record).filter(*params))).scalars().all()
records = (await db.execute(select(Record).filter(*params).order_by(Record.inserted_at.asc()))).scalars().all()

deleted_record_events_v1 = []
for record in records:
Expand Down
23 changes: 18 additions & 5 deletions argilla-server/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from typing import TYPE_CHECKING, AsyncGenerator, Generator

import httpx
import asyncio
import pytest
import pytest_asyncio

from rq import Queue
from typing import TYPE_CHECKING, AsyncGenerator, Generator
from sqlalchemy import NullPool, create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine

from argilla_server.cli.database.migrate import migrate_db
from argilla_server.database import database_url_sync
from argilla_server.jobs.queues import REDIS_CONNECTION
from argilla_server.settings import settings
from sqlalchemy import NullPool, create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine

from tests.database import SyncTestSession, TestSession, set_task

Expand Down Expand Up @@ -97,6 +100,16 @@ def sync_db(sync_connection: "Connection") -> Generator["Session", None, None]:
sync_connection.rollback()


@pytest.fixture(autouse=True)
def empty_job_queues():
queues = Queue.all(connection=REDIS_CONNECTION)

for queue in queues:
queue.empty()

yield


@pytest.fixture
def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession":
"""Create a mocked `AsyncSession` that proxies to the sync session. This will allow us to execute the async CLI commands
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from uuid import UUID

import pytest
from argilla_server.enums import DatasetStatus, QuestionType, ResponseStatus, SuggestionType
from argilla_server.models.database import Record, Response, Suggestion, User

from uuid import UUID
from httpx import AsyncClient
from fastapi.encoders import jsonable_encoder
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.jobs.queues import HIGH_QUEUE
from argilla_server.webhooks.v1.enums import RecordEvent
from argilla_server.webhooks.v1.records import build_record_event
from argilla_server.models.database import Record, Response, Suggestion, User
from argilla_server.enums import DatasetStatus, QuestionType, ResponseStatus, SuggestionType

from tests.factories import (
DatasetFactory,
LabelSelectionQuestionFactory,
Expand All @@ -32,6 +37,7 @@
ImageFieldFactory,
TextQuestionFactory,
ChatFieldFactory,
WebhookFactory,
)


Expand Down Expand Up @@ -551,3 +557,48 @@ async def test_create_dataset_records_bulk_with_chat_field_without_content_key(
}
}
assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 0

async def test_create_dataset_records_bulk_enqueue_webhook_record_created_events(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)
await TextFieldFactory.create(name="prompt", dataset=dataset)
await TextQuestionFactory.create(name="text-question", dataset=dataset)

webhook = await WebhookFactory.create(events=[RecordEvent.created])

response = await async_client.post(
self.url(dataset.id),
headers=owner_auth_header,
json={
"items": [
{
"fields": {
"prompt": "Does exercise help reduce stress?",
},
},
{
"fields": {
"prompt": "What is the best way to reduce stress?",
},
},
],
},
)

assert response.status_code == 201

records = (await db.execute(select(Record).order_by(Record.inserted_at.asc()))).scalars().all()

event_a = await build_record_event(db, RecordEvent.created, records[0])
event_b = await build_record_event(db, RecordEvent.created, records[1])

assert HIGH_QUEUE.count == 2

assert HIGH_QUEUE.jobs[0].args[0] == webhook.id
assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.created
assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data)

assert HIGH_QUEUE.jobs[1].args[0] == webhook.id
assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.created
assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data)
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,26 @@

from uuid import UUID
from httpx import AsyncClient
from fastapi.encoders import jsonable_encoder
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.models import User

from argilla_server.models import User, Record
from argilla_server.jobs.queues import HIGH_QUEUE
from argilla_server.enums import DatasetDistributionStrategy, ResponseStatus, DatasetStatus, RecordStatus
from argilla_server.webhooks.v1.enums import RecordEvent
from argilla_server.webhooks.v1.records import build_record_event

from tests.factories import DatasetFactory, RecordFactory, TextQuestionFactory, ResponseFactory, AnnotatorFactory
from tests.factories import (
DatasetFactory,
RecordFactory,
TextFieldFactory,
TextQuestionFactory,
AnnotatorFactory,
WebhookFactory,
ResponseFactory,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -151,3 +166,95 @@ async def test_upsert_dataset_records_bulk_updates_records_status(
assert record_b.status == RecordStatus.pending
assert record_c.status == RecordStatus.pending
assert record_d.status == RecordStatus.pending

async def test_upsert_dataset_records_bulk_enqueue_webhook_record_created_events(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)
await TextFieldFactory.create(name="prompt", dataset=dataset)
await TextQuestionFactory.create(name="text-question", dataset=dataset)

webhook = await WebhookFactory.create(events=[RecordEvent.created, RecordEvent.updated])

response = await async_client.put(
self.url(dataset.id),
headers=owner_auth_header,
json={
"items": [
{
"fields": {
"prompt": "Does exercise help reduce stress?",
},
},
{
"fields": {
"prompt": "What is the best way to reduce stress?",
},
},
],
},
)

assert response.status_code == 200

records = (await db.execute(select(Record).order_by(Record.inserted_at.asc()))).scalars().all()

event_a = await build_record_event(db, RecordEvent.created, records[0])
event_b = await build_record_event(db, RecordEvent.created, records[1])

assert HIGH_QUEUE.count == 2

assert HIGH_QUEUE.jobs[0].args[0] == webhook.id
assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.created
assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data)

assert HIGH_QUEUE.jobs[1].args[0] == webhook.id
assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.created
assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data)

async def test_upsert_dataset_records_bulk_enqueue_webhook_record_updated_events(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)
await TextFieldFactory.create(name="prompt", dataset=dataset)
await TextQuestionFactory.create(name="text-question", dataset=dataset)

records = await RecordFactory.create_batch(2, dataset=dataset)

webhook = await WebhookFactory.create(events=[RecordEvent.created, RecordEvent.updated])

response = await async_client.put(
self.url(dataset.id),
headers=owner_auth_header,
json={
"items": [
{
"id": str(records[0].id),
"metadata": {
"metadata-key": "metadata-value",
},
},
{
"id": str(records[1].id),
"metadata": {
"metadata-key": "metadata-value",
},
},
],
},
)

assert response.status_code == 200

event_a = await build_record_event(db, RecordEvent.updated, records[0])
event_b = await build_record_event(db, RecordEvent.updated, records[1])

assert HIGH_QUEUE.count == 2

assert HIGH_QUEUE.jobs[0].args[0] == webhook.id
assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.updated
assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data)

assert HIGH_QUEUE.jobs[1].args[0] == webhook.id
assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.updated
assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from uuid import UUID
from httpx import AsyncClient
from fastapi.encoders import jsonable_encoder
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.jobs.queues import HIGH_QUEUE
from argilla_server.webhooks.v1.enums import RecordEvent
from argilla_server.webhooks.v1.records import build_record_event

from tests.factories import DatasetFactory, RecordFactory, WebhookFactory


@pytest.mark.asyncio
class TestDeleteDatasetRecords:
def url(self, dataset_id: UUID) -> str:
return f"/api/v1/datasets/{dataset_id}/records"

async def test_delete_dataset_records_enqueue_webhook_record_deleted_events(
self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict
):
dataset = await DatasetFactory.create()
records = await RecordFactory.create_batch(2, dataset=dataset)
webhook = await WebhookFactory.create(events=[RecordEvent.deleted])

event_a = await build_record_event(db, RecordEvent.deleted, records[0])
event_b = await build_record_event(db, RecordEvent.deleted, records[1])

response = await async_client.delete(
self.url(dataset.id),
headers=owner_auth_header,
params={"ids": f"{records[0].id},{records[1].id}"},
)

assert response.status_code == 204

assert HIGH_QUEUE.count == 2

assert HIGH_QUEUE.jobs[0].args[0] == webhook.id
assert HIGH_QUEUE.jobs[0].args[1] == RecordEvent.deleted
assert HIGH_QUEUE.jobs[0].args[3] == jsonable_encoder(event_a.data)

assert HIGH_QUEUE.jobs[1].args[0] == webhook.id
assert HIGH_QUEUE.jobs[1].args[1] == RecordEvent.deleted
assert HIGH_QUEUE.jobs[1].args[3] == jsonable_encoder(event_b.data)
Loading
Loading