Skip to content

Commit

Permalink
[ENHANCEMENT] Argilla SDK: Updating record fields and vectors (#5026)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR reviews the record attributes and normalizes how to work with
fields, vectors, and metadata. Now, all are treated as dictionaries and
users can update in the same way that working with dictionaries or
creating new attributes:

```python
record = Record(fields={"name": "John"})

record.fields.update({"name": "Jane", "age": "30"})
record.fields.new_field = "value"

record.vectors["new-vector"] = [1.0, 2.0, 3.0]
record.vectors.vector = [1.0, 2.0, 3.0]

record.metadata["new-key"] = "new_value"
record.metadata.key = "new_value"

```

Once this approach is approved, I will create a new PR changing the
docs.


**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] New feature (non-breaking change which adds functionality)
- [x] Refactor (change restructuring the codebase without changing
functionality)
- [x] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I added relevant documentation
- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jun 18, 2024
1 parent 90f1ef6 commit 6452993
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 121 deletions.
4 changes: 2 additions & 2 deletions argilla/src/argilla/_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from argilla._models._workspace import WorkspaceModel
from argilla._models._user import UserModel, Role
from argilla._models._dataset import DatasetModel
from argilla._models._record._record import RecordModel
from argilla._models._record._record import RecordModel, FieldValue
from argilla._models._record._suggestion import SuggestionModel
from argilla._models._record._response import UserResponseModel, ResponseStatus
from argilla._models._record._vector import VectorModel
from argilla._models._record._vector import VectorModel, VectorValue
from argilla._models._record._metadata import MetadataModel, MetadataValue
from argilla._models._search import (
SearchQueryModel,
Expand Down
9 changes: 6 additions & 3 deletions argilla/src/argilla/_models/_record/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@

from pydantic import Field, field_serializer, field_validator

from argilla._models._resource import ResourceModel
from argilla._models._record._metadata import MetadataModel, MetadataValue
from argilla._models._record._response import UserResponseModel
from argilla._models._record._suggestion import SuggestionModel
from argilla._models._record._vector import VectorModel
from argilla._models._resource import ResourceModel

__all__ = ["RecordModel", "FieldValue"]

FieldValue = Union[str, None]

class RecordModel(ResourceModel):
"""Schema for the records of a `Dataset`"""

fields: Optional[Dict[str, Union[str, None]]] = None
fields: Optional[Dict[str, FieldValue]] = None
metadata: Optional[Union[List[MetadataModel], Dict[str, MetadataValue]]] = Field(default_factory=dict)
vectors: Optional[List[VectorModel]] = Field(default_factory=list)
responses: Optional[List[UserResponseModel]] = Field(default_factory=list)
Expand All @@ -49,7 +52,7 @@ def serialize_metadata(self, value: List[MetadataModel]) -> Dict[str, Any]:
return {metadata.name: metadata.value for metadata in value}

@field_serializer("fields", when_used="always")
def serialize_empty_fields(self, value: Dict[str, Union[str, None]]) -> Dict[str, Union[str, None]]:
def serialize_empty_fields(self, value: Dict[str, Union[str, None]]) -> Optional[Dict[str, Union[str, None]]]:
"""Serialize empty fields to None."""
if isinstance(value, dict) and len(value) == 0:
return None
Expand Down
10 changes: 6 additions & 4 deletions argilla/src/argilla/_models/_record/_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List

from pydantic import field_validator

from argilla._models import ResourceModel

import re
from pydantic import field_validator
__all__ = ["VectorModel", "VectorValue"]

__all__ = ["VectorModel"]
VectorValue = List[float]


class VectorModel(ResourceModel):
name: str
vector_values: List[float]
vector_values: VectorValue

@field_validator("name")
@classmethod
Expand Down
18 changes: 10 additions & 8 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from argilla._api import RecordsAPI
from argilla._helpers import LoggingMixin
from argilla._models import RecordModel, MetadataValue
from argilla._models import RecordModel, MetadataValue, VectorValue, FieldValue
from argilla.client import Argilla
from argilla.records._io import GenericIO, HFDataset, HFDatasetsIO, JsonIO
from argilla.records._resource import Record
Expand Down Expand Up @@ -405,13 +405,15 @@ def _infer_record_from_mapping(
Returns:
A Record object.
"""
fields: Dict[str, str] = {}
responses: List[Response] = []
record_id: Optional[str] = None
suggestion_values = defaultdict(dict)
vectors: List[Vector] = []

fields: Dict[str, FieldValue] = {}
vectors: Dict[str, VectorValue] = {}
metadata: Dict[str, MetadataValue] = {}

responses: List[Response] = []
suggestion_values: Dict[str, dict] = defaultdict(dict)

schema = self.__dataset.schema

for attribute, value in data.items():
Expand Down Expand Up @@ -466,7 +468,7 @@ def _infer_record_from_mapping(
{"value": value, "question_name": attribute, "question_id": schema_item.id}
)
elif isinstance(schema_item, VectorField):
vectors.append(Vector(name=attribute, values=value))
vectors[attribute] = value
elif isinstance(schema_item, MetadataPropertyBase):
metadata[attribute] = value
else:
Expand All @@ -478,9 +480,9 @@ def _infer_record_from_mapping(
return Record(
id=record_id,
fields=fields,
suggestions=suggestions,
responses=responses,
vectors=vectors,
metadata=metadata,
suggestions=suggestions,
responses=responses,
_dataset=self.__dataset,
)
158 changes: 78 additions & 80 deletions argilla/src/argilla/records/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
SuggestionModel,
VectorModel,
MetadataValue,
FieldValue,
VectorValue,
)
from argilla._resource import Resource
from argilla.responses import Response, UserResponse
Expand Down Expand Up @@ -54,9 +56,9 @@ class Record(Resource):
def __init__(
self,
id: Optional[Union[UUID, str]] = None,
fields: Optional[Dict[str, Union[str, None]]] = None,
fields: Optional[Dict[str, FieldValue]] = None,
metadata: Optional[Dict[str, MetadataValue]] = None,
vectors: Optional[List[Vector]] = None,
vectors: Optional[Dict[str, VectorValue]] = None,
responses: Optional[List[Response]] = None,
suggestions: Optional[List[Suggestion]] = None,
_server_id: Optional[UUID] = None,
Expand Down Expand Up @@ -93,7 +95,7 @@ def __init__(
# Initialize the fields
self.__fields = RecordFields(fields=self._model.fields)
# Initialize the vectors
self.__vectors = RecordVectors(vectors=vectors, record=self)
self.__vectors = RecordVectors(vectors=vectors)
# Initialize the metadata
self.__metadata = RecordMetadata(metadata=metadata)
self.__responses = RecordResponses(responses=responses, record=self)
Expand Down Expand Up @@ -158,8 +160,8 @@ def api_model(self) -> RecordModel:
id=self._model.id,
external_id=self._model.external_id,
fields=self.fields.to_dict(),
metadata=self.metadata.models,
vectors=self.vectors.models,
metadata=self.metadata.api_models(),
vectors=self.vectors.api_models(),
responses=self.responses.api_models(),
suggestions=self.suggestions.api_models(),
)
Expand All @@ -181,19 +183,22 @@ def to_dict(self) -> Dict[str, Dict]:
represented as a key-value pair in the dictionary of the respective key. i.e.
`{"fields": {"prompt": "...", "response": "..."}, "responses": {"rating": "..."},
"""
id = str(self.id) if self.id else None
server_id = str(self._model.id) if self._model.id else None
fields = self.fields.to_dict()
metadata = dict(self.metadata)
metadata = self.metadata.to_dict()
suggestions = self.suggestions.to_dict()
responses = self.responses.to_dict()
vectors = self.vectors.to_dict()

return {
"id": self.id,
"id": id,
"fields": fields,
"metadata": metadata,
"suggestions": suggestions,
"responses": responses,
"vectors": vectors,
"_server_id": str(self._model.id) if self._model.id else None,
"_server_id": server_id,
}

@classmethod
Expand All @@ -219,7 +224,6 @@ def from_dict(cls, data: Dict[str, Dict], dataset: Optional["Dataset"] = None) -
for question_name, _responses in responses.items()
for value in _responses
]
vectors = [Vector(name=vector_name, values=values) for vector_name, values in vectors.items()]

return cls(
id=record_id,
Expand All @@ -245,7 +249,7 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
id=model.external_id,
fields=model.fields,
metadata={meta.name: meta.value for meta in model.metadata},
vectors=[Vector.from_model(model=vector) for vector in model.vectors],
vectors={vector.name: vector.vector_values for vector in model.vectors},
# Responses and their models are not aligned 1-1.
responses=[
response
Expand All @@ -258,27 +262,62 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
)


class RecordFields:
class RecordFields(dict):
"""This is a container class for the fields of a Record.
It allows for accessing fields by attribute and iterating over them.
It allows for accessing fields by attribute and key name.
"""

def __init__(self, fields: Dict[str, Union[str, None]]) -> None:
self.__fields = fields or {}
for key, value in self.__fields.items():
setattr(self, key, value)
def __init__(self, fields: Optional[Dict[str, FieldValue]] = None) -> None:
super().__init__(fields or {})

def __getitem__(self, key: str) -> Optional[str]:
return self.__fields.get(key)
def __getattr__(self, item: str):
return self[item]

def __iter__(self):
return iter(self.__fields)
def __setattr__(self, key: str, value: MetadataValue):
self[key] = value

def to_dict(self) -> Dict[str, Union[str, None]]:
return self.__fields
def to_dict(self) -> dict:
return dict(self.items())

def __repr__(self) -> str:
return self.to_dict().__repr__()

class RecordMetadata(dict):
"""This is a container class for the metadata of a Record."""

def __init__(self, metadata: Optional[Dict[str, MetadataValue]] = None) -> None:
super().__init__(metadata or {})

def __getattr__(self, item: str):
return self[item]

def __setattr__(self, key: str, value: MetadataValue):
self[key] = value

def to_dict(self) -> dict:
return dict(self.items())

def api_models(self) -> List[MetadataModel]:
return [MetadataModel(name=key, value=value) for key, value in self.items()]


class RecordVectors(dict):
"""This is a container class for the vectors of a Record.
It allows for accessing suggestions by attribute and key name.
"""

def __init__(self, vectors: Dict[str, VectorValue]) -> None:
super().__init__(vectors or {})

def __getattr__(self, item: str):
return self[item]

def __setattr__(self, key: str, value: VectorValue):
self[key] = value

def to_dict(self) -> Dict[str, List[float]]:
return dict(self.items())

def api_models(self) -> List[VectorModel]:
return [Vector(name=name, values=value).api_model() for name, value in self.items()]


class RecordResponses(Iterable[Response]):
Expand Down Expand Up @@ -309,6 +348,16 @@ def __getattr__(self, name) -> List[Response]:
def __repr__(self) -> str:
return {k: [{"value": v["value"]} for v in values] for k, values in self.to_dict().items()}.__repr__()

def to_dict(self) -> Dict[str, List[Dict]]:
"""Converts the responses to a dictionary.
Returns:
A dictionary of responses.
"""
response_dict = defaultdict(list)
for response in self.__responses:
response_dict[response.question_name].append({"value": response.value, "user_id": str(response.user_id)})
return response_dict

def api_models(self) -> List[UserResponseModel]:
"""Returns a list of ResponseModel objects."""

Expand All @@ -321,15 +370,6 @@ def api_models(self) -> List[UserResponseModel]:
for responses in responses_by_user_id.values()
]

def to_dict(self) -> Dict[str, List[Dict]]:
"""Converts the responses to a dictionary.
Returns:
A dictionary of responses.
"""
response_dict = defaultdict(list)
for response in self.__responses:
response_dict[response.question_name].append({"value": response.value, "user_id": response.user_id})
return response_dict


class RecordSuggestions(Iterable[Suggestion]):
Expand All @@ -345,15 +385,15 @@ def __init__(self, suggestions: List[Suggestion], record: Record) -> None:
suggestion.record = self.record
setattr(self, suggestion.question_name, suggestion)

def api_models(self) -> List[SuggestionModel]:
return [suggestion.api_model() for suggestion in self.__suggestions]

def __iter__(self):
return iter(self.__suggestions)

def __getitem__(self, index: int):
return self.__suggestions[index]

def __repr__(self) -> str:
return self.to_dict().__repr__()

def to_dict(self) -> Dict[str, List[str]]:
"""Converts the suggestions to a dictionary.
Returns:
Expand All @@ -368,48 +408,6 @@ def to_dict(self) -> Dict[str, List[str]]:
}
return suggestion_dict

def __repr__(self) -> str:
return self.to_dict().__repr__()


class RecordVectors:
"""This is a container class for the vectors of a Record.
It allows for accessing suggestions by attribute and iterating over them.
"""

def __init__(self, vectors: List[Vector], record: Record) -> None:
self.__vectors = vectors or []
self.record = record
for vector in self.__vectors:
setattr(self, vector.name, vector.values)

def __repr__(self) -> str:
return {vector.name: f"{len(vector.values)}" for vector in self.__vectors}.__repr__()

@property
def models(self) -> List[VectorModel]:
return [vector.api_model() for vector in self.__vectors]

def to_dict(self) -> Dict[str, List[float]]:
"""Converts the vectors to a dictionary.
Returns:
A dictionary of vectors.
"""
return {vector.name: list(map(float, vector.values)) for vector in self.__vectors}


class RecordMetadata(dict):
"""This is a container class for the metadata of a Record."""

def __init__(self, metadata: Optional[Dict[str, MetadataValue]] = None) -> None:
super().__init__(metadata or {})

def __getattr__(self, item: str):
return self[item]

def __setattr__(self, key: str, value: MetadataValue):
self[key] = value
def api_models(self) -> List[SuggestionModel]:
return [suggestion.api_model() for suggestion in self.__suggestions]

@property
def models(self) -> List[MetadataModel]:
return [MetadataModel(name=key, value=value) for key, value in self.items()]
Loading

0 comments on commit 6452993

Please sign in to comment.