From 6452993d6da0e3539814a8b1ecab32ea3eac43b1 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Tue, 18 Jun 2024 12:17:26 +0200 Subject: [PATCH 1/3] [ENHANCEMENT] Argilla SDK: Updating record fields and vectors (#5026) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR reviews the record attributes and normalizes how to work with fields, vectors, and metadata. Now, all are treated as dictionaries and users can update in the same way that working with dictionaries or creating new attributes: ```python record = Record(fields={"name": "John"}) record.fields.update({"name": "Jane", "age": "30"}) record.fields.new_field = "value" record.vectors["new-vector"] = [1.0, 2.0, 3.0] record.vectors.vector = [1.0, 2.0, 3.0] record.metadata["new-key"] = "new_value" record.metadata.key = "new_value" ``` Once this approach is approved, I will create a new PR changing the docs. **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] New feature (non-breaking change which adds functionality) - [x] Refactor (change restructuring the codebase without changing functionality) - [x] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I added relevant documentation - [ ] I followed the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --- argilla/src/argilla/_models/__init__.py | 4 +- .../src/argilla/_models/_record/_record.py | 9 +- .../src/argilla/_models/_record/_vector.py | 10 +- .../src/argilla/records/_dataset_records.py | 18 +- argilla/src/argilla/records/_resource.py | 158 +++++++++--------- .../tests/integration/test_export_dataset.py | 6 +- .../tests/integration/test_export_records.py | 6 +- argilla/tests/integration/test_metadata.py | 4 +- ...est_record_export_import_compatibillity.py | 17 +- .../tests/unit/test_resources/test_records.py | 43 ++++- 10 files changed, 154 insertions(+), 121 deletions(-) diff --git a/argilla/src/argilla/_models/__init__.py b/argilla/src/argilla/_models/__init__.py index 3e0421d65f..6d318d084f 100644 --- a/argilla/src/argilla/_models/__init__.py +++ b/argilla/src/argilla/_models/__init__.py @@ -18,10 +18,10 @@ from argilla._models._workspace import WorkspaceModel from argilla._models._user import UserModel, Role from argilla._models._dataset import DatasetModel -from argilla._models._record._record import RecordModel +from argilla._models._record._record import RecordModel, FieldValue from argilla._models._record._suggestion import SuggestionModel from argilla._models._record._response import UserResponseModel, ResponseStatus -from argilla._models._record._vector import VectorModel +from argilla._models._record._vector import VectorModel, VectorValue from argilla._models._record._metadata import MetadataModel, MetadataValue from argilla._models._search import ( SearchQueryModel, diff --git a/argilla/src/argilla/_models/_record/_record.py b/argilla/src/argilla/_models/_record/_record.py index 4668163543..0286dc1c12 100644 --- a/argilla/src/argilla/_models/_record/_record.py +++ b/argilla/src/argilla/_models/_record/_record.py @@ -16,17 +16,20 @@ from pydantic import Field, field_serializer, field_validator -from argilla._models._resource import ResourceModel from argilla._models._record._metadata import MetadataModel, MetadataValue from argilla._models._record._response import UserResponseModel from argilla._models._record._suggestion import SuggestionModel from argilla._models._record._vector import VectorModel +from argilla._models._resource import ResourceModel + +__all__ = ["RecordModel", "FieldValue"] +FieldValue = Union[str, None] class RecordModel(ResourceModel): """Schema for the records of a `Dataset`""" - fields: Optional[Dict[str, Union[str, None]]] = None + fields: Optional[Dict[str, FieldValue]] = None metadata: Optional[Union[List[MetadataModel], Dict[str, MetadataValue]]] = Field(default_factory=dict) vectors: Optional[List[VectorModel]] = Field(default_factory=list) responses: Optional[List[UserResponseModel]] = Field(default_factory=list) @@ -49,7 +52,7 @@ def serialize_metadata(self, value: List[MetadataModel]) -> Dict[str, Any]: return {metadata.name: metadata.value for metadata in value} @field_serializer("fields", when_used="always") - def serialize_empty_fields(self, value: Dict[str, Union[str, None]]) -> Dict[str, Union[str, None]]: + def serialize_empty_fields(self, value: Dict[str, Union[str, None]]) -> Optional[Dict[str, Union[str, None]]]: """Serialize empty fields to None.""" if isinstance(value, dict) and len(value) == 0: return None diff --git a/argilla/src/argilla/_models/_record/_vector.py b/argilla/src/argilla/_models/_record/_vector.py index 9efcb1ae75..d1c6f2e250 100644 --- a/argilla/src/argilla/_models/_record/_vector.py +++ b/argilla/src/argilla/_models/_record/_vector.py @@ -11,19 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import List +from pydantic import field_validator + from argilla._models import ResourceModel -import re -from pydantic import field_validator +__all__ = ["VectorModel", "VectorValue"] -__all__ = ["VectorModel"] +VectorValue = List[float] class VectorModel(ResourceModel): name: str - vector_values: List[float] + vector_values: VectorValue @field_validator("name") @classmethod diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index 70ddf81029..10332fbd3b 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -21,7 +21,7 @@ from argilla._api import RecordsAPI from argilla._helpers import LoggingMixin -from argilla._models import RecordModel, MetadataValue +from argilla._models import RecordModel, MetadataValue, VectorValue, FieldValue from argilla.client import Argilla from argilla.records._io import GenericIO, HFDataset, HFDatasetsIO, JsonIO from argilla.records._resource import Record @@ -405,13 +405,15 @@ def _infer_record_from_mapping( Returns: A Record object. """ - fields: Dict[str, str] = {} - responses: List[Response] = [] record_id: Optional[str] = None - suggestion_values = defaultdict(dict) - vectors: List[Vector] = [] + + fields: Dict[str, FieldValue] = {} + vectors: Dict[str, VectorValue] = {} metadata: Dict[str, MetadataValue] = {} + responses: List[Response] = [] + suggestion_values: Dict[str, dict] = defaultdict(dict) + schema = self.__dataset.schema for attribute, value in data.items(): @@ -466,7 +468,7 @@ def _infer_record_from_mapping( {"value": value, "question_name": attribute, "question_id": schema_item.id} ) elif isinstance(schema_item, VectorField): - vectors.append(Vector(name=attribute, values=value)) + vectors[attribute] = value elif isinstance(schema_item, MetadataPropertyBase): metadata[attribute] = value else: @@ -478,9 +480,9 @@ def _infer_record_from_mapping( return Record( id=record_id, fields=fields, - suggestions=suggestions, - responses=responses, vectors=vectors, metadata=metadata, + suggestions=suggestions, + responses=responses, _dataset=self.__dataset, ) diff --git a/argilla/src/argilla/records/_resource.py b/argilla/src/argilla/records/_resource.py index 4c20c463b6..fb4c322c52 100644 --- a/argilla/src/argilla/records/_resource.py +++ b/argilla/src/argilla/records/_resource.py @@ -23,6 +23,8 @@ SuggestionModel, VectorModel, MetadataValue, + FieldValue, + VectorValue, ) from argilla._resource import Resource from argilla.responses import Response, UserResponse @@ -54,9 +56,9 @@ class Record(Resource): def __init__( self, id: Optional[Union[UUID, str]] = None, - fields: Optional[Dict[str, Union[str, None]]] = None, + fields: Optional[Dict[str, FieldValue]] = None, metadata: Optional[Dict[str, MetadataValue]] = None, - vectors: Optional[List[Vector]] = None, + vectors: Optional[Dict[str, VectorValue]] = None, responses: Optional[List[Response]] = None, suggestions: Optional[List[Suggestion]] = None, _server_id: Optional[UUID] = None, @@ -93,7 +95,7 @@ def __init__( # Initialize the fields self.__fields = RecordFields(fields=self._model.fields) # Initialize the vectors - self.__vectors = RecordVectors(vectors=vectors, record=self) + self.__vectors = RecordVectors(vectors=vectors) # Initialize the metadata self.__metadata = RecordMetadata(metadata=metadata) self.__responses = RecordResponses(responses=responses, record=self) @@ -158,8 +160,8 @@ def api_model(self) -> RecordModel: id=self._model.id, external_id=self._model.external_id, fields=self.fields.to_dict(), - metadata=self.metadata.models, - vectors=self.vectors.models, + metadata=self.metadata.api_models(), + vectors=self.vectors.api_models(), responses=self.responses.api_models(), suggestions=self.suggestions.api_models(), ) @@ -181,19 +183,22 @@ def to_dict(self) -> Dict[str, Dict]: represented as a key-value pair in the dictionary of the respective key. i.e. `{"fields": {"prompt": "...", "response": "..."}, "responses": {"rating": "..."}, """ + id = str(self.id) if self.id else None + server_id = str(self._model.id) if self._model.id else None fields = self.fields.to_dict() - metadata = dict(self.metadata) + metadata = self.metadata.to_dict() suggestions = self.suggestions.to_dict() responses = self.responses.to_dict() vectors = self.vectors.to_dict() + return { - "id": self.id, + "id": id, "fields": fields, "metadata": metadata, "suggestions": suggestions, "responses": responses, "vectors": vectors, - "_server_id": str(self._model.id) if self._model.id else None, + "_server_id": server_id, } @classmethod @@ -219,7 +224,6 @@ def from_dict(cls, data: Dict[str, Dict], dataset: Optional["Dataset"] = None) - for question_name, _responses in responses.items() for value in _responses ] - vectors = [Vector(name=vector_name, values=values) for vector_name, values in vectors.items()] return cls( id=record_id, @@ -245,7 +249,7 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record": id=model.external_id, fields=model.fields, metadata={meta.name: meta.value for meta in model.metadata}, - vectors=[Vector.from_model(model=vector) for vector in model.vectors], + vectors={vector.name: vector.vector_values for vector in model.vectors}, # Responses and their models are not aligned 1-1. responses=[ response @@ -258,27 +262,62 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record": ) -class RecordFields: +class RecordFields(dict): """This is a container class for the fields of a Record. - It allows for accessing fields by attribute and iterating over them. + It allows for accessing fields by attribute and key name. """ - def __init__(self, fields: Dict[str, Union[str, None]]) -> None: - self.__fields = fields or {} - for key, value in self.__fields.items(): - setattr(self, key, value) + def __init__(self, fields: Optional[Dict[str, FieldValue]] = None) -> None: + super().__init__(fields or {}) - def __getitem__(self, key: str) -> Optional[str]: - return self.__fields.get(key) + def __getattr__(self, item: str): + return self[item] - def __iter__(self): - return iter(self.__fields) + def __setattr__(self, key: str, value: MetadataValue): + self[key] = value - def to_dict(self) -> Dict[str, Union[str, None]]: - return self.__fields + def to_dict(self) -> dict: + return dict(self.items()) - def __repr__(self) -> str: - return self.to_dict().__repr__() + +class RecordMetadata(dict): + """This is a container class for the metadata of a Record.""" + + def __init__(self, metadata: Optional[Dict[str, MetadataValue]] = None) -> None: + super().__init__(metadata or {}) + + def __getattr__(self, item: str): + return self[item] + + def __setattr__(self, key: str, value: MetadataValue): + self[key] = value + + def to_dict(self) -> dict: + return dict(self.items()) + + def api_models(self) -> List[MetadataModel]: + return [MetadataModel(name=key, value=value) for key, value in self.items()] + + +class RecordVectors(dict): + """This is a container class for the vectors of a Record. + It allows for accessing suggestions by attribute and key name. + """ + + def __init__(self, vectors: Dict[str, VectorValue]) -> None: + super().__init__(vectors or {}) + + def __getattr__(self, item: str): + return self[item] + + def __setattr__(self, key: str, value: VectorValue): + self[key] = value + + def to_dict(self) -> Dict[str, List[float]]: + return dict(self.items()) + + def api_models(self) -> List[VectorModel]: + return [Vector(name=name, values=value).api_model() for name, value in self.items()] class RecordResponses(Iterable[Response]): @@ -309,6 +348,16 @@ def __getattr__(self, name) -> List[Response]: def __repr__(self) -> str: return {k: [{"value": v["value"]} for v in values] for k, values in self.to_dict().items()}.__repr__() + def to_dict(self) -> Dict[str, List[Dict]]: + """Converts the responses to a dictionary. + Returns: + A dictionary of responses. + """ + response_dict = defaultdict(list) + for response in self.__responses: + response_dict[response.question_name].append({"value": response.value, "user_id": str(response.user_id)}) + return response_dict + def api_models(self) -> List[UserResponseModel]: """Returns a list of ResponseModel objects.""" @@ -321,15 +370,6 @@ def api_models(self) -> List[UserResponseModel]: for responses in responses_by_user_id.values() ] - def to_dict(self) -> Dict[str, List[Dict]]: - """Converts the responses to a dictionary. - Returns: - A dictionary of responses. - """ - response_dict = defaultdict(list) - for response in self.__responses: - response_dict[response.question_name].append({"value": response.value, "user_id": response.user_id}) - return response_dict class RecordSuggestions(Iterable[Suggestion]): @@ -345,15 +385,15 @@ def __init__(self, suggestions: List[Suggestion], record: Record) -> None: suggestion.record = self.record setattr(self, suggestion.question_name, suggestion) - def api_models(self) -> List[SuggestionModel]: - return [suggestion.api_model() for suggestion in self.__suggestions] - def __iter__(self): return iter(self.__suggestions) def __getitem__(self, index: int): return self.__suggestions[index] + def __repr__(self) -> str: + return self.to_dict().__repr__() + def to_dict(self) -> Dict[str, List[str]]: """Converts the suggestions to a dictionary. Returns: @@ -368,48 +408,6 @@ def to_dict(self) -> Dict[str, List[str]]: } return suggestion_dict - def __repr__(self) -> str: - return self.to_dict().__repr__() - - -class RecordVectors: - """This is a container class for the vectors of a Record. - It allows for accessing suggestions by attribute and iterating over them. - """ - - def __init__(self, vectors: List[Vector], record: Record) -> None: - self.__vectors = vectors or [] - self.record = record - for vector in self.__vectors: - setattr(self, vector.name, vector.values) - - def __repr__(self) -> str: - return {vector.name: f"{len(vector.values)}" for vector in self.__vectors}.__repr__() - - @property - def models(self) -> List[VectorModel]: - return [vector.api_model() for vector in self.__vectors] - - def to_dict(self) -> Dict[str, List[float]]: - """Converts the vectors to a dictionary. - Returns: - A dictionary of vectors. - """ - return {vector.name: list(map(float, vector.values)) for vector in self.__vectors} - - -class RecordMetadata(dict): - """This is a container class for the metadata of a Record.""" - - def __init__(self, metadata: Optional[Dict[str, MetadataValue]] = None) -> None: - super().__init__(metadata or {}) - - def __getattr__(self, item: str): - return self[item] - - def __setattr__(self, key: str, value: MetadataValue): - self[key] = value + def api_models(self) -> List[SuggestionModel]: + return [suggestion.api_model() for suggestion in self.__suggestions] - @property - def models(self) -> List[MetadataModel]: - return [MetadataModel(name=key, value=value) for key, value in self.items()] diff --git a/argilla/tests/integration/test_export_dataset.py b/argilla/tests/integration/test_export_dataset.py index a2f81c4024..a78a81fe64 100644 --- a/argilla/tests/integration/test_export_dataset.py +++ b/argilla/tests/integration/test_export_dataset.py @@ -98,17 +98,17 @@ def test_import_dataset_from_disk(dataset: rg.Dataset, client): { "text": "1: Hello World, how are you?", "label": "positive", - "external_id": uuid.uuid4(), + "id": uuid.uuid4(), }, { "text": "2: Hello World, how are you?", "label": "negative", - "external_id": uuid.uuid4(), + "id": uuid.uuid4(), }, { "text": "3: Hello World, how are you?", "label": "positive", - "external_id": uuid.uuid4(), + "id": uuid.uuid4(), }, ] dataset.records.log(records=mock_data) diff --git a/argilla/tests/integration/test_export_records.py b/argilla/tests/integration/test_export_records.py index bffd569080..61f3b7c39e 100644 --- a/argilla/tests/integration/test_export_records.py +++ b/argilla/tests/integration/test_export_records.py @@ -190,17 +190,17 @@ def test_export_records_to_json(dataset: rg.Dataset): { "text": "Hello World, how are you?", "label": "positive", - "external_id": uuid.uuid4(), + "id": uuid.uuid4(), }, { "text": "Hello World, how are you?", "label": "negative", - "external_id": uuid.uuid4(), + "id": uuid.uuid4(), }, { "text": "Hello World, how are you?", "label": "positive", - "external_id": uuid.uuid4(), + "id": uuid.uuid4(), }, ] dataset.records.log(records=mock_data) diff --git a/argilla/tests/integration/test_metadata.py b/argilla/tests/integration/test_metadata.py index bfe07ef688..6b1b0ffd27 100644 --- a/argilla/tests/integration/test_metadata.py +++ b/argilla/tests/integration/test_metadata.py @@ -120,7 +120,7 @@ def test_add_record_with_metadata(dataset_with_metadata: Dataset): assert record.metadata.category == records[idx]["category"] assert record.metadata["category"] == records[idx]["category"] assert len(record.metadata) == 1 - models = record.metadata.models + models = record.metadata.api_models() assert models[0].value == records[idx]["category"] assert models[0].name == "category" @@ -137,6 +137,6 @@ def test_add_record_with_mapped_metadata(dataset_with_metadata: Dataset): assert record.metadata.category == records[idx]["my_category"] assert record.metadata["category"] == records[idx]["my_category"] assert len(record.metadata) == 1 - models = record.metadata.models + models = record.metadata.api_models() assert models[0].value == records[idx]["my_category"] assert models[0].name == "category" diff --git a/argilla/tests/unit/export/test_record_export_import_compatibillity.py b/argilla/tests/unit/export/test_record_export_import_compatibillity.py index 7018dc502f..4fd7ad53ac 100644 --- a/argilla/tests/unit/export/test_record_export_import_compatibillity.py +++ b/argilla/tests/unit/export/test_record_export_import_compatibillity.py @@ -22,22 +22,19 @@ @pytest.fixture -def user_id(): - return str(uuid.uuid4()) +def record(): - -@pytest.fixture -def record(user_id): return rg.Record( + id=uuid.uuid4(), fields={"text": "Hello World, how are you?"}, suggestions=[ rg.Suggestion("label", "positive", score=0.9), rg.Suggestion("topics", ["topic1", "topic2"], score=[0.9, 0.8]), ], - responses=[rg.Response("label", "positive", user_id=user_id)], + responses=[rg.Response("label", "positive", user_id=uuid.uuid4())], metadata={"source": "twitter", "language": "en"}, - vectors=[rg.Vector("text", [0, 0, 0])], - id=str(uuid.uuid4()), + vectors={"text": [0, 0, 0]}, + ) @@ -50,7 +47,9 @@ def test_export_record_to_from_dict(record): for key, value in record.metadata.items(): assert imported_record.metadata[key] == value assert record.fields.text == imported_record.fields.text - assert record.id == imported_record.id + # This is a consequence of how UUIDs are treated in python and could be + # problematic for users. + assert str(record.id) == imported_record.id def test_export_generic_io_via_json(record): diff --git a/argilla/tests/unit/test_resources/test_records.py b/argilla/tests/unit/test_resources/test_records.py index 8e662fe781..7782c64128 100644 --- a/argilla/tests/unit/test_resources/test_records.py +++ b/argilla/tests/unit/test_resources/test_records.py @@ -30,11 +30,11 @@ def test_record_repr(self): responses=[Response(question_name="question", value="answer", user_id=user_id)], ) assert ( - record.__repr__() == f"Record(id={record_id}," - "fields={'name': 'John', 'age': '30'}," - "metadata={'key': 'value'}," - "suggestions={'question': {'value': 'answer', 'score': None, 'agent': None}}," - f"responses={{'question': [{{'value': 'answer'}}]}})" + record.__repr__() == f"Record(id={record_id}," + "fields={'name': 'John', 'age': '30'}," + "metadata={'key': 'value'}," + "suggestions={'question': {'value': 'answer', 'score': None, 'agent': None}}," + f"responses={{'question': [{{'value': 'answer'}}]}})" ) def test_update_record_metadata_by_key(self): @@ -44,7 +44,7 @@ def test_update_record_metadata_by_key(self): record.metadata["new-key"] = "new_value" assert record.metadata == {"key": "new_value", "new-key": "new_value"} - assert record.metadata.models == [ + assert record.metadata.api_models() == [ MetadataModel(name="key", value="new_value"), MetadataModel(name="new-key", value="new_value"), ] @@ -56,7 +56,36 @@ def test_update_record_metadata_by_attribute(self): record.metadata.new_key = "new_value" assert record.metadata == {"key": "new_value", "new_key": "new_value"} - assert record.metadata.models == [ + assert record.metadata.api_models() == [ MetadataModel(name="key", value="new_value"), MetadataModel(name="new_key", value="new_value"), ] + + def test_update_record_fields(self): + record = Record(fields={"name": "John"}) + + record.fields.update({"name": "Jane", "age": "30"}) + record.fields["new_field"] = "value" + + assert record.fields == {"name": "Jane", "age": "30", "new_field": "value"} + + def test_update_record_fields_by_attribute(self): + record = Record(fields={"name": "John"}) + + record.fields.name = "Jane" + record.fields.age = "30" + record.fields.new_field = "value" + + assert record.fields == {"name": "Jane", "age": "30", "new_field": "value"} + + def test_update_record_vectors(self): + record = Record(fields={"name": "John"}, vectors={"vector": [1.0, 2.0, 3.0]}) + + record.vectors["new-vector"] = [1.0, 2.0, 3.0] + assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]} + + def test_update_record_vectors_by_attribute(self): + record = Record(fields={"name": "John"}) + + record.vectors.vector = [1.0, 2.0, 3.0] + assert record.vectors == {"vector": [1.0, 2.0, 3.0]} From fa52a6ec4abdafc1ef0d245324696cdc3d7744bf Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Tue, 18 Jun 2024 13:48:40 +0200 Subject: [PATCH 2/3] [ENHANCEMENT] argilla: Remove attribute-like access (#5048) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR addresses [discussion](https://github.com/argilla-io/argilla/pull/5026#discussion_r1642364666) from PR https://github.com/argilla-io/argilla/pull/5026 and removes attribute-like access for record fields, vectors and metadata. **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I added relevant documentation - [ ] I followed the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --- argilla/src/argilla/records/_resource.py | 20 ------------- argilla/tests/integration/test_add_records.py | 20 ++++++------- .../tests/integration/test_export_dataset.py | 2 +- .../tests/integration/test_export_records.py | 4 +-- argilla/tests/integration/test_metadata.py | 2 -- .../tests/integration/test_query_records.py | 4 +-- .../test_update_dataset_settings.py | 8 +++--- argilla/tests/integration/test_vectors.py | 18 ++++++------ ...est_record_export_import_compatibillity.py | 6 ++-- argilla/tests/unit/test_record_ingestion.py | 28 +++++++++---------- .../tests/unit/test_resources/test_records.py | 26 ----------------- 11 files changed, 45 insertions(+), 93 deletions(-) diff --git a/argilla/src/argilla/records/_resource.py b/argilla/src/argilla/records/_resource.py index fb4c322c52..a25ae66f08 100644 --- a/argilla/src/argilla/records/_resource.py +++ b/argilla/src/argilla/records/_resource.py @@ -270,12 +270,6 @@ class RecordFields(dict): def __init__(self, fields: Optional[Dict[str, FieldValue]] = None) -> None: super().__init__(fields or {}) - def __getattr__(self, item: str): - return self[item] - - def __setattr__(self, key: str, value: MetadataValue): - self[key] = value - def to_dict(self) -> dict: return dict(self.items()) @@ -286,12 +280,6 @@ class RecordMetadata(dict): def __init__(self, metadata: Optional[Dict[str, MetadataValue]] = None) -> None: super().__init__(metadata or {}) - def __getattr__(self, item: str): - return self[item] - - def __setattr__(self, key: str, value: MetadataValue): - self[key] = value - def to_dict(self) -> dict: return dict(self.items()) @@ -307,12 +295,6 @@ class RecordVectors(dict): def __init__(self, vectors: Dict[str, VectorValue]) -> None: super().__init__(vectors or {}) - def __getattr__(self, item: str): - return self[item] - - def __setattr__(self, key: str, value: VectorValue): - self[key] = value - def to_dict(self) -> Dict[str, List[float]]: return dict(self.items()) @@ -371,7 +353,6 @@ def api_models(self) -> List[UserResponseModel]: ] - class RecordSuggestions(Iterable[Suggestion]): """This is a container class for the suggestions of a Record. It allows for accessing suggestions by attribute and iterating over them. @@ -410,4 +391,3 @@ def to_dict(self) -> Dict[str, List[str]]: def api_models(self) -> List[SuggestionModel]: return [suggestion.api_model() for suggestion in self.__suggestions] - diff --git a/argilla/tests/integration/test_add_records.py b/argilla/tests/integration/test_add_records.py index d6c64b8c85..eebc2d2a9f 100644 --- a/argilla/tests/integration/test_add_records.py +++ b/argilla/tests/integration/test_add_records.py @@ -84,9 +84,9 @@ def test_add_records(client): assert dataset_records[0].id == str(mock_data[0]["id"]) assert dataset_records[1].id == str(mock_data[1]["id"]) assert dataset_records[2].id == str(mock_data[2]["id"]) - assert dataset_records[0].fields.text == mock_data[0]["text"] - assert dataset_records[1].fields.text == mock_data[1]["text"] - assert dataset_records[2].fields.text == mock_data[2]["text"] + assert dataset_records[0].fields["text"] == mock_data[0]["text"] + assert dataset_records[1].fields["text"] == mock_data[1]["text"] + assert dataset_records[2].fields["text"] == mock_data[2]["text"] def test_add_dict_records(client: Argilla): @@ -124,7 +124,7 @@ def test_add_dict_records(client: Argilla): for record, data in zip(ds.records, mock_data): assert record.id == data["id"] - assert record.fields.text == data["text"] + assert record.fields["text"] == data["text"] assert "label" not in record.__dict__ for record, data in zip(ds.records(batch_size=1, with_suggestions=True), mock_data): @@ -193,7 +193,7 @@ def test_add_records_with_suggestions(client) -> None: assert dataset_records[0].suggestions.topics.value == ["topic1", "topic2"] assert dataset_records[0].suggestions.topics.score == [0.9, 0.8] - assert dataset_records[1].fields.text == mock_data[1]["text"] + assert dataset_records[1].fields["text"] == mock_data[1]["text"] assert dataset_records[1].suggestions.comment.value == "I'm doing great, thank you!" assert dataset_records[1].suggestions.comment.score is None assert dataset_records[1].suggestions.topics.value == ["topic3"] @@ -258,7 +258,7 @@ def test_add_records_with_responses(client) -> None: for record, mock_record in zip(dataset_records, mock_data): assert record.id == str(mock_record["id"]) - assert record.fields.text == mock_record["text"] + assert record.fields["text"] == mock_record["text"] assert record.responses.label[0].value == mock_record["my_label"] assert record.responses.label[0].user_id == user.id @@ -319,7 +319,7 @@ def test_add_records_with_responses_and_suggestions(client) -> None: dataset_records = list(dataset.records(with_suggestions=True)) assert dataset_records[0].id == str(mock_data[0]["id"]) - assert dataset_records[1].fields.text == mock_data[1]["text"] + assert dataset_records[1].fields["text"] == mock_data[1]["text"] assert dataset_records[2].suggestions.label.value == "positive" assert dataset_records[2].responses.label[0].value == "negative" assert dataset_records[2].responses.label[0].user_id == user.id @@ -386,7 +386,7 @@ def test_add_records_with_fields_mapped(client) -> None: dataset_records = list(dataset.records(with_suggestions=True)) assert dataset_records[0].id == str(mock_data[0]["id"]) - assert dataset_records[1].fields.text == mock_data[1]["x"] + assert dataset_records[1].fields["text"] == mock_data[1]["x"] assert dataset_records[2].suggestions.label.value == "positive" assert dataset_records[2].suggestions.label.score == 0.5 assert dataset_records[2].responses.label[0].value == "negative" @@ -447,7 +447,7 @@ def test_add_records_with_id_mapped(client) -> None: dataset_records = list(dataset.records(with_suggestions=True)) assert dataset_records[0].id == str(mock_data[0]["uuid"]) - assert dataset_records[1].fields.text == mock_data[1]["x"] + assert dataset_records[1].fields["text"] == mock_data[1]["x"] assert dataset_records[2].suggestions.label.value == "positive" assert dataset_records[2].responses.label[0].value == "negative" assert dataset_records[2].responses.label[0].user_id == user.id @@ -571,7 +571,7 @@ def test_add_records_with_responses_and_same_schema_name(client: Argilla): dataset_records = list(dataset.records(with_responses=True)) - assert dataset_records[0].fields.text == mock_data[1]["text"] + assert dataset_records[0].fields["text"] == mock_data[1]["text"] assert dataset_records[1].responses.label[0].value == "negative" assert dataset_records[1].responses.label[0].user_id == user.id diff --git a/argilla/tests/integration/test_export_dataset.py b/argilla/tests/integration/test_export_dataset.py index a78a81fe64..0be286e028 100644 --- a/argilla/tests/integration/test_export_dataset.py +++ b/argilla/tests/integration/test_export_dataset.py @@ -118,7 +118,7 @@ def test_import_dataset_from_disk(dataset: rg.Dataset, client): new_dataset = rg.Dataset.from_disk(output_dir, client=client) for i, record in enumerate(new_dataset.records(with_suggestions=True)): - assert record.fields.text == mock_data[i]["text"] + assert record.fields["text"] == mock_data[i]["text"] assert record.suggestions.label.value == mock_data[i]["label"] assert new_dataset.settings.fields[0].name == "text" diff --git a/argilla/tests/integration/test_export_records.py b/argilla/tests/integration/test_export_records.py index 61f3b7c39e..7aa56e76a9 100644 --- a/argilla/tests/integration/test_export_records.py +++ b/argilla/tests/integration/test_export_records.py @@ -241,7 +241,7 @@ def test_export_records_from_json(dataset: rg.Dataset): dataset.records.from_json(path=temp_file) for i, record in enumerate(dataset.records(with_suggestions=True)): - assert record.fields.text == mock_data[i]["text"] + assert record.fields["text"] == mock_data[i]["text"] assert record.suggestions.label.value == mock_data[i]["label"] assert record.id == str(mock_data[i]["id"]) @@ -297,6 +297,6 @@ def test_import_records_from_hf_dataset(dataset: rg.Dataset) -> None: dataset.records.log(records=mock_hf_dataset) for i, record in enumerate(dataset.records(with_suggestions=True)): - assert record.fields.text == mock_data[i]["text"] + assert record.fields["text"] == mock_data[i]["text"] assert record.suggestions.label.value == mock_data[i]["label"] assert record.id == str(mock_data[i]["id"]) diff --git a/argilla/tests/integration/test_metadata.py b/argilla/tests/integration/test_metadata.py index 6b1b0ffd27..01f1a36aac 100644 --- a/argilla/tests/integration/test_metadata.py +++ b/argilla/tests/integration/test_metadata.py @@ -117,7 +117,6 @@ def test_add_record_with_metadata(dataset_with_metadata: Dataset): dataset_with_metadata.records.log(records) for idx, record in enumerate(dataset_with_metadata.records): - assert record.metadata.category == records[idx]["category"] assert record.metadata["category"] == records[idx]["category"] assert len(record.metadata) == 1 models = record.metadata.api_models() @@ -134,7 +133,6 @@ def test_add_record_with_mapped_metadata(dataset_with_metadata: Dataset): dataset_with_metadata.records.log(records, mapping={"my_category": "category"}) for idx, record in enumerate(dataset_with_metadata.records): - assert record.metadata.category == records[idx]["my_category"] assert record.metadata["category"] == records[idx]["my_category"] assert len(record.metadata) == 1 models = record.metadata.api_models() diff --git a/argilla/tests/integration/test_query_records.py b/argilla/tests/integration/test_query_records.py index 6e8ce13448..5c4de8c286 100644 --- a/argilla/tests/integration/test_query_records.py +++ b/argilla/tests/integration/test_query_records.py @@ -64,12 +64,12 @@ def test_query_records_by_text(client: Argilla, dataset: Dataset): assert len(records) == 1 assert records[0].id == "1" - assert records[0].fields.text == "First record" + assert records[0].fields["text"] == "First record" records = list(dataset.records(query="second")) assert len(records) == 1 assert records[0].id == "2" - assert records[0].fields.text == "Second record" + assert records[0].fields["text"] == "Second record" records = list(dataset.records(query="record")) assert len(records) == 2 diff --git a/argilla/tests/integration/test_update_dataset_settings.py b/argilla/tests/integration/test_update_dataset_settings.py index 173c4d5f1e..0a606481e5 100644 --- a/argilla/tests/integration/test_update_dataset_settings.py +++ b/argilla/tests/integration/test_update_dataset_settings.py @@ -41,12 +41,12 @@ def test_update_settings(self, client: Argilla, dataset: Dataset): dataset = client.datasets(dataset.name) settings = dataset.settings - assert settings.fields.text.use_markdown is True - assert settings.vectors.vector.dimensions == 10 + assert settings.fields["text"].use_markdown is True + assert settings.vectors["vector"].dimensions == 10 assert isinstance(settings.metadata.metadata, FloatMetadataProperty) - settings.vectors.vector.title = "A new title for vector" + settings.vectors["vector"].title = "A new title for vector" settings.update() dataset = client.datasets(dataset.name) - assert dataset.settings.vectors.vector.title == "A new title for vector" + assert dataset.settings.vectors["vector"].title == "A new title for vector" diff --git a/argilla/tests/integration/test_vectors.py b/argilla/tests/integration/test_vectors.py index e9edb49a1d..e0ca6f2acd 100644 --- a/argilla/tests/integration/test_vectors.py +++ b/argilla/tests/integration/test_vectors.py @@ -74,9 +74,9 @@ def test_vectors(client: rg.Argilla, dataset: rg.Dataset): assert dataset_records[0].id == str(mock_data[0]["id"]) assert dataset_records[1].id == str(mock_data[1]["id"]) assert dataset_records[2].id == str(mock_data[2]["id"]) - assert dataset_records[0].vectors.vector == mock_data[0]["vector"] - assert dataset_records[1].vectors.vector == mock_data[1]["vector"] - assert dataset_records[2].vectors.vector == mock_data[2]["vector"] + assert dataset_records[0].vectors["vector"] == mock_data[0]["vector"] + assert dataset_records[1].vectors["vector"] == mock_data[1]["vector"] + assert dataset_records[2].vectors["vector"] == mock_data[2]["vector"] def test_vectors_return_with_bool(client: rg.Argilla, dataset: rg.Dataset): @@ -106,9 +106,9 @@ def test_vectors_return_with_bool(client: rg.Argilla, dataset: rg.Dataset): assert dataset_records[0].id == str(mock_data[0]["id"]) assert dataset_records[1].id == str(mock_data[1]["id"]) assert dataset_records[2].id == str(mock_data[2]["id"]) - assert dataset_records[0].vectors.vector == mock_data[0]["vector"] - assert dataset_records[1].vectors.vector == mock_data[1]["vector"] - assert dataset_records[2].vectors.vector == mock_data[2]["vector"] + assert dataset_records[0].vectors["vector"] == mock_data[0]["vector"] + assert dataset_records[1].vectors["vector"] == mock_data[1]["vector"] + assert dataset_records[2].vectors["vector"] == mock_data[2]["vector"] def test_vectors_return_with_name(client: rg.Argilla, dataset: rg.Dataset): @@ -138,6 +138,6 @@ def test_vectors_return_with_name(client: rg.Argilla, dataset: rg.Dataset): assert dataset_records[0].id == str(mock_data[0]["id"]) assert dataset_records[1].id == str(mock_data[1]["id"]) assert dataset_records[2].id == str(mock_data[2]["id"]) - assert dataset_records[0].vectors.vector == mock_data[0]["vector"] - assert dataset_records[1].vectors.vector == mock_data[1]["vector"] - assert dataset_records[2].vectors.vector == mock_data[2]["vector"] + assert dataset_records[0].vectors["vector"] == mock_data[0]["vector"] + assert dataset_records[1].vectors["vector"] == mock_data[1]["vector"] + assert dataset_records[2].vectors["vector"] == mock_data[2]["vector"] diff --git a/argilla/tests/unit/export/test_record_export_import_compatibillity.py b/argilla/tests/unit/export/test_record_export_import_compatibillity.py index 4fd7ad53ac..70f61e50a7 100644 --- a/argilla/tests/unit/export/test_record_export_import_compatibillity.py +++ b/argilla/tests/unit/export/test_record_export_import_compatibillity.py @@ -46,7 +46,7 @@ def test_export_record_to_from_dict(record): assert record.suggestions[0].value == imported_record.suggestions[0].value for key, value in record.metadata.items(): assert imported_record.metadata[key] == value - assert record.fields.text == imported_record.fields.text + assert record.fields["text"] == imported_record.fields["text"] # This is a consequence of how UUIDs are treated in python and could be # problematic for users. assert str(record.id) == imported_record.id @@ -62,5 +62,5 @@ def test_export_generic_io_via_json(record): assert record.suggestions[0].value == imported_record.suggestions[0].value for key, value in record.metadata.items(): assert imported_record.metadata[key] == value - assert record.fields.text == imported_record.fields.text - assert record.vectors.text == imported_record.vectors.text + assert record.fields["text"] == imported_record.fields["text"] + assert record.vectors["text"] == imported_record.vectors["text"] diff --git a/argilla/tests/unit/test_record_ingestion.py b/argilla/tests/unit/test_record_ingestion.py index fb7167f66c..a7c98c51e8 100644 --- a/argilla/tests/unit/test_record_ingestion.py +++ b/argilla/tests/unit/test_record_ingestion.py @@ -43,7 +43,7 @@ def test_ingest_record_from_dict(dataset): }, ) - assert record.fields.prompt == "What is the capital of France?" + assert record.fields["prompt"] == "What is the capital of France?" assert record.suggestions.label.value == "positive" @@ -58,7 +58,7 @@ def test_ingest_record_from_dict_with_mapping(dataset): }, ) - assert record.fields.prompt == "What is the capital of France?" + assert record.fields["prompt"] == "What is the capital of France?" assert record.suggestions.label.value == "positive" @@ -70,7 +70,7 @@ def test_ingest_record_from_dict_with_suggestions(dataset): }, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.suggestions.label.value == "negative" @@ -88,7 +88,7 @@ def test_ingest_record_from_dict_with_suggestions_scores(dataset): }, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.suggestions.label.value == "negative" assert record.suggestions.label.score == 0.9 assert record.suggestions.label.agent == "model_name" @@ -108,7 +108,7 @@ def test_ingest_record_from_dict_with_suggestions_scores_and_agent(dataset): }, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.suggestions.label.value == "negative" assert record.suggestions.label.score == 0.9 assert record.suggestions.label.agent == "model_name" @@ -127,7 +127,7 @@ def test_ingest_record_from_dict_with_responses(dataset): user_id=user_id, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.responses.label[0].value == "negative" assert record.responses.label[0].user_id == user_id @@ -142,7 +142,7 @@ def test_ingest_record_from_dict_with_id_as_id(dataset): }, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.id == record_id @@ -159,7 +159,7 @@ def test_ingest_record_from_dict_with_id_and_mapping(dataset): }, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.id == record_id @@ -172,7 +172,7 @@ def test_ingest_record_from_dict_with_metadata(dataset): }, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.suggestions.label.value == "negative" assert record.metadata["score"] == 0.9 @@ -189,7 +189,7 @@ def test_ingest_record_from_dict_with_metadata_and_mapping(dataset): }, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.suggestions.label.value == "negative" assert record.metadata["score"] == 0.9 @@ -203,9 +203,9 @@ def test_ingest_record_from_dict_with_vectors(dataset): }, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.suggestions.label.value == "negative" - assert record.vectors.vector == [1, 2, 3] + assert record.vectors["vector"] == [1, 2, 3] def test_ingest_record_from_dict_with_vectors_and_mapping(dataset): @@ -220,6 +220,6 @@ def test_ingest_record_from_dict_with_vectors_and_mapping(dataset): }, ) - assert record.fields.prompt == "Hello World, how are you?" + assert record.fields["prompt"] == "Hello World, how are you?" assert record.suggestions.label.value == "negative" - assert record.vectors.vector == [1, 2, 3] + assert record.vectors["vector"] == [1, 2, 3] diff --git a/argilla/tests/unit/test_resources/test_records.py b/argilla/tests/unit/test_resources/test_records.py index 7782c64128..6a0ae1e056 100644 --- a/argilla/tests/unit/test_resources/test_records.py +++ b/argilla/tests/unit/test_resources/test_records.py @@ -49,18 +49,6 @@ def test_update_record_metadata_by_key(self): MetadataModel(name="new-key", value="new_value"), ] - def test_update_record_metadata_by_attribute(self): - record = Record(fields={"name": "John", "age": "30"}, metadata={"key": "value"}) - - record.metadata.key = "new_value" - record.metadata.new_key = "new_value" - - assert record.metadata == {"key": "new_value", "new_key": "new_value"} - assert record.metadata.api_models() == [ - MetadataModel(name="key", value="new_value"), - MetadataModel(name="new_key", value="new_value"), - ] - def test_update_record_fields(self): record = Record(fields={"name": "John"}) @@ -69,23 +57,9 @@ def test_update_record_fields(self): assert record.fields == {"name": "Jane", "age": "30", "new_field": "value"} - def test_update_record_fields_by_attribute(self): - record = Record(fields={"name": "John"}) - - record.fields.name = "Jane" - record.fields.age = "30" - record.fields.new_field = "value" - - assert record.fields == {"name": "Jane", "age": "30", "new_field": "value"} - def test_update_record_vectors(self): record = Record(fields={"name": "John"}, vectors={"vector": [1.0, 2.0, 3.0]}) record.vectors["new-vector"] = [1.0, 2.0, 3.0] assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]} - def test_update_record_vectors_by_attribute(self): - record = Record(fields={"name": "John"}) - - record.vectors.vector = [1.0, 2.0, 3.0] - assert record.vectors == {"vector": [1.0, 2.0, 3.0]} From 1ab171cf2ee001c8acfdba41fe507fc4ae941a5e Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 19 Jun 2024 11:03:33 +0200 Subject: [PATCH 3/3] [ENHANCEMENT] docs: Add howto update record vectors (#5052) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR adds how to update record vectors in how-to guides. Since the fields cannot be updated, the how-to guides skip them for now. **How Has This Been Tested** (Please describe the tests that you ran to verify your changes.) - [ ] `sphinx-autobuild` (read [Developer Documentation](https://docs.argilla.io/en/latest/community/developer_docs.html#building-the-documentation) for more details) **Checklist** - [ ] I added relevant documentation - [ ] I followed the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --- argilla/docs/how_to_guides/record.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/argilla/docs/how_to_guides/record.md b/argilla/docs/how_to_guides/record.md index b5077a7048..bd7bcf794c 100644 --- a/argilla/docs/how_to_guides/record.md +++ b/argilla/docs/how_to_guides/record.md @@ -436,7 +436,7 @@ updated_data = [ dataset.records.log(records=updated_data) ``` -!!! note "Update the metadata" +=== "Update the metadata" The `metadata` of `Record` object is a python dictionary. So to update the metadata of a record, you can iterate over the records and update the metadata by key or using `metadata.update`. After that, you should update the records in the dataset. ```python @@ -452,6 +452,22 @@ dataset.records.log(records=updated_data) dataset.records.log(records=updated_records) ``` +=== "Update vectors" + When a new vector field is added to the dataset settings, or some value for the existing record vectors must updated, you can iterate over the records and update the vectors in the same way as the metadata. + + ```python + updated_records = [] + + for record in dataset.records(): + + record.vectors["new_vector"] = [...] + record.vector["v"] = [...] + + updated_records.append(record) + + dataset.records.log(records=updated_records) + ``` + ## Delete records You can delete records in a dataset calling the `delete` method on the `Dataset` object. To delete records, you need to retrieve them from the server and get a list with those that you want to delete.