Skip to content

Commit

Permalink
[REFACTOR] argilla: Align questions to Resource API (#5680)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

Closes #4931

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Refactor (change restructuring the codebase without changing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: burtenshaw <[email protected]>
  • Loading branch information
frascuchon and burtenshaw authored Nov 21, 2024
1 parent 5f6c291 commit d6bc6f8
Show file tree
Hide file tree
Showing 25 changed files with 389 additions and 669 deletions.
64 changes: 15 additions & 49 deletions argilla/src/argilla/_api/_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,76 +18,56 @@
import httpx
from argilla._api._base import ResourceAPI
from argilla._exceptions import api_error_handler
from argilla._models import (
TextQuestionModel,
LabelQuestionModel,
MultiLabelQuestionModel,
RankingQuestionModel,
RatingQuestionModel,
SpanQuestionModel,
QuestionBaseModel,
QuestionModel,
)
from argilla._models import QuestionModel

__all__ = ["QuestionsAPI"]


class QuestionsAPI(ResourceAPI[QuestionBaseModel]):
class QuestionsAPI(ResourceAPI[QuestionModel]):
"""Manage datasets via the API"""

http_client: httpx.Client

_TYPE_TO_MODEL_CLASS = {
"text": TextQuestionModel,
"label_selection": LabelQuestionModel,
"multi_label_selection": MultiLabelQuestionModel,
"ranking": RankingQuestionModel,
"rating": RatingQuestionModel,
"span": SpanQuestionModel,
}

################
# CRUD methods #
################

@api_error_handler
def create(
self,
dataset_id: UUID,
question: QuestionModel,
) -> QuestionModel:
url = f"/api/v1/datasets/{dataset_id}/questions"
url = f"/api/v1/datasets/{question.dataset_id}/questions"
response = self.http_client.post(url=url, json=question.model_dump())
response.raise_for_status()
response_json = response.json()
question_model = self._model_from_json(response_json=response_json)
self._log_message(message=f"Created question {question_model.name} in dataset {dataset_id}")
self._log_message(message=f"Created question {question_model.name} in dataset {question.dataset_id}")
return question_model

@api_error_handler
def update(
self,
question: QuestionModel,
) -> QuestionModel:
# TODO: Implement update method for fields with server side ID
raise NotImplementedError
url = f"/api/v1/questions/{question.id}"
response = self.http_client.patch(url, json=question.model_dump())
response.raise_for_status()
response_json = response.json()
updated_question = self._model_from_json(response_json)
self._log_message(message=f"Update question {updated_question.name} with id {question.id}")
return updated_question

@api_error_handler
def delete(self, question_id: UUID) -> None:
# TODO: Implement delete method for fields with server side ID
raise NotImplementedError
url = f"/api/v1/questions/{question_id}"
self.http_client.delete(url).raise_for_status()
self._log_message(message=f"Deleted question with id {question_id}")

####################
# Utility methods #
####################

def create_many(self, dataset_id: UUID, questions: List[QuestionModel]) -> List[QuestionModel]:
response_models = []
for question in questions:
response_model = self.create(dataset_id=dataset_id, question=question)
response_models.append(response_model)
return response_models

@api_error_handler
def list(self, dataset_id: UUID) -> List[QuestionModel]:
response = self.http_client.get(f"/api/v1/datasets/{dataset_id}/questions")
Expand All @@ -103,21 +83,7 @@ def list(self, dataset_id: UUID) -> List[QuestionModel]:
def _model_from_json(self, response_json: Dict) -> QuestionModel:
response_json["inserted_at"] = self._date_from_iso_format(date=response_json["inserted_at"])
response_json["updated_at"] = self._date_from_iso_format(date=response_json["updated_at"])
return self._get_model_from_response(response_json=response_json)
return QuestionModel(**response_json)

def _model_from_jsons(self, response_jsons: List[Dict]) -> List[QuestionModel]:
return list(map(self._model_from_json, response_jsons))

def _get_model_from_response(self, response_json: Dict) -> QuestionModel:
"""Get the model from the response"""
try:
question_type = response_json.get("settings", {}).get("type")
except Exception as e:
raise ValueError("Invalid field type: missing 'settings.type' in response") from e

question_class = self._TYPE_TO_MODEL_CLASS.get(question_type)
if question_class is None:
self._log_message(message=f"Unknown question type: {question_type}")
question_class = QuestionBaseModel

return question_class(**response_json, check_fields=False)
25 changes: 17 additions & 8 deletions argilla/src/argilla/_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,14 @@
FieldSettings,
)
from argilla._models._settings._questions import (
LabelQuestionModel,
LabelQuestionSettings,
MultiLabelQuestionModel,
QuestionBaseModel,
QuestionModel,
QuestionSettings,
RankingQuestionModel,
RatingQuestionModel,
SpanQuestionModel,
SpanQuestionSettings,
TextQuestionModel,
TextQuestionSettings,
LabelQuestionSettings,
RatingQuestionSettings,
MultiLabelQuestionSettings,
RankingQuestionSettings,
)
from argilla._models._settings._metadata import (
MetadataFieldModel,
Expand All @@ -61,5 +57,18 @@
FloatMetadataPropertySettings,
IntegerMetadataPropertySettings,
)
from argilla._models._settings._questions import (
QuestionModel,
QuestionSettings,
LabelQuestionSettings,
RatingQuestionSettings,
TextQuestionSettings,
MultiLabelQuestionSettings,
RankingQuestionSettings,
SpanQuestionSettings,
)
from argilla._models._settings._vectors import VectorFieldModel

from argilla._models._user import UserModel, Role
from argilla._models._workspace import WorkspaceModel
from argilla._models._webhook import WebhookModel, EventType
164 changes: 164 additions & 0 deletions argilla/src/argilla/_models/_settings/_questions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Annotated, Union, Optional, ClassVar, List, Dict, Literal
from uuid import UUID

from pydantic import ConfigDict, field_validator, Field, BaseModel, model_validator, field_serializer
from pydantic_core.core_schema import ValidationInfo

from argilla._models import ResourceModel

try:
from typing import Self
except ImportError:
from typing_extensions import Self


class LabelQuestionSettings(BaseModel):
type: Literal["label_selection"] = "label_selection"

_MIN_VISIBLE_OPTIONS: ClassVar[int] = 3

options: List[Dict[str, Optional[str]]] = Field(default_factory=list, validate_default=True)
visible_options: Optional[int] = Field(None, validate_default=True, ge=_MIN_VISIBLE_OPTIONS)

@field_validator("options", mode="before")
@classmethod
def __labels_are_unique(cls, options: List[Dict[str, Optional[str]]]) -> List[Dict[str, Optional[str]]]:
"""Ensure that labels are unique"""

unique_labels = list(set([option["value"] for option in options]))
if len(unique_labels) != len(options):
raise ValueError("All labels must be unique")
return options

@model_validator(mode="after")
def __validate_visible_options(self) -> "Self":
if self.visible_options is None and self.options and len(self.options) >= self._MIN_VISIBLE_OPTIONS:
self.visible_options = len(self.options)
return self


class MultiLabelQuestionSettings(LabelQuestionSettings):
type: Literal["multi_label_selection"] = "multi_label_selection"
options_order: Literal["natural", "suggestion"] = Field("natural", description="The order of the labels in the UI.")


class RankingQuestionSettings(BaseModel):
type: Literal["ranking"] = "ranking"

options: List[Dict[str, Optional[str]]] = Field(default_factory=list, validate_default=True)

@field_validator("options", mode="before")
@classmethod
def __values_are_unique(cls, options: List[Dict[str, Optional[str]]]) -> List[Dict[str, Optional[str]]]:
"""Ensure that values are unique"""

unique_values = list(set([option["value"] for option in options]))
if len(unique_values) != len(options):
raise ValueError("All values must be unique")

return options


class RatingQuestionSettings(BaseModel):
type: Literal["rating"] = "rating"

options: List[dict] = Field(..., validate_default=True)

@field_validator("options", mode="before")
@classmethod
def __values_are_unique(cls, options: List[dict]) -> List[dict]:
"""Ensure that values are unique"""

unique_values = list(set([option["value"] for option in options]))
if len(unique_values) != len(options):
raise ValueError("All values must be unique")

return options


class SpanQuestionSettings(BaseModel):
type: Literal["span"] = "span"

_MIN_VISIBLE_OPTIONS: ClassVar[int] = 3

allow_overlapping: bool = False
field: Optional[str] = None
options: List[Dict[str, Optional[str]]] = Field(default_factory=list, validate_default=True)
visible_options: Optional[int] = Field(None, validate_default=True, ge=_MIN_VISIBLE_OPTIONS)

@field_validator("options", mode="before")
@classmethod
def __values_are_unique(cls, options: List[Dict[str, Optional[str]]]) -> List[Dict[str, Optional[str]]]:
"""Ensure that values are unique"""

unique_values = list(set([option["value"] for option in options]))
if len(unique_values) != len(options):
raise ValueError("All values must be unique")

return options

@model_validator(mode="after")
def __validate_visible_options(self) -> "Self":
if self.visible_options is None and self.options and len(self.options) >= self._MIN_VISIBLE_OPTIONS:
self.visible_options = len(self.options)
return self


class TextQuestionSettings(BaseModel):
type: Literal["text"] = "text"

use_markdown: bool = False


QuestionSettings = Annotated[
Union[
LabelQuestionSettings,
MultiLabelQuestionSettings,
RankingQuestionSettings,
RatingQuestionSettings,
SpanQuestionSettings,
TextQuestionSettings,
],
Field(..., discriminator="type"),
]


class QuestionModel(ResourceModel):
name: str
settings: QuestionSettings

title: str = Field(None, validate_default=True)
description: Optional[str] = None
required: bool = True

dataset_id: Optional[UUID] = None

@field_validator("title", mode="before")
@classmethod
def _title_default(cls, title, info: ValidationInfo):
validated_title = title or info.data["name"]
return validated_title

@property
def type(self) -> str:
return self.settings.type

@field_serializer("id", "dataset_id", when_used="unless-none")
def serialize_id(self, value: UUID) -> str:
return str(value)

model_config = ConfigDict(validate_assignment=True)
35 changes: 0 additions & 35 deletions argilla/src/argilla/_models/_settings/_questions/__init__.py

This file was deleted.

Loading

0 comments on commit d6bc6f8

Please sign in to comment.