From b8847e9792e85f23e5b9e421d893330c0113b579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 31 May 2024 17:58:54 +0200 Subject: [PATCH] refactor: remove API v1 responses exception captures (#4904) # Description This PR remove all exception capture blocks for API v1 responses handler. Refs #4871 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [x] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [x] Improving and passing existent tests. **Checklist** - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../apis/v1/handlers/responses.py | 13 ++--- .../src/argilla_server/bulk/records_bulk.py | 9 ++-- .../src/argilla_server/contexts/datasets.py | 32 ++++++------ .../validators/response_values.py | 49 ++++++++++++------- .../argilla_server/validators/responses.py | 11 +++-- .../tests/unit/api/v1/test_records.py | 9 ++++ .../unit/api/v1/users/test_create_user.py | 1 + 7 files changed, 71 insertions(+), 53 deletions(-) diff --git a/argilla-server/src/argilla_server/apis/v1/handlers/responses.py b/argilla-server/src/argilla_server/apis/v1/handlers/responses.py index 49b1059859..0ad493d7ae 100644 --- a/argilla-server/src/argilla_server/apis/v1/handlers/responses.py +++ b/argilla-server/src/argilla_server/apis/v1/handlers/responses.py @@ -14,7 +14,7 @@ from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Security, status +from fastapi import APIRouter, Depends, Security, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -69,12 +69,7 @@ async def update_response( await authorize(current_user, ResponsePolicyV1.update(response)) - # TODO: We should split API v1 into different FastAPI apps so we can customize error management. - # After mapping ValueError to 422 errors for API v1 then we can remove this try except. - try: - return await datasets.update_response(db, search_engine, response, response_update) - except ValueError as err: - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err)) + return await datasets.update_response(db, search_engine, response, response_update) @router.delete("/responses/{response_id}", response_model=ResponseSchema) @@ -93,6 +88,4 @@ async def delete_response( await authorize(current_user, ResponsePolicyV1.delete(response)) - await datasets.delete_response(db, search_engine, response) - - return response + return await datasets.delete_response(db, search_engine, response) diff --git a/argilla-server/src/argilla_server/bulk/records_bulk.py b/argilla-server/src/argilla_server/bulk/records_bulk.py index 043b9dc7a4..100cc55d5e 100644 --- a/argilla-server/src/argilla_server/bulk/records_bulk.py +++ b/argilla-server/src/argilla_server/bulk/records_bulk.py @@ -25,6 +25,7 @@ fetch_records_by_external_ids_as_dict, fetch_records_by_ids_as_dict, ) +from argilla_server.errors.future import UnprocessableEntityError from argilla_server.models import Dataset, Record, Response, Suggestion, Vector, VectorSettings from argilla_server.schemas.v1.records import RecordCreate, RecordUpsert from argilla_server.schemas.v1.records_bulk import ( @@ -99,10 +100,10 @@ async def _upsert_records_suggestions( try: SuggestionCreateValidator(suggestion_create).validate_for(question.parsed_settings, record) upsert_many_suggestions.append(dict(**suggestion_create.dict(), record_id=record.id)) - except ValueError as ex: + except (UnprocessableEntityError, ValueError) as ex: raise ValueError(f"suggestion for question name={question.name} is not valid: {ex}") - except ValueError as ex: + except (UnprocessableEntityError, ValueError) as ex: raise ValueError(f"Record at position {idx} does not have valid suggestions because {ex}") from ex if not upsert_many_suggestions: @@ -131,7 +132,7 @@ async def _upsert_records_responses( ResponseCreateValidator(response_create).validate_for(record) upsert_many_responses.append(dict(**response_create.dict(), record_id=record.id)) - except ValueError as ex: + except (UnprocessableEntityError, ValueError) as ex: raise ValueError(f"Record at position {idx} does not have valid responses because {ex}") from ex if not upsert_many_responses: @@ -158,7 +159,7 @@ async def _upsert_records_vectors( VectorValidator(value).validate_for(settings) upsert_many_vectors.append(dict(value=value, record_id=record.id, vector_settings_id=settings.id)) - except ValueError as ex: + except (UnprocessableEntityError, ValueError) as ex: raise ValueError(f"Record at position {idx} does not have valid vectors because {ex}") from ex if not upsert_many_vectors: diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index dcbee0a6b9..08ed907154 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -37,9 +37,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, selectinload -import argilla_server.errors.future as errors from argilla_server.contexts import accounts, questions from argilla_server.enums import DatasetStatus, RecordInclude, UserRole +from argilla_server.errors.future import NotFoundError, NotUniqueError, UnprocessableEntityError from argilla_server.models import ( Dataset, Field, @@ -131,10 +131,10 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) -> async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate): if await Workspace.get(db, dataset_create.workspace_id) is None: - raise errors.UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found") + raise UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found") if await get_dataset_by_name_and_workspace_id(db, dataset_create.name, dataset_create.workspace_id): - raise errors.NotUniqueError( + raise NotUniqueError( f"Dataset with name `{dataset_create.name}` already exists for workspace with id `{dataset_create.workspace_id}`" ) @@ -166,13 +166,13 @@ def _allowed_roles_for_metadata_property_create(metadata_property_create: Metada async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset: if dataset.is_ready: - raise errors.UnprocessableEntityError("Dataset is already published") + raise UnprocessableEntityError("Dataset is already published") if await _count_required_fields_by_dataset_id(db, dataset.id) == 0: - raise errors.UnprocessableEntityError("Dataset cannot be published without required fields") + raise UnprocessableEntityError("Dataset cannot be published without required fields") if await _count_required_questions_by_dataset_id(db, dataset.id) == 0: - raise errors.UnprocessableEntityError("Dataset cannot be published without required questions") + raise UnprocessableEntityError("Dataset cannot be published without required questions") async with db.begin_nested(): dataset = await dataset.update(db, status=DatasetStatus.ready, autocommit=False) @@ -205,12 +205,10 @@ async def get_field_by_name_and_dataset_id(db: AsyncSession, name: str, dataset_ async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCreate) -> Field: if dataset.is_ready: - raise errors.UnprocessableEntityError("Field cannot be created for a published dataset") + raise UnprocessableEntityError("Field cannot be created for a published dataset") if await get_field_by_name_and_dataset_id(db, field_create.name, dataset.id): - raise errors.NotUniqueError( - f"Field with name `{field_create.name}` already exists for dataset with id `{dataset.id}`" - ) + raise NotUniqueError(f"Field with name `{field_create.name}` already exists for dataset with id `{dataset.id}`") return await Field.create( db, @@ -224,7 +222,7 @@ async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCr async def update_field(db: AsyncSession, field: Field, field_update: "FieldUpdate") -> Field: if field_update.settings and field_update.settings.type != field.settings["type"]: - raise errors.UnprocessableEntityError( + raise UnprocessableEntityError( f"Field type cannot be changed. Expected '{field.settings['type']}' but got '{field_update.settings.type}'" ) @@ -234,7 +232,7 @@ async def update_field(db: AsyncSession, field: Field, field_update: "FieldUpdat async def delete_field(db: AsyncSession, field: Field) -> Field: if field.dataset.is_ready: - raise errors.UnprocessableEntityError("Fields cannot be deleted for a published dataset") + raise UnprocessableEntityError("Fields cannot be deleted for a published dataset") return await field.delete(db) @@ -252,7 +250,7 @@ async def get_metadata_property_by_name_and_dataset_id_or_raise( ) -> MetadataProperty: metadata_property = await get_metadata_property_by_name_and_dataset_id(db, name, dataset_id) if metadata_property is None: - raise errors.NotFoundError(f"Metadata property with name `{name}` not found for dataset with id `{dataset_id}`") + raise NotFoundError(f"Metadata property with name `{name}` not found for dataset with id `{dataset_id}`") return metadata_property @@ -268,7 +266,7 @@ async def create_metadata_property( metadata_property_create: MetadataPropertyCreate, ) -> MetadataProperty: if await get_metadata_property_by_name_and_dataset_id(db, metadata_property_create.name, dataset.id): - raise errors.NotUniqueError( + raise NotUniqueError( f"Metadata property with name `{metadata_property_create.name}` already exists for dataset with id `{dataset.id}`" ) @@ -330,12 +328,12 @@ async def create_vector_settings( db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, vector_settings_create: "VectorSettingsCreate" ) -> VectorSettings: if await count_vectors_settings_by_dataset_id(db, dataset.id) >= CREATE_DATASET_VECTOR_SETTINGS_MAX_COUNT: - raise errors.UnprocessableEntityError( + raise UnprocessableEntityError( f"The maximum number of vector settings has been reached for dataset with id `{dataset.id}`" ) if await get_vector_settings_by_name_and_dataset_id(db, vector_settings_create.name, dataset.id): - raise errors.NotUniqueError( + raise NotUniqueError( f"Vector settings with name `{vector_settings_create.name}` already exists for dataset with id `{dataset.id}`" ) @@ -696,7 +694,7 @@ async def _build_record_suggestions( ) ) - except ValueError as e: + except (UnprocessableEntityError, ValueError) as e: raise ValueError(f"suggestion for question_id={suggestion_create.question_id} is not valid: {e}") from e return suggestions diff --git a/argilla-server/src/argilla_server/validators/response_values.py b/argilla-server/src/argilla_server/validators/response_values.py index ffd68f0d76..9df2d8885b 100644 --- a/argilla-server/src/argilla_server/validators/response_values.py +++ b/argilla-server/src/argilla_server/validators/response_values.py @@ -15,6 +15,7 @@ from typing import Optional from argilla_server.enums import QuestionType, ResponseStatus +from argilla_server.errors.future import UnprocessableEntityError from argilla_server.models import Record from argilla_server.schemas.v1.questions import ( LabelSelectionQuestionSettings, @@ -54,7 +55,7 @@ def validate_for( elif question_settings.type == QuestionType.span: SpanQuestionResponseValueValidator(self._response_value).validate_for(question_settings, record) else: - raise ValueError(f"unknown question type f{question_settings.type!r}") + raise UnprocessableEntityError(f"unknown question type f{question_settings.type!r}") class TextQuestionResponseValueValidator: @@ -66,7 +67,7 @@ def validate(self) -> None: def _validate_value_type(self) -> None: if not isinstance(self._response_value, str): - raise ValueError(f"text question expects a text value, found {type(self._response_value)}") + raise UnprocessableEntityError(f"text question expects a text value, found {type(self._response_value)}") class LabelSelectionQuestionResponseValueValidator: @@ -82,7 +83,7 @@ def _validate_label_is_available_at_question_settings( available_labels = [option.value for option in label_selection_question_settings.options] if self._response_value not in available_labels: - raise ValueError( + raise UnprocessableEntityError( f"{self._response_value!r} is not a valid label for label selection question.\nValid labels are: {available_labels!r}" ) @@ -99,17 +100,17 @@ def validate_for(self, multi_label_selection_question_settings: MultiLabelSelect def _validate_value_type(self) -> None: if not isinstance(self._response_value, list): - raise ValueError( + raise UnprocessableEntityError( f"multi label selection questions expects a list of values, found {type(self._response_value)}" ) def _validate_labels_are_not_empty(self) -> None: if len(self._response_value) == 0: - raise ValueError("multi label selection questions expects a list of values, found empty list") + raise UnprocessableEntityError("multi label selection questions expects a list of values, found empty list") def _validate_labels_are_unique(self) -> None: if len(self._response_value) != len(set(self._response_value)): - raise ValueError( + raise UnprocessableEntityError( "multi label selection questions expect a list of unique values, but duplicates were found" ) @@ -120,7 +121,7 @@ def _validate_labels_are_available_at_question_settings( invalid_labels = sorted(list(set(self._response_value) - set(available_labels))) if invalid_labels: - raise ValueError( + raise UnprocessableEntityError( f"{invalid_labels!r} are not valid labels for multi label selection question.\nValid labels are: {available_labels!r}" ) @@ -138,7 +139,7 @@ def _validate_rating_is_available_at_question_settings( available_options = [option.value for option in rating_question_settings.options] if self._response_value not in available_options: - raise ValueError( + raise UnprocessableEntityError( f"{self._response_value!r} is not a valid rating for rating question.\nValid ratings are: {available_options!r}" ) @@ -158,7 +159,9 @@ def validate_for( def _validate_value_type(self) -> None: if not isinstance(self._response_value, list): - raise ValueError(f"ranking question expects a list of values, found {type(self._response_value)}") + raise UnprocessableEntityError( + f"ranking question expects a list of values, found {type(self._response_value)}" + ) def _validate_all_rankings_are_present_when_submitted( self, ranking_question_settings: RankingQuestionSettings, response_status: Optional[ResponseStatus] = None @@ -170,7 +173,7 @@ def _validate_all_rankings_are_present_when_submitted( available_values_len = len(available_values) if len(self._response_value) != available_values_len: - raise ValueError( + raise UnprocessableEntityError( f"ranking question expects a list containing {available_values_len} values, found a list of {len(self._response_value)} values" ) @@ -187,7 +190,7 @@ def _validate_all_rankings_are_valid_when_submitted( invalid_rankings = sorted(list(set(response_rankings) - set(available_rankings))) if invalid_rankings: - raise ValueError( + raise UnprocessableEntityError( f"{invalid_rankings!r} are not valid ranks for ranking question.\nValid ranks are: {available_rankings!r}" ) @@ -199,7 +202,7 @@ def _validate_values_are_available_at_question_settings( invalid_values = sorted(list(set(response_values) - set(available_values))) if invalid_values: - raise ValueError( + raise UnprocessableEntityError( f"{invalid_values!r} are not valid values for ranking question.\nValid values are: {available_values!r}" ) @@ -207,7 +210,9 @@ def _validate_values_are_unique(self) -> None: response_values = [value_item.value for value_item in self._response_value] if len(response_values) != len(set(response_values)): - raise ValueError("ranking question expects a list of unique values, but duplicates were found") + raise UnprocessableEntityError( + "ranking question expects a list of unique values, but duplicates were found" + ) class SpanQuestionResponseValueValidator: @@ -223,13 +228,17 @@ def validate_for(self, span_question_settings: SpanQuestionSettings, record: Rec def _validate_value_type(self) -> None: if not isinstance(self._response_value, list): - raise ValueError(f"span question expects a list of values, found {type(self._response_value)}") + raise UnprocessableEntityError( + f"span question expects a list of values, found {type(self._response_value)}" + ) def _validate_question_settings_field_is_present_at_record( self, span_question_settings: SpanQuestionSettings, record: Record ) -> None: if span_question_settings.field not in record.fields: - raise ValueError(f"span question requires record to have field `{span_question_settings.field}`") + raise UnprocessableEntityError( + f"span question requires record to have field `{span_question_settings.field}`" + ) def _validate_start_end_are_within_record_field_limits( self, span_question_settings: SpanQuestionSettings, record: Record @@ -238,12 +247,12 @@ def _validate_start_end_are_within_record_field_limits( for value_item in self._response_value: if value_item.start > (field_len - 1): - raise ValueError( + raise UnprocessableEntityError( f"span question response value `start` must have a value lower than record field `{span_question_settings.field}` length that is `{field_len}`" ) if value_item.end > field_len: - raise ValueError( + raise UnprocessableEntityError( f"span question response value `end` must have a value lower or equal than record field `{span_question_settings.field}` length that is `{field_len}`" ) @@ -252,7 +261,7 @@ def _validate_labels_are_available_at_question_settings(self, span_question_sett for value_item in self._response_value: if not value_item.label in available_labels: - raise ValueError( + raise UnprocessableEntityError( f"undefined label '{value_item.label}' for span question.\nValid labels are: {available_labels!r}" ) @@ -267,4 +276,6 @@ def _validate_values_are_not_overlapped(self, span_question_settings: SpanQuesti and value_item.start < other_value_item.end and value_item.end > other_value_item.start ): - raise ValueError(f"overlapping values found between spans at index idx={span_i} and idx={span_j}") + raise UnprocessableEntityError( + f"overlapping values found between spans at index idx={span_i} and idx={span_j}" + ) diff --git a/argilla-server/src/argilla_server/validators/responses.py b/argilla-server/src/argilla_server/validators/responses.py index d833aaa7b9..cc9b093e47 100644 --- a/argilla-server/src/argilla_server/validators/responses.py +++ b/argilla-server/src/argilla_server/validators/responses.py @@ -15,6 +15,7 @@ from typing import Union from argilla_server.enums import QuestionType, ResponseStatus +from argilla_server.errors.future import UnprocessableEntityError from argilla_server.models import Record from argilla_server.schemas.v1.responses import ResponseCreate, ResponseUpdate, ResponseUpsert from argilla_server.validators.response_values import ResponseValueValidator @@ -36,19 +37,23 @@ def _is_submitted_response(self) -> bool: def _validate_values_are_present_when_submitted(self) -> None: if self._is_submitted_response and not self._response_change.values: - raise ValueError("missing response values for submitted response") + raise UnprocessableEntityError("missing response values for submitted response") def _validate_required_questions_have_values(self, record: Record) -> None: for question in record.dataset.questions: if self._is_submitted_response and question.required and question.name not in self._response_change.values: - raise ValueError(f"missing response value for required question with name={question.name!r}") + raise UnprocessableEntityError( + f"missing response value for required question with name={question.name!r}" + ) def _validate_values_have_configured_questions(self, record: Record) -> None: question_names = [question.name for question in record.dataset.questions] for value_question_name in self._response_change.values or []: if value_question_name not in question_names: - raise ValueError(f"found response value for non configured question with name={value_question_name!r}") + raise UnprocessableEntityError( + f"found response value for non configured question with name={value_question_name!r}" + ) def _validate_values(self, record: Record) -> None: if not self._response_change.values: diff --git a/argilla-server/tests/unit/api/v1/test_records.py b/argilla-server/tests/unit/api/v1/test_records.py index 99fe09e930..3400fd7e71 100644 --- a/argilla-server/tests/unit/api/v1/test_records.py +++ b/argilla-server/tests/unit/api/v1/test_records.py @@ -816,6 +816,15 @@ async def test_create_record_response_with_extra_question_responses( }, "multi label selection questions expects a list of values, found ", ), + ( + create_multi_label_selection_questions, + { + "values": { + "multi_label_selection_question_1": {"value": ["option1", "option2", "option1"]}, + }, + }, + "multi label selection questions expect a list of unique values, but duplicates were found", + ), ( create_multi_label_selection_questions, { diff --git a/argilla-server/tests/unit/api/v1/users/test_create_user.py b/argilla-server/tests/unit/api/v1/users/test_create_user.py index 90eebbf0c2..2b6aa03710 100644 --- a/argilla-server/tests/unit/api/v1/users/test_create_user.py +++ b/argilla-server/tests/unit/api/v1/users/test_create_user.py @@ -200,6 +200,7 @@ async def test_create_user_with_existent_username( assert response.status_code == 409 assert response.json() == {"detail": "User username `username` is not unique"} + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 async def test_create_user_with_invalid_username(