Skip to content

Commit

Permalink
[BUGFIX] argilla: Support export action with filtered records (#5054)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR adds support for export operations to filter record results. So
users can export a subset of records:

```python
ds.records(query=...).to_list()
ds.records(query(...).to_json(...)
```

Closes #5053 

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [X] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jun 19, 2024
1 parent 29952ad commit 955d967
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 10 deletions.
29 changes: 19 additions & 10 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]:
def _is_search_query(self) -> bool:
return bool(self.__query and (self.__query.query or self.__query.filter))

def to_list(self, flatten: bool) -> List[Dict[str, Any]]:
return GenericIO.to_list(records=list(self), flatten=flatten)

def to_dict(self, flatten: bool, orient: str) -> Dict[str, Any]:
data = GenericIO.to_dict(records=list(self), flatten=flatten, orient=orient)
return data

def to_json(self, path: Union[Path, str]) -> Path:
return JsonIO.to_json(records=list(self), path=path)

def to_datasets(self) -> "HFDataset":
return HFDatasetsIO.to_datasets(records=list(self))


class DatasetRecords(Iterable[Record], LoggingMixin):
"""This class is used to work with records from a dataset and is accessed via `Dataset.records`.
Expand All @@ -142,7 +155,7 @@ def __init__(self, client: "Argilla", dataset: "Dataset"):
self._api = self.__client.api.records

def __iter__(self):
return DatasetRecordsIterator(self.__dataset, self.__client)
return DatasetRecordsIterator(self.__dataset, self.__client, with_suggestions=True, with_responses=True)

def __call__(
self,
Expand Down Expand Up @@ -286,9 +299,7 @@ def to_dict(self, flatten: bool = False, orient: str = "names") -> Dict[str, Any
A dictionary of records.
"""
records = list(self(with_suggestions=True, with_responses=True))
data = GenericIO.to_dict(records=records, flatten=flatten, orient=orient)
return data
return self().to_dict(flatten=flatten, orient=orient)

def to_list(self, flatten: bool = False) -> List[Dict[str, Any]]:
"""
Expand All @@ -300,8 +311,7 @@ def to_list(self, flatten: bool = False) -> List[Dict[str, Any]]:
Returns:
A list of dictionaries of records.
"""
records = list(self(with_suggestions=True, with_responses=True))
data = GenericIO.to_list(records=records, flatten=flatten)
data = self().to_list(flatten=flatten)
return data

def to_json(self, path: Union[Path, str]) -> Path:
Expand All @@ -315,8 +325,7 @@ def to_json(self, path: Union[Path, str]) -> Path:
The path to the file where the records were saved.
"""
records = list(self(with_suggestions=True, with_responses=True))
return JsonIO.to_json(records=records, path=path)
return self().to_json(path=path)

def from_json(self, path: Union[Path, str]) -> List[Record]:
"""Creates a DatasetRecords object from a disk path to a JSON file.
Expand All @@ -340,8 +349,8 @@ def to_datasets(self) -> HFDataset:
The dataset containing the records.
"""
records = list(self(with_suggestions=True, with_responses=True))
return HFDatasetsIO.to_datasets(records=records)

return self().to_datasets()

############################
# Private methods
Expand Down
31 changes: 31 additions & 0 deletions argilla/tests/integration/test_export_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,37 @@ def test_export_records_list_flattened(client: Argilla, dataset: rg.Dataset):
assert exported_records[0]["label.suggestion.score"] is None


def test_export_record_list_with_filtered_records(client: Argilla, dataset: rg.Dataset):
mock_data = [
{
"text": "Hello World, how are you?",
"label": "positive",
"id": uuid.uuid4(),
},
{
"text": "Hello World, how are you?",
"label": "negative",
"id": uuid.uuid4(),
},
{
"text": "Hello World, how are you?",
"label": "positive",
"id": uuid.uuid4(),
},
]
dataset.records.log(records=mock_data)
exported_records = dataset.records(query=rg.Query(query="hello")).to_list(flatten=True)
assert len(exported_records) == len(mock_data)
assert isinstance(exported_records, list)
assert isinstance(exported_records[0], dict)
assert isinstance(exported_records[0]["id"], str)
assert isinstance(exported_records[0]["text"], str)
assert isinstance(exported_records[0]["label.suggestion"], str)
assert exported_records[0]["text"] == "Hello World, how are you?"
assert exported_records[0]["label.suggestion"] == "positive"
assert exported_records[0]["label.suggestion.score"] is None


def test_export_records_list_nested(client: Argilla, dataset: rg.Dataset):
mock_data = [
{
Expand Down

0 comments on commit 955d967

Please sign in to comment.