diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index 70ddf81029..00028ebd57 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -116,6 +116,19 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]: def _is_search_query(self) -> bool: return bool(self.__query and (self.__query.query or self.__query.filter)) + def to_list(self, flatten: bool) -> List[Dict[str, Any]]: + return GenericIO.to_list(records=list(self), flatten=flatten) + + def to_dict(self, flatten: bool, orient: str) -> Dict[str, Any]: + data = GenericIO.to_dict(records=list(self), flatten=flatten, orient=orient) + return data + + def to_json(self, path: Union[Path, str]) -> Path: + return JsonIO.to_json(records=list(self), path=path) + + def to_datasets(self) -> "HFDataset": + return HFDatasetsIO.to_datasets(records=list(self)) + class DatasetRecords(Iterable[Record], LoggingMixin): """This class is used to work with records from a dataset and is accessed via `Dataset.records`. @@ -142,7 +155,7 @@ def __init__(self, client: "Argilla", dataset: "Dataset"): self._api = self.__client.api.records def __iter__(self): - return DatasetRecordsIterator(self.__dataset, self.__client) + return DatasetRecordsIterator(self.__dataset, self.__client, with_suggestions=True, with_responses=True) def __call__( self, @@ -286,9 +299,7 @@ def to_dict(self, flatten: bool = False, orient: str = "names") -> Dict[str, Any A dictionary of records. """ - records = list(self(with_suggestions=True, with_responses=True)) - data = GenericIO.to_dict(records=records, flatten=flatten, orient=orient) - return data + return self().to_dict(flatten=flatten, orient=orient) def to_list(self, flatten: bool = False) -> List[Dict[str, Any]]: """ @@ -300,8 +311,7 @@ def to_list(self, flatten: bool = False) -> List[Dict[str, Any]]: Returns: A list of dictionaries of records. """ - records = list(self(with_suggestions=True, with_responses=True)) - data = GenericIO.to_list(records=records, flatten=flatten) + data = self().to_list(flatten=flatten) return data def to_json(self, path: Union[Path, str]) -> Path: @@ -315,8 +325,7 @@ def to_json(self, path: Union[Path, str]) -> Path: The path to the file where the records were saved. """ - records = list(self(with_suggestions=True, with_responses=True)) - return JsonIO.to_json(records=records, path=path) + return self().to_json(path=path) def from_json(self, path: Union[Path, str]) -> List[Record]: """Creates a DatasetRecords object from a disk path to a JSON file. @@ -340,8 +349,8 @@ def to_datasets(self) -> HFDataset: The dataset containing the records. """ - records = list(self(with_suggestions=True, with_responses=True)) - return HFDatasetsIO.to_datasets(records=records) + + return self().to_datasets() ############################ # Private methods diff --git a/argilla/tests/integration/test_export_records.py b/argilla/tests/integration/test_export_records.py index bffd569080..3465d7e64b 100644 --- a/argilla/tests/integration/test_export_records.py +++ b/argilla/tests/integration/test_export_records.py @@ -105,6 +105,37 @@ def test_export_records_list_flattened(client: Argilla, dataset: rg.Dataset): assert exported_records[0]["label.suggestion.score"] is None +def test_export_record_list_with_filtered_records(client: Argilla, dataset: rg.Dataset): + mock_data = [ + { + "text": "Hello World, how are you?", + "label": "positive", + "id": uuid.uuid4(), + }, + { + "text": "Hello World, how are you?", + "label": "negative", + "id": uuid.uuid4(), + }, + { + "text": "Hello World, how are you?", + "label": "positive", + "id": uuid.uuid4(), + }, + ] + dataset.records.log(records=mock_data) + exported_records = dataset.records(query=rg.Query(query="hello")).to_list(flatten=True) + assert len(exported_records) == len(mock_data) + assert isinstance(exported_records, list) + assert isinstance(exported_records[0], dict) + assert isinstance(exported_records[0]["id"], str) + assert isinstance(exported_records[0]["text"], str) + assert isinstance(exported_records[0]["label.suggestion"], str) + assert exported_records[0]["text"] == "Hello World, how are you?" + assert exported_records[0]["label.suggestion"] == "positive" + assert exported_records[0]["label.suggestion.score"] is None + + def test_export_records_list_nested(client: Argilla, dataset: rg.Dataset): mock_data = [ {