diff --git a/argilla/CHANGELOG.md b/argilla/CHANGELOG.md index b0a76c59ff..1c4cf3fc37 100644 --- a/argilla/CHANGELOG.md +++ b/argilla/CHANGELOG.md @@ -16,6 +16,10 @@ These are the section headers that we use: ## [Unreleased]() +### Fixed + +- Fixed error when iterating over datasets and settings are not properly loaded. ([#5753](https://github.com/argilla-io/argilla/pull/5753)) + ## [2.5.0](https://github.com/argilla-io/argilla/compare/v2.4.0...v2.5.0) ### Added diff --git a/argilla/src/argilla/client.py b/argilla/src/argilla/client.py index 9d0c666304..96c1f01f53 100644 --- a/argilla/src/argilla/client.py +++ b/argilla/src/argilla/client.py @@ -312,7 +312,9 @@ class Datasets(Sequence["Dataset"], ResourceHTMLReprMixin): """A collection of datasets. It can be used to create a new dataset or to get an existing one.""" class _Iterator(GenericIterator["Dataset"]): - pass + def __next__(self): + dataset = super().__next__() + return dataset.get() def __init__(self, client: "Argilla") -> None: self._client = client @@ -366,7 +368,7 @@ def __getitem__(self, index: slice) -> Sequence["Dataset"]: ... def __getitem__(self, index) -> "Dataset": model = self._api.list()[index] - return self._from_model(model) + return self._from_model(model).get() def __len__(self) -> int: return len(self._api.list()) diff --git a/argilla/tests/integration/test_listing_datasets.py b/argilla/tests/integration/test_listing_datasets.py index c228968405..f47766c852 100644 --- a/argilla/tests/integration/test_listing_datasets.py +++ b/argilla/tests/integration/test_listing_datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla import Argilla, Dataset, Settings, TextField, TextQuestion, Workspace +from argilla import Argilla, Dataset, Settings, TextField, TextQuestion, Workspace, TaskDistribution class TestDatasetsList: @@ -33,3 +33,28 @@ def test_list_datasets(self, client: Argilla): for ds in datasets: if ds.name == "test_dataset": assert ds == dataset, "The dataset was not loaded properly" + + def test_list_dataset_with_custom_task_distribution(self, client: Argilla, workspace: Workspace): + dataset = Dataset( + name="test_dataset", + workspace=workspace.name, + settings=Settings( + fields=[TextField(name="text")], + questions=[TextQuestion(name="text_question")], + distribution=TaskDistribution(min_submitted=4), + ), + client=client, + ) + dataset.create() + datasets = client.datasets + assert len(datasets) > 0, "No datasets were found" + + dataset_idx = 0 + for idx, ds in enumerate(datasets): + if ds.id == dataset.id: + dataset_idx = idx + assert ds.settings.distribution.min_submitted == 4, "The dataset was not loaded properly" + break + + ds = client.datasets[dataset_idx] + assert ds.settings.distribution.min_submitted == 4, "The dataset was not loaded properly"