diff --git a/argilla/docs/reference/argilla/settings/settings.md b/argilla/docs/reference/argilla/settings/settings.md index 4177d8e165..62261487a7 100644 --- a/argilla/docs/reference/argilla/settings/settings.md +++ b/argilla/docs/reference/argilla/settings/settings.md @@ -3,7 +3,7 @@ hide: footer --- # `rg.Settings` -`rg.Settings` is used to define the setttings of an Argilla `Dataset`. The settings can be used to configure the +`rg.Settings` is used to define the settings of an Argilla `Dataset`. The settings can be used to configure the behavior of the dataset, such as the fields, questions, guidelines, metadata, and vectors. The `Settings` class is passed to the `Dataset` class and used to create the dataset on the server. Once created, the settings of a dataset cannot be changed. @@ -32,6 +32,31 @@ dataset.create() To define the settings for fields, questions, metadata, vectors, or distribution, refer to the [`rg.TextField`](fields.md), [`rg.LabelQuestion`](questions.md), [`rg.TermsMetadataProperty`](metadata_property.md), and [`rg.VectorField`](vectors.md), [`rg.TaskDistribution`](task_distribution.md) class documentation. +### Adding or removing properties to settings + +The settings object can be modified before create the dataset by adding, replacing or removing properties by using +the method `settings.add` and `settings.<>.remove` + +```python +import argilla as rg + +settings = rg.Settings( + guidelines="Select the sentiment of the prompt.", + fields=[rg.TextField(name="prompt", use_markdown=True)], + questions=[rg.LabelQuestion(name="sentiment", labels=["positive", "negative"])], +) + +# Adding a new property +settings.add(rg.TextField(name="response", use_markdown=True)) + +# Replace an existing property by other property type +settings.add(rg.TextQuestion(name="response", use_markdown=False)) + +# Remove an existing property +settings.questions.remove("response") + +``` + ### Creating settings using built in templates Argilla provides built-in templates for creating settings for common dataset types. To use a template, use the class methods of the `Settings` class. There are three built-in templates available for classification, ranking, and rating tasks. Template settings also include default guidelines and mappings. diff --git a/argilla/src/argilla/records/_mapping/_mapper.py b/argilla/src/argilla/records/_mapping/_mapper.py index 9c39ae297c..a4c4a398a8 100644 --- a/argilla/src/argilla/records/_mapping/_mapper.py +++ b/argilla/src/argilla/records/_mapping/_mapper.py @@ -20,7 +20,7 @@ from argilla._exceptions import RecordsIngestionError from argilla.records._resource import Record from argilla.responses import Response -from argilla.settings import AbstractField, VectorField +from argilla.settings import FieldBase, VectorField from argilla.settings._metadata import MetadataPropertyBase from argilla.settings._question import QuestionPropertyBase from argilla.suggestions import Suggestion @@ -184,7 +184,7 @@ def _select_attribute_type(self, attribute_route: AttributeRoute) -> AttributeRo attribute_route.type = AttributeType.SUGGESTION elif isinstance(schema_item, QuestionPropertyBase) and attribute_route.type == AttributeType.RESPONSE: attribute_route.type = AttributeType.RESPONSE - elif isinstance(schema_item, AbstractField): + elif isinstance(schema_item, FieldBase): attribute_route.type = AttributeType.FIELD elif isinstance(schema_item, VectorField): attribute_route.type = AttributeType.VECTOR diff --git a/argilla/src/argilla/settings/_field.py b/argilla/src/argilla/settings/_field.py index 20bdd7db40..3f39c1c1f7 100644 --- a/argilla/src/argilla/settings/_field.py +++ b/argilla/src/argilla/settings/_field.py @@ -40,10 +40,10 @@ if TYPE_CHECKING: from argilla.datasets import Dataset -__all__ = ["Field", "AbstractField", "TextField", "ImageField", "ChatField", "CustomField"] +__all__ = ["Field", "FieldBase", "TextField", "ImageField", "ChatField", "CustomField"] -class AbstractField(ABC, SettingsPropertyBase): +class FieldBase(ABC, SettingsPropertyBase): """Abstract base class to work with Field resources""" _model: FieldModel @@ -96,7 +96,7 @@ def _with_client(self, client: "Argilla") -> "Self": return self -class TextField(AbstractField): +class TextField(FieldBase): """Text field for use in Argilla `Dataset` `Settings`""" def __init__( @@ -136,7 +136,7 @@ def use_markdown(self, value: bool) -> None: self._model.settings.use_markdown = value -class ImageField(AbstractField): +class ImageField(FieldBase): """Image field for use in Argilla `Dataset` `Settings`""" def __init__( @@ -167,7 +167,7 @@ def __init__( ) -class ChatField(AbstractField): +class ChatField(FieldBase): """Chat field for use in Argilla `Dataset` `Settings`""" def __init__( @@ -208,7 +208,7 @@ def use_markdown(self, value: bool) -> None: self._model.settings.use_markdown = value -class CustomField(AbstractField): +class CustomField(FieldBase): """Custom field for use in Argilla `Dataset` `Settings`""" def __init__( diff --git a/argilla/src/argilla/settings/_resource.py b/argilla/src/argilla/settings/_resource.py index 3ad3824157..97ced197c9 100644 --- a/argilla/src/argilla/settings/_resource.py +++ b/argilla/src/argilla/settings/_resource.py @@ -14,6 +14,7 @@ import json import os +import warnings from functools import cached_property from pathlib import Path from typing import List, Optional, TYPE_CHECKING, Dict, Union, Iterator, Sequence, Literal @@ -22,10 +23,10 @@ from argilla._exceptions import SettingsError, ArgillaAPIError, ArgillaSerializeError from argilla._models._dataset import DatasetModel from argilla._resource import Resource -from argilla.settings._field import Field, _field_from_dict, _field_from_model +from argilla.settings._field import Field, _field_from_dict, _field_from_model, FieldBase from argilla.settings._io import build_settings_from_repo_id -from argilla.settings._metadata import MetadataType, MetadataField -from argilla.settings._question import QuestionType, question_from_model, question_from_dict +from argilla.settings._metadata import MetadataType, MetadataField, MetadataPropertyBase +from argilla.settings._question import QuestionType, question_from_model, question_from_dict, QuestionPropertyBase from argilla.settings._task_distribution import TaskDistribution from argilla.settings._templates import DefaultSettingsMixin from argilla.settings._vector import VectorField @@ -206,10 +207,10 @@ def create(self) -> "Settings": self.validate() self._update_dataset_related_attributes() - self.__fields.create() - self.__questions.create() - self.__vectors.create() - self.__metadata.create() + self.__fields._create() + self.__questions._create() + self.__vectors._create() + self.__metadata._create() self._update_last_api_call() return self @@ -218,10 +219,10 @@ def update(self) -> "Resource": self.validate() self._update_dataset_related_attributes() - self.__fields.update() - self.__vectors.update() - self.__metadata.update() - # self.questions.update() + self.__fields._update() + self.__vectors._update() + self.__metadata._update() + self.__questions._update() self._update_last_api_call() return self @@ -286,6 +287,43 @@ def __eq__(self, other: "Settings") -> bool: return False return self.serialize() == other.serialize() # TODO: Create proper __eq__ methods for fields and questions + def add( + self, property: Union[Field, VectorField, MetadataType, QuestionType], override: bool = True + ) -> Union[Field, VectorField, MetadataType, QuestionType]: + """ + Add a property to the settings + + Args: + property: The property to add + override: If True, override the existing property with the same name. Otherwise, raise an error. Defaults to True. + + Returns: + The added property + + """ + # review all settings properties and remove any existing property with the same name + for attributes in [self.fields, self.questions, self.vectors, self.metadata]: + for prop in attributes: + if prop.name == property.name: + message = f"Property with name {property.name!r} already exists in settings as {prop.__class__.__name__!r}" + if override: + warnings.warn(message + ". Overriding the existing property.") + attributes.remove(prop) + else: + raise SettingsError(message) + + if isinstance(property, FieldBase): + self.fields.add(property) + elif isinstance(property, QuestionPropertyBase): + self.questions.add(property) + elif isinstance(property, VectorField): + self.vectors.add(property) + elif isinstance(property, MetadataPropertyBase): + self.metadata.add(property) + else: + raise ValueError(f"Unsupported property type: {type(property).__name__}") + return property + ##################### # Repr Methods # ##################### @@ -444,6 +482,7 @@ class SettingsProperties(Sequence[Property]): def __init__(self, settings: "Settings", properties: List[Property]): self._properties_by_name = {} self._settings = settings + self._removed_properties = [] for property in properties or []: if self._settings.dataset and hasattr(property, "dataset"): @@ -461,7 +500,7 @@ def __getitem__(self, key: Union[UUID, str, int]) -> Optional[Property]: return self._properties_by_name.get(key) def __iter__(self) -> Iterator[Property]: - return iter(self._properties_by_name.values()) + return iter([v for v in self._properties_by_name.values()]) def __len__(self): return len(self._properties_by_name) @@ -478,7 +517,17 @@ def add(self, property: Property) -> Property: setattr(self, property.name, property) return property - def create(self): + def remove(self, property: Union[str, Property]) -> None: + if isinstance(property, str): + property = self._properties_by_name.pop(property) + else: + property = self._properties_by_name.pop(property.name) + + if property: + delattr(self, property.name) + self._removed_properties.append(property) + + def _create(self): for property in self: try: property.dataset = self._settings.dataset @@ -486,7 +535,7 @@ def create(self): except ArgillaAPIError as e: raise SettingsError(f"Failed to create property {property.name!r}: {e.message}") from e - def update(self): + def _update(self): for item in self: try: item.dataset = self._settings.dataset @@ -494,6 +543,15 @@ def update(self): except ArgillaAPIError as e: raise SettingsError(f"Failed to update {item.name!r}: {e.message}") from e + self._delete() + + def _delete(self): + for item in self._removed_properties: + try: + item.delete() + except ArgillaAPIError as e: + raise SettingsError(f"Failed to delete {item.name!r}: {e.message}") from e + def serialize(self) -> List[dict]: return [property.serialize() for property in self] @@ -520,13 +578,19 @@ class to work with questions as we do with fields, vectors, or metadata (special Once issue https://github.com/argilla-io/argilla/issues/4931 is tackled, this class should be removed. """ - def create(self): + def _create(self): for question in self: try: self._create_question(question) except ArgillaAPIError as e: raise SettingsError(f"Failed to create question {question.name}") from e + def _update(self): + pass + + def _delete(self): + pass + def _create_question(self, question: QuestionType) -> None: question_model = self._settings._client.api.questions.create( dataset_id=self._settings.dataset.id, diff --git a/argilla/tests/integration/test_update_dataset_settings.py b/argilla/tests/integration/test_update_dataset_settings.py index b71e985151..5ec1883fba 100644 --- a/argilla/tests/integration/test_update_dataset_settings.py +++ b/argilla/tests/integration/test_update_dataset_settings.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import uuid - import pytest -from argilla import Dataset, Settings, TextField, ChatField, LabelQuestion, Argilla, VectorField, FloatMetadataProperty +from argilla import ( + Dataset, + Settings, + TextField, + ChatField, + LabelQuestion, + Argilla, + VectorField, + FloatMetadataProperty, + TermsMetadataProperty, +) @pytest.fixture @@ -62,3 +70,19 @@ def test_update_distribution_settings(self, client: Argilla, dataset: Dataset): dataset = client.datasets(dataset.name) assert dataset.settings.distribution.min_submitted == 100 + + def test_remove_settings_property(self, client: Argilla, dataset: Dataset): + dataset.settings.metadata.add(TermsMetadataProperty(name="metadata")) + dataset.settings.vectors.add(VectorField(name="vector", dimensions=10)) + dataset.update() + + assert isinstance(dataset.settings.metadata["metadata"], TermsMetadataProperty) + assert isinstance(dataset.settings.vectors["vector"], VectorField) + + dataset.settings.metadata.remove("metadata") + dataset.settings.vectors.remove("vector") + + dataset.update() + + assert dataset.settings.metadata["metadata"] is None + assert dataset.settings.vectors["vector"] is None diff --git a/argilla/tests/unit/test_settings/test_settings.py b/argilla/tests/unit/test_settings/test_settings.py index f2829d3848..73f69883b9 100644 --- a/argilla/tests/unit/test_settings/test_settings.py +++ b/argilla/tests/unit/test_settings/test_settings.py @@ -205,14 +205,58 @@ def test_read_settings_without_distribution(self, mocker: "MockerFixture"): settings.get() assert settings.distribution == TaskDistribution.default() - class TestSettingsSerialization: - def test_serialize(self): - settings = rg.Settings( - guidelines="This is a guideline", - fields=[rg.TextField(name="prompt", use_markdown=True)], - questions=[rg.LabelQuestion(name="sentiment", labels=["positive", "negative"])], - ) - settings_serialized = settings.serialize() - assert settings_serialized["guidelines"] == "This is a guideline" - assert settings_serialized["fields"][0]["name"] == "prompt" - assert settings_serialized["fields"][0]["settings"]["use_markdown"] is True + def test_serialize(self): + settings = rg.Settings( + guidelines="This is a guideline", + fields=[rg.TextField(name="prompt", use_markdown=True)], + questions=[rg.LabelQuestion(name="sentiment", labels=["positive", "negative"])], + ) + settings_serialized = settings.serialize() + assert settings_serialized["guidelines"] == "This is a guideline" + assert settings_serialized["fields"][0]["name"] == "prompt" + assert settings_serialized["fields"][0]["settings"]["use_markdown"] is True + + def test_remove_property_from_settings(self): + settings = rg.Settings( + fields=[rg.TextField(name="text", title="text")], + questions=[rg.LabelQuestion(name="label", title="text", labels=["positive", "negative"])], + metadata=[rg.FloatMetadataProperty("source")], + vectors=[rg.VectorField(name="vector", dimensions=3)], + ) + + settings.fields.remove("text") + assert len(settings.fields) == 0 + + settings.questions.remove("label") + assert len(settings.questions) == 0 + + settings.metadata.remove("source") + assert len(settings.metadata) == 0 + + settings.vectors.remove("vector") + assert len(settings.vectors) == 0 + + def test_adding_properties_with_override_enabled(self): + settings = rg.Settings() + + settings.add(rg.TextField(name="text", title="text")) + assert len(settings.fields) == 1 + + settings.add(rg.TextQuestion(name="question", title="question")) + assert len(settings.questions) == 1 + + settings.add(rg.FloatMetadataProperty(name="text"), override=True) + assert len(settings.metadata) == 1 + assert len(settings.fields) == 0 + + def test_adding_properties_with_disabled_override(self): + settings = rg.Settings() + + settings.add(rg.TextField(name="text", title="text")) + assert len(settings.fields) == 1 + + settings.add(rg.TextQuestion(name="question", title="question")) + assert len(settings.questions) == 1 + + with pytest.raises(SettingsError, match="Property with name 'text' already exists"): + settings.add(rg.FloatMetadataProperty(name="text"), override=False)