Skip to content

Commit

Permalink
refactor: remove API v1 responses exception captures (#4904)
Browse files Browse the repository at this point in the history
# Description

This PR remove all exception capture blocks for API v1 responses
handler.

Refs #4871 

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [x] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [x] Improving and passing existent tests.

**Checklist**

- [ ] I added relevant documentation
- [ ] follows the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
jfcalvo authored May 31, 2024
1 parent e3aa28a commit b8847e9
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 53 deletions.
13 changes: 3 additions & 10 deletions argilla-server/src/argilla_server/apis/v1/handlers/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Security, status
from fastapi import APIRouter, Depends, Security, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

Expand Down Expand Up @@ -69,12 +69,7 @@ async def update_response(

await authorize(current_user, ResponsePolicyV1.update(response))

# TODO: We should split API v1 into different FastAPI apps so we can customize error management.
# After mapping ValueError to 422 errors for API v1 then we can remove this try except.
try:
return await datasets.update_response(db, search_engine, response, response_update)
except ValueError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))
return await datasets.update_response(db, search_engine, response, response_update)


@router.delete("/responses/{response_id}", response_model=ResponseSchema)
Expand All @@ -93,6 +88,4 @@ async def delete_response(

await authorize(current_user, ResponsePolicyV1.delete(response))

await datasets.delete_response(db, search_engine, response)

return response
return await datasets.delete_response(db, search_engine, response)
9 changes: 5 additions & 4 deletions argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
fetch_records_by_external_ids_as_dict,
fetch_records_by_ids_as_dict,
)
from argilla_server.errors.future import UnprocessableEntityError
from argilla_server.models import Dataset, Record, Response, Suggestion, Vector, VectorSettings
from argilla_server.schemas.v1.records import RecordCreate, RecordUpsert
from argilla_server.schemas.v1.records_bulk import (
Expand Down Expand Up @@ -99,10 +100,10 @@ async def _upsert_records_suggestions(
try:
SuggestionCreateValidator(suggestion_create).validate_for(question.parsed_settings, record)
upsert_many_suggestions.append(dict(**suggestion_create.dict(), record_id=record.id))
except ValueError as ex:
except (UnprocessableEntityError, ValueError) as ex:
raise ValueError(f"suggestion for question name={question.name} is not valid: {ex}")

except ValueError as ex:
except (UnprocessableEntityError, ValueError) as ex:
raise ValueError(f"Record at position {idx} does not have valid suggestions because {ex}") from ex

if not upsert_many_suggestions:
Expand Down Expand Up @@ -131,7 +132,7 @@ async def _upsert_records_responses(

ResponseCreateValidator(response_create).validate_for(record)
upsert_many_responses.append(dict(**response_create.dict(), record_id=record.id))
except ValueError as ex:
except (UnprocessableEntityError, ValueError) as ex:
raise ValueError(f"Record at position {idx} does not have valid responses because {ex}") from ex

if not upsert_many_responses:
Expand All @@ -158,7 +159,7 @@ async def _upsert_records_vectors(

VectorValidator(value).validate_for(settings)
upsert_many_vectors.append(dict(value=value, record_id=record.id, vector_settings_id=settings.id))
except ValueError as ex:
except (UnprocessableEntityError, ValueError) as ex:
raise ValueError(f"Record at position {idx} does not have valid vectors because {ex}") from ex

if not upsert_many_vectors:
Expand Down
32 changes: 15 additions & 17 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, joinedload, selectinload

import argilla_server.errors.future as errors
from argilla_server.contexts import accounts, questions
from argilla_server.enums import DatasetStatus, RecordInclude, UserRole
from argilla_server.errors.future import NotFoundError, NotUniqueError, UnprocessableEntityError
from argilla_server.models import (
Dataset,
Field,
Expand Down Expand Up @@ -131,10 +131,10 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) ->

async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate):
if await Workspace.get(db, dataset_create.workspace_id) is None:
raise errors.UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found")
raise UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found")

if await get_dataset_by_name_and_workspace_id(db, dataset_create.name, dataset_create.workspace_id):
raise errors.NotUniqueError(
raise NotUniqueError(
f"Dataset with name `{dataset_create.name}` already exists for workspace with id `{dataset_create.workspace_id}`"
)

Expand Down Expand Up @@ -166,13 +166,13 @@ def _allowed_roles_for_metadata_property_create(metadata_property_create: Metada

async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset:
if dataset.is_ready:
raise errors.UnprocessableEntityError("Dataset is already published")
raise UnprocessableEntityError("Dataset is already published")

if await _count_required_fields_by_dataset_id(db, dataset.id) == 0:
raise errors.UnprocessableEntityError("Dataset cannot be published without required fields")
raise UnprocessableEntityError("Dataset cannot be published without required fields")

if await _count_required_questions_by_dataset_id(db, dataset.id) == 0:
raise errors.UnprocessableEntityError("Dataset cannot be published without required questions")
raise UnprocessableEntityError("Dataset cannot be published without required questions")

async with db.begin_nested():
dataset = await dataset.update(db, status=DatasetStatus.ready, autocommit=False)
Expand Down Expand Up @@ -205,12 +205,10 @@ async def get_field_by_name_and_dataset_id(db: AsyncSession, name: str, dataset_

async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCreate) -> Field:
if dataset.is_ready:
raise errors.UnprocessableEntityError("Field cannot be created for a published dataset")
raise UnprocessableEntityError("Field cannot be created for a published dataset")

if await get_field_by_name_and_dataset_id(db, field_create.name, dataset.id):
raise errors.NotUniqueError(
f"Field with name `{field_create.name}` already exists for dataset with id `{dataset.id}`"
)
raise NotUniqueError(f"Field with name `{field_create.name}` already exists for dataset with id `{dataset.id}`")

return await Field.create(
db,
Expand All @@ -224,7 +222,7 @@ async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCr

async def update_field(db: AsyncSession, field: Field, field_update: "FieldUpdate") -> Field:
if field_update.settings and field_update.settings.type != field.settings["type"]:
raise errors.UnprocessableEntityError(
raise UnprocessableEntityError(

Check warning on line 225 in argilla-server/src/argilla_server/contexts/datasets.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/contexts/datasets.py#L225

Added line #L225 was not covered by tests
f"Field type cannot be changed. Expected '{field.settings['type']}' but got '{field_update.settings.type}'"
)

Expand All @@ -234,7 +232,7 @@ async def update_field(db: AsyncSession, field: Field, field_update: "FieldUpdat

async def delete_field(db: AsyncSession, field: Field) -> Field:
if field.dataset.is_ready:
raise errors.UnprocessableEntityError("Fields cannot be deleted for a published dataset")
raise UnprocessableEntityError("Fields cannot be deleted for a published dataset")

return await field.delete(db)

Expand All @@ -252,7 +250,7 @@ async def get_metadata_property_by_name_and_dataset_id_or_raise(
) -> MetadataProperty:
metadata_property = await get_metadata_property_by_name_and_dataset_id(db, name, dataset_id)
if metadata_property is None:
raise errors.NotFoundError(f"Metadata property with name `{name}` not found for dataset with id `{dataset_id}`")
raise NotFoundError(f"Metadata property with name `{name}` not found for dataset with id `{dataset_id}`")

return metadata_property

Expand All @@ -268,7 +266,7 @@ async def create_metadata_property(
metadata_property_create: MetadataPropertyCreate,
) -> MetadataProperty:
if await get_metadata_property_by_name_and_dataset_id(db, metadata_property_create.name, dataset.id):
raise errors.NotUniqueError(
raise NotUniqueError(
f"Metadata property with name `{metadata_property_create.name}` already exists for dataset with id `{dataset.id}`"
)

Expand Down Expand Up @@ -330,12 +328,12 @@ async def create_vector_settings(
db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, vector_settings_create: "VectorSettingsCreate"
) -> VectorSettings:
if await count_vectors_settings_by_dataset_id(db, dataset.id) >= CREATE_DATASET_VECTOR_SETTINGS_MAX_COUNT:
raise errors.UnprocessableEntityError(
raise UnprocessableEntityError(
f"The maximum number of vector settings has been reached for dataset with id `{dataset.id}`"
)

if await get_vector_settings_by_name_and_dataset_id(db, vector_settings_create.name, dataset.id):
raise errors.NotUniqueError(
raise NotUniqueError(
f"Vector settings with name `{vector_settings_create.name}` already exists for dataset with id `{dataset.id}`"
)

Expand Down Expand Up @@ -696,7 +694,7 @@ async def _build_record_suggestions(
)
)

except ValueError as e:
except (UnprocessableEntityError, ValueError) as e:
raise ValueError(f"suggestion for question_id={suggestion_create.question_id} is not valid: {e}") from e

return suggestions
Expand Down
49 changes: 30 additions & 19 deletions argilla-server/src/argilla_server/validators/response_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Optional

from argilla_server.enums import QuestionType, ResponseStatus
from argilla_server.errors.future import UnprocessableEntityError
from argilla_server.models import Record
from argilla_server.schemas.v1.questions import (
LabelSelectionQuestionSettings,
Expand Down Expand Up @@ -54,7 +55,7 @@ def validate_for(
elif question_settings.type == QuestionType.span:
SpanQuestionResponseValueValidator(self._response_value).validate_for(question_settings, record)
else:
raise ValueError(f"unknown question type f{question_settings.type!r}")
raise UnprocessableEntityError(f"unknown question type f{question_settings.type!r}")

Check warning on line 58 in argilla-server/src/argilla_server/validators/response_values.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/validators/response_values.py#L58

Added line #L58 was not covered by tests


class TextQuestionResponseValueValidator:
Expand All @@ -66,7 +67,7 @@ def validate(self) -> None:

def _validate_value_type(self) -> None:
if not isinstance(self._response_value, str):
raise ValueError(f"text question expects a text value, found {type(self._response_value)}")
raise UnprocessableEntityError(f"text question expects a text value, found {type(self._response_value)}")

Check warning on line 70 in argilla-server/src/argilla_server/validators/response_values.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/validators/response_values.py#L70

Added line #L70 was not covered by tests


class LabelSelectionQuestionResponseValueValidator:
Expand All @@ -82,7 +83,7 @@ def _validate_label_is_available_at_question_settings(
available_labels = [option.value for option in label_selection_question_settings.options]

if self._response_value not in available_labels:
raise ValueError(
raise UnprocessableEntityError(
f"{self._response_value!r} is not a valid label for label selection question.\nValid labels are: {available_labels!r}"
)

Expand All @@ -99,17 +100,17 @@ def validate_for(self, multi_label_selection_question_settings: MultiLabelSelect

def _validate_value_type(self) -> None:
if not isinstance(self._response_value, list):
raise ValueError(
raise UnprocessableEntityError(
f"multi label selection questions expects a list of values, found {type(self._response_value)}"
)

def _validate_labels_are_not_empty(self) -> None:
if len(self._response_value) == 0:
raise ValueError("multi label selection questions expects a list of values, found empty list")
raise UnprocessableEntityError("multi label selection questions expects a list of values, found empty list")

def _validate_labels_are_unique(self) -> None:
if len(self._response_value) != len(set(self._response_value)):
raise ValueError(
raise UnprocessableEntityError(
"multi label selection questions expect a list of unique values, but duplicates were found"
)

Expand All @@ -120,7 +121,7 @@ def _validate_labels_are_available_at_question_settings(
invalid_labels = sorted(list(set(self._response_value) - set(available_labels)))

if invalid_labels:
raise ValueError(
raise UnprocessableEntityError(
f"{invalid_labels!r} are not valid labels for multi label selection question.\nValid labels are: {available_labels!r}"
)

Expand All @@ -138,7 +139,7 @@ def _validate_rating_is_available_at_question_settings(
available_options = [option.value for option in rating_question_settings.options]

if self._response_value not in available_options:
raise ValueError(
raise UnprocessableEntityError(
f"{self._response_value!r} is not a valid rating for rating question.\nValid ratings are: {available_options!r}"
)

Expand All @@ -158,7 +159,9 @@ def validate_for(

def _validate_value_type(self) -> None:
if not isinstance(self._response_value, list):
raise ValueError(f"ranking question expects a list of values, found {type(self._response_value)}")
raise UnprocessableEntityError(
f"ranking question expects a list of values, found {type(self._response_value)}"
)

def _validate_all_rankings_are_present_when_submitted(
self, ranking_question_settings: RankingQuestionSettings, response_status: Optional[ResponseStatus] = None
Expand All @@ -170,7 +173,7 @@ def _validate_all_rankings_are_present_when_submitted(
available_values_len = len(available_values)

if len(self._response_value) != available_values_len:
raise ValueError(
raise UnprocessableEntityError(
f"ranking question expects a list containing {available_values_len} values, found a list of {len(self._response_value)} values"
)

Expand All @@ -187,7 +190,7 @@ def _validate_all_rankings_are_valid_when_submitted(
invalid_rankings = sorted(list(set(response_rankings) - set(available_rankings)))

if invalid_rankings:
raise ValueError(
raise UnprocessableEntityError(
f"{invalid_rankings!r} are not valid ranks for ranking question.\nValid ranks are: {available_rankings!r}"
)

Expand All @@ -199,15 +202,17 @@ def _validate_values_are_available_at_question_settings(
invalid_values = sorted(list(set(response_values) - set(available_values)))

if invalid_values:
raise ValueError(
raise UnprocessableEntityError(
f"{invalid_values!r} are not valid values for ranking question.\nValid values are: {available_values!r}"
)

def _validate_values_are_unique(self) -> None:
response_values = [value_item.value for value_item in self._response_value]

if len(response_values) != len(set(response_values)):
raise ValueError("ranking question expects a list of unique values, but duplicates were found")
raise UnprocessableEntityError(
"ranking question expects a list of unique values, but duplicates were found"
)


class SpanQuestionResponseValueValidator:
Expand All @@ -223,13 +228,17 @@ def validate_for(self, span_question_settings: SpanQuestionSettings, record: Rec

def _validate_value_type(self) -> None:
if not isinstance(self._response_value, list):
raise ValueError(f"span question expects a list of values, found {type(self._response_value)}")
raise UnprocessableEntityError(

Check warning on line 231 in argilla-server/src/argilla_server/validators/response_values.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/validators/response_values.py#L231

Added line #L231 was not covered by tests
f"span question expects a list of values, found {type(self._response_value)}"
)

def _validate_question_settings_field_is_present_at_record(
self, span_question_settings: SpanQuestionSettings, record: Record
) -> None:
if span_question_settings.field not in record.fields:
raise ValueError(f"span question requires record to have field `{span_question_settings.field}`")
raise UnprocessableEntityError(
f"span question requires record to have field `{span_question_settings.field}`"
)

def _validate_start_end_are_within_record_field_limits(
self, span_question_settings: SpanQuestionSettings, record: Record
Expand All @@ -238,12 +247,12 @@ def _validate_start_end_are_within_record_field_limits(

for value_item in self._response_value:
if value_item.start > (field_len - 1):
raise ValueError(
raise UnprocessableEntityError(
f"span question response value `start` must have a value lower than record field `{span_question_settings.field}` length that is `{field_len}`"
)

if value_item.end > field_len:
raise ValueError(
raise UnprocessableEntityError(
f"span question response value `end` must have a value lower or equal than record field `{span_question_settings.field}` length that is `{field_len}`"
)

Expand All @@ -252,7 +261,7 @@ def _validate_labels_are_available_at_question_settings(self, span_question_sett

for value_item in self._response_value:
if not value_item.label in available_labels:
raise ValueError(
raise UnprocessableEntityError(
f"undefined label '{value_item.label}' for span question.\nValid labels are: {available_labels!r}"
)

Expand All @@ -267,4 +276,6 @@ def _validate_values_are_not_overlapped(self, span_question_settings: SpanQuesti
and value_item.start < other_value_item.end
and value_item.end > other_value_item.start
):
raise ValueError(f"overlapping values found between spans at index idx={span_i} and idx={span_j}")
raise UnprocessableEntityError(
f"overlapping values found between spans at index idx={span_i} and idx={span_j}"
)
Loading

0 comments on commit b8847e9

Please sign in to comment.