Skip to content

Commit

Permalink
[ENHANCEMENT][REFACTOR] SDK: allow to remove settings (#5584)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR allows users to change all the settings before creating the
dataset, and metadata and vector, when the dataset is created.

I included the `remove` method instead of a new
`settings.add(override=True)` since this change requires a lot of
internal refactor with the current design.

An example of how to use could be to change inferred settings from hub
before create the final dataset

```python
settings = rg.Settings.from_hub("google/frames-benchmark")

settings.fields.remove("answer")

for field in settings.fields:
    if field.name.startswith("wiki"):
        settings.fields.remove(field)
        settings.metadata.add(rg.TermsMetadataProperty(field.name))

settings.questions.add(rg.TextQuestion(name="answer", title="Answer"))

dataset = rg.Dataset.from_hub("google/frames-benchmark", settings=settings)
```

Or adding new metadata or vector settings when the dataset is created:
```python
dataset = client.dataset("my-dataset")

dataset.metadata.add([rg.TermsMetadata(name="split")])
dataset.update() # this line sends the change to the argilla server

```

## Updated

You can remove a property by name:
```python
settings.fields.remove("text")
```

or by property instance
```python
for field in settings.fields:
   if field.name.startwith("wiki"):
      settings.fields.remove(field)
````

And you can override existing properties to change the property type,
using the new `settings.add` method:
```python
settings = Settings.from_hub(...)

# change some settings definitions before create the dataset
settings.add(rg.TermsMetadata("wikipedia_1")) # this will remove the existing wikipedia_1 property and will create a new terms metadata one
```

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Refactor (change restructuring the codebase without changing
functionality)
- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: burtenshaw <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 6, 2024
1 parent 02df3d1 commit af2033c
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 38 deletions.
27 changes: 26 additions & 1 deletion argilla/docs/reference/argilla/settings/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ hide: footer
---
# `rg.Settings`

`rg.Settings` is used to define the setttings of an Argilla `Dataset`. The settings can be used to configure the
`rg.Settings` is used to define the settings of an Argilla `Dataset`. The settings can be used to configure the
behavior of the dataset, such as the fields, questions, guidelines, metadata, and vectors. The `Settings` class is
passed to the `Dataset` class and used to create the dataset on the server. Once created, the settings of a dataset
cannot be changed.
Expand Down Expand Up @@ -32,6 +32,31 @@ dataset.create()

To define the settings for fields, questions, metadata, vectors, or distribution, refer to the [`rg.TextField`](fields.md), [`rg.LabelQuestion`](questions.md), [`rg.TermsMetadataProperty`](metadata_property.md), and [`rg.VectorField`](vectors.md), [`rg.TaskDistribution`](task_distribution.md) class documentation.

### Adding or removing properties to settings

The settings object can be modified before create the dataset by adding, replacing or removing properties by using
the method `settings.add` and `settings.<>.remove`

```python
import argilla as rg

settings = rg.Settings(
guidelines="Select the sentiment of the prompt.",
fields=[rg.TextField(name="prompt", use_markdown=True)],
questions=[rg.LabelQuestion(name="sentiment", labels=["positive", "negative"])],
)

# Adding a new property
settings.add(rg.TextField(name="response", use_markdown=True))

# Replace an existing property by other property type
settings.add(rg.TextQuestion(name="response", use_markdown=False))

# Remove an existing property
settings.questions.remove("response")

```

### Creating settings using built in templates

Argilla provides built-in templates for creating settings for common dataset types. To use a template, use the class methods of the `Settings` class. There are three built-in templates available for classification, ranking, and rating tasks. Template settings also include default guidelines and mappings.
Expand Down
4 changes: 2 additions & 2 deletions argilla/src/argilla/records/_mapping/_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from argilla._exceptions import RecordsIngestionError
from argilla.records._resource import Record
from argilla.responses import Response
from argilla.settings import AbstractField, VectorField
from argilla.settings import FieldBase, VectorField
from argilla.settings._metadata import MetadataPropertyBase
from argilla.settings._question import QuestionPropertyBase
from argilla.suggestions import Suggestion
Expand Down Expand Up @@ -184,7 +184,7 @@ def _select_attribute_type(self, attribute_route: AttributeRoute) -> AttributeRo
attribute_route.type = AttributeType.SUGGESTION
elif isinstance(schema_item, QuestionPropertyBase) and attribute_route.type == AttributeType.RESPONSE:
attribute_route.type = AttributeType.RESPONSE
elif isinstance(schema_item, AbstractField):
elif isinstance(schema_item, FieldBase):
attribute_route.type = AttributeType.FIELD
elif isinstance(schema_item, VectorField):
attribute_route.type = AttributeType.VECTOR
Expand Down
12 changes: 6 additions & 6 deletions argilla/src/argilla/settings/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
if TYPE_CHECKING:
from argilla.datasets import Dataset

__all__ = ["Field", "AbstractField", "TextField", "ImageField", "ChatField", "CustomField"]
__all__ = ["Field", "FieldBase", "TextField", "ImageField", "ChatField", "CustomField"]


class AbstractField(ABC, SettingsPropertyBase):
class FieldBase(ABC, SettingsPropertyBase):
"""Abstract base class to work with Field resources"""

_model: FieldModel
Expand Down Expand Up @@ -96,7 +96,7 @@ def _with_client(self, client: "Argilla") -> "Self":
return self


class TextField(AbstractField):
class TextField(FieldBase):
"""Text field for use in Argilla `Dataset` `Settings`"""

def __init__(
Expand Down Expand Up @@ -136,7 +136,7 @@ def use_markdown(self, value: bool) -> None:
self._model.settings.use_markdown = value


class ImageField(AbstractField):
class ImageField(FieldBase):
"""Image field for use in Argilla `Dataset` `Settings`"""

def __init__(
Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(
)


class ChatField(AbstractField):
class ChatField(FieldBase):
"""Chat field for use in Argilla `Dataset` `Settings`"""

def __init__(
Expand Down Expand Up @@ -208,7 +208,7 @@ def use_markdown(self, value: bool) -> None:
self._model.settings.use_markdown = value


class CustomField(AbstractField):
class CustomField(FieldBase):
"""Custom field for use in Argilla `Dataset` `Settings`"""

def __init__(
Expand Down
94 changes: 79 additions & 15 deletions argilla/src/argilla/settings/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
import os
import warnings
from functools import cached_property
from pathlib import Path
from typing import List, Optional, TYPE_CHECKING, Dict, Union, Iterator, Sequence, Literal
Expand All @@ -22,10 +23,10 @@
from argilla._exceptions import SettingsError, ArgillaAPIError, ArgillaSerializeError
from argilla._models._dataset import DatasetModel
from argilla._resource import Resource
from argilla.settings._field import Field, _field_from_dict, _field_from_model
from argilla.settings._field import Field, _field_from_dict, _field_from_model, FieldBase
from argilla.settings._io import build_settings_from_repo_id
from argilla.settings._metadata import MetadataType, MetadataField
from argilla.settings._question import QuestionType, question_from_model, question_from_dict
from argilla.settings._metadata import MetadataType, MetadataField, MetadataPropertyBase
from argilla.settings._question import QuestionType, question_from_model, question_from_dict, QuestionPropertyBase
from argilla.settings._task_distribution import TaskDistribution
from argilla.settings._templates import DefaultSettingsMixin
from argilla.settings._vector import VectorField
Expand Down Expand Up @@ -206,10 +207,10 @@ def create(self) -> "Settings":
self.validate()

self._update_dataset_related_attributes()
self.__fields.create()
self.__questions.create()
self.__vectors.create()
self.__metadata.create()
self.__fields._create()
self.__questions._create()
self.__vectors._create()
self.__metadata._create()

self._update_last_api_call()
return self
Expand All @@ -218,10 +219,10 @@ def update(self) -> "Resource":
self.validate()

self._update_dataset_related_attributes()
self.__fields.update()
self.__vectors.update()
self.__metadata.update()
# self.questions.update()
self.__fields._update()
self.__vectors._update()
self.__metadata._update()
self.__questions._update()

self._update_last_api_call()
return self
Expand Down Expand Up @@ -286,6 +287,43 @@ def __eq__(self, other: "Settings") -> bool:
return False
return self.serialize() == other.serialize() # TODO: Create proper __eq__ methods for fields and questions

def add(
self, property: Union[Field, VectorField, MetadataType, QuestionType], override: bool = True
) -> Union[Field, VectorField, MetadataType, QuestionType]:
"""
Add a property to the settings
Args:
property: The property to add
override: If True, override the existing property with the same name. Otherwise, raise an error. Defaults to True.
Returns:
The added property
"""
# review all settings properties and remove any existing property with the same name
for attributes in [self.fields, self.questions, self.vectors, self.metadata]:
for prop in attributes:
if prop.name == property.name:
message = f"Property with name {property.name!r} already exists in settings as {prop.__class__.__name__!r}"
if override:
warnings.warn(message + ". Overriding the existing property.")
attributes.remove(prop)
else:
raise SettingsError(message)

if isinstance(property, FieldBase):
self.fields.add(property)
elif isinstance(property, QuestionPropertyBase):
self.questions.add(property)
elif isinstance(property, VectorField):
self.vectors.add(property)
elif isinstance(property, MetadataPropertyBase):
self.metadata.add(property)
else:
raise ValueError(f"Unsupported property type: {type(property).__name__}")
return property

#####################
# Repr Methods #
#####################
Expand Down Expand Up @@ -444,6 +482,7 @@ class SettingsProperties(Sequence[Property]):
def __init__(self, settings: "Settings", properties: List[Property]):
self._properties_by_name = {}
self._settings = settings
self._removed_properties = []

for property in properties or []:
if self._settings.dataset and hasattr(property, "dataset"):
Expand All @@ -461,7 +500,7 @@ def __getitem__(self, key: Union[UUID, str, int]) -> Optional[Property]:
return self._properties_by_name.get(key)

def __iter__(self) -> Iterator[Property]:
return iter(self._properties_by_name.values())
return iter([v for v in self._properties_by_name.values()])

def __len__(self):
return len(self._properties_by_name)
Expand All @@ -478,22 +517,41 @@ def add(self, property: Property) -> Property:
setattr(self, property.name, property)
return property

def create(self):
def remove(self, property: Union[str, Property]) -> None:
if isinstance(property, str):
property = self._properties_by_name.pop(property)
else:
property = self._properties_by_name.pop(property.name)

if property:
delattr(self, property.name)
self._removed_properties.append(property)

def _create(self):
for property in self:
try:
property.dataset = self._settings.dataset
property.create()
except ArgillaAPIError as e:
raise SettingsError(f"Failed to create property {property.name!r}: {e.message}") from e

def update(self):
def _update(self):
for item in self:
try:
item.dataset = self._settings.dataset
item.update() if item.id else item.create()
except ArgillaAPIError as e:
raise SettingsError(f"Failed to update {item.name!r}: {e.message}") from e

self._delete()

def _delete(self):
for item in self._removed_properties:
try:
item.delete()
except ArgillaAPIError as e:
raise SettingsError(f"Failed to delete {item.name!r}: {e.message}") from e

def serialize(self) -> List[dict]:
return [property.serialize() for property in self]

Expand All @@ -520,13 +578,19 @@ class to work with questions as we do with fields, vectors, or metadata (special
Once issue https://github.com/argilla-io/argilla/issues/4931 is tackled, this class should be removed.
"""

def create(self):
def _create(self):
for question in self:
try:
self._create_question(question)
except ArgillaAPIError as e:
raise SettingsError(f"Failed to create question {question.name}") from e

def _update(self):
pass

def _delete(self):
pass

def _create_question(self, question: QuestionType) -> None:
question_model = self._settings._client.api.questions.create(
dataset_id=self._settings.dataset.id,
Expand Down
30 changes: 27 additions & 3 deletions argilla/tests/integration/test_update_dataset_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid

import pytest

from argilla import Dataset, Settings, TextField, ChatField, LabelQuestion, Argilla, VectorField, FloatMetadataProperty
from argilla import (
Dataset,
Settings,
TextField,
ChatField,
LabelQuestion,
Argilla,
VectorField,
FloatMetadataProperty,
TermsMetadataProperty,
)


@pytest.fixture
Expand Down Expand Up @@ -62,3 +70,19 @@ def test_update_distribution_settings(self, client: Argilla, dataset: Dataset):

dataset = client.datasets(dataset.name)
assert dataset.settings.distribution.min_submitted == 100

def test_remove_settings_property(self, client: Argilla, dataset: Dataset):
dataset.settings.metadata.add(TermsMetadataProperty(name="metadata"))
dataset.settings.vectors.add(VectorField(name="vector", dimensions=10))
dataset.update()

assert isinstance(dataset.settings.metadata["metadata"], TermsMetadataProperty)
assert isinstance(dataset.settings.vectors["vector"], VectorField)

dataset.settings.metadata.remove("metadata")
dataset.settings.vectors.remove("vector")

dataset.update()

assert dataset.settings.metadata["metadata"] is None
assert dataset.settings.vectors["vector"] is None
66 changes: 55 additions & 11 deletions argilla/tests/unit/test_settings/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,58 @@ def test_read_settings_without_distribution(self, mocker: "MockerFixture"):
settings.get()
assert settings.distribution == TaskDistribution.default()

class TestSettingsSerialization:
def test_serialize(self):
settings = rg.Settings(
guidelines="This is a guideline",
fields=[rg.TextField(name="prompt", use_markdown=True)],
questions=[rg.LabelQuestion(name="sentiment", labels=["positive", "negative"])],
)
settings_serialized = settings.serialize()
assert settings_serialized["guidelines"] == "This is a guideline"
assert settings_serialized["fields"][0]["name"] == "prompt"
assert settings_serialized["fields"][0]["settings"]["use_markdown"] is True
def test_serialize(self):
settings = rg.Settings(
guidelines="This is a guideline",
fields=[rg.TextField(name="prompt", use_markdown=True)],
questions=[rg.LabelQuestion(name="sentiment", labels=["positive", "negative"])],
)
settings_serialized = settings.serialize()
assert settings_serialized["guidelines"] == "This is a guideline"
assert settings_serialized["fields"][0]["name"] == "prompt"
assert settings_serialized["fields"][0]["settings"]["use_markdown"] is True

def test_remove_property_from_settings(self):
settings = rg.Settings(
fields=[rg.TextField(name="text", title="text")],
questions=[rg.LabelQuestion(name="label", title="text", labels=["positive", "negative"])],
metadata=[rg.FloatMetadataProperty("source")],
vectors=[rg.VectorField(name="vector", dimensions=3)],
)

settings.fields.remove("text")
assert len(settings.fields) == 0

settings.questions.remove("label")
assert len(settings.questions) == 0

settings.metadata.remove("source")
assert len(settings.metadata) == 0

settings.vectors.remove("vector")
assert len(settings.vectors) == 0

def test_adding_properties_with_override_enabled(self):
settings = rg.Settings()

settings.add(rg.TextField(name="text", title="text"))
assert len(settings.fields) == 1

settings.add(rg.TextQuestion(name="question", title="question"))
assert len(settings.questions) == 1

settings.add(rg.FloatMetadataProperty(name="text"), override=True)
assert len(settings.metadata) == 1
assert len(settings.fields) == 0

def test_adding_properties_with_disabled_override(self):
settings = rg.Settings()

settings.add(rg.TextField(name="text", title="text"))
assert len(settings.fields) == 1

settings.add(rg.TextQuestion(name="question", title="question"))
assert len(settings.questions) == 1

with pytest.raises(SettingsError, match="Property with name 'text' already exists"):
settings.add(rg.FloatMetadataProperty(name="text"), override=False)

0 comments on commit af2033c

Please sign in to comment.