Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] argilla: Support export action with filtered records #5054

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]:
def _is_search_query(self) -> bool:
return bool(self.__query and (self.__query.query or self.__query.filter))

def to_list(self, flatten: bool) -> List[Dict[str, Any]]:
return GenericIO.to_list(records=list(self), flatten=flatten)

def to_dict(self, flatten: bool, orient: str) -> Dict[str, Any]:
data = GenericIO.to_dict(records=list(self), flatten=flatten, orient=orient)
return data

def to_json(self, path: Union[Path, str]) -> Path:
return JsonIO.to_json(records=list(self), path=path)

def to_datasets(self) -> "HFDataset":
return HFDatasetsIO.to_datasets(records=list(self))


class DatasetRecords(Iterable[Record], LoggingMixin):
"""This class is used to work with records from a dataset and is accessed via `Dataset.records`.
Expand All @@ -142,7 +155,7 @@ def __init__(self, client: "Argilla", dataset: "Dataset"):
self._api = self.__client.api.records

def __iter__(self):
return DatasetRecordsIterator(self.__dataset, self.__client)
return DatasetRecordsIterator(self.__dataset, self.__client, with_suggestions=True, with_responses=True)

def __call__(
self,
Expand Down Expand Up @@ -286,9 +299,7 @@ def to_dict(self, flatten: bool = False, orient: str = "names") -> Dict[str, Any
A dictionary of records.

"""
records = list(self(with_suggestions=True, with_responses=True))
data = GenericIO.to_dict(records=records, flatten=flatten, orient=orient)
return data
return self().to_dict(flatten=flatten, orient=orient)

def to_list(self, flatten: bool = False) -> List[Dict[str, Any]]:
"""
Expand All @@ -300,8 +311,7 @@ def to_list(self, flatten: bool = False) -> List[Dict[str, Any]]:
Returns:
A list of dictionaries of records.
"""
records = list(self(with_suggestions=True, with_responses=True))
data = GenericIO.to_list(records=records, flatten=flatten)
data = self().to_list(flatten=flatten)
return data

def to_json(self, path: Union[Path, str]) -> Path:
Expand All @@ -315,8 +325,7 @@ def to_json(self, path: Union[Path, str]) -> Path:
The path to the file where the records were saved.

"""
records = list(self(with_suggestions=True, with_responses=True))
return JsonIO.to_json(records=records, path=path)
return self().to_json(path=path)

def from_json(self, path: Union[Path, str]) -> List[Record]:
"""Creates a DatasetRecords object from a disk path to a JSON file.
Expand All @@ -340,8 +349,8 @@ def to_datasets(self) -> HFDataset:
The dataset containing the records.

"""
records = list(self(with_suggestions=True, with_responses=True))
return HFDatasetsIO.to_datasets(records=records)

return self().to_datasets()

############################
# Private methods
Expand Down
31 changes: 31 additions & 0 deletions argilla/tests/integration/test_export_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,37 @@ def test_export_records_list_flattened(client: Argilla, dataset: rg.Dataset):
assert exported_records[0]["label.suggestion.score"] is None


def test_export_record_list_with_filtered_records(client: Argilla, dataset: rg.Dataset):
mock_data = [
{
"text": "Hello World, how are you?",
"label": "positive",
"id": uuid.uuid4(),
},
{
"text": "Hello World, how are you?",
"label": "negative",
"id": uuid.uuid4(),
},
{
"text": "Hello World, how are you?",
"label": "positive",
"id": uuid.uuid4(),
},
]
dataset.records.log(records=mock_data)
exported_records = dataset.records(query=rg.Query(query="hello")).to_list(flatten=True)
assert len(exported_records) == len(mock_data)
assert isinstance(exported_records, list)
assert isinstance(exported_records[0], dict)
assert isinstance(exported_records[0]["id"], str)
assert isinstance(exported_records[0]["text"], str)
assert isinstance(exported_records[0]["label.suggestion"], str)
assert exported_records[0]["text"] == "Hello World, how are you?"
assert exported_records[0]["label.suggestion"] == "positive"
assert exported_records[0]["label.suggestion.score"] is None


def test_export_records_list_nested(client: Argilla, dataset: rg.Dataset):
mock_data = [
{
Expand Down
Loading