Skip to content

Commit

Permalink
Merge branch 'develop' into feat/export-dataset-to-hub-feature-branch
Browse files Browse the repository at this point in the history
  • Loading branch information
jfcalvo authored Dec 16, 2024
2 parents 91dfd1d + ebb0fa2 commit bdd6232
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
4 changes: 4 additions & 0 deletions argilla/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ These are the section headers that we use:

## [Unreleased]()

### Fixed

- Fixed error when iterating over datasets and settings are not properly loaded. ([#5753](https://github.com/argilla-io/argilla/pull/5753))

## [2.5.0](https://github.com/argilla-io/argilla/compare/v2.4.0...v2.5.0)

### Added
Expand Down
6 changes: 4 additions & 2 deletions argilla/src/argilla/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ class Datasets(Sequence["Dataset"], ResourceHTMLReprMixin):
"""A collection of datasets. It can be used to create a new dataset or to get an existing one."""

class _Iterator(GenericIterator["Dataset"]):
pass
def __next__(self):
dataset = super().__next__()
return dataset.get()

def __init__(self, client: "Argilla") -> None:
self._client = client
Expand Down Expand Up @@ -366,7 +368,7 @@ def __getitem__(self, index: slice) -> Sequence["Dataset"]: ...

def __getitem__(self, index) -> "Dataset":
model = self._api.list()[index]
return self._from_model(model)
return self._from_model(model).get()

def __len__(self) -> int:
return len(self._api.list())
Expand Down
27 changes: 26 additions & 1 deletion argilla/tests/integration/test_listing_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla import Argilla, Dataset, Settings, TextField, TextQuestion, Workspace
from argilla import Argilla, Dataset, Settings, TextField, TextQuestion, Workspace, TaskDistribution


class TestDatasetsList:
Expand All @@ -33,3 +33,28 @@ def test_list_datasets(self, client: Argilla):
for ds in datasets:
if ds.name == "test_dataset":
assert ds == dataset, "The dataset was not loaded properly"

def test_list_dataset_with_custom_task_distribution(self, client: Argilla, workspace: Workspace):
dataset = Dataset(
name="test_dataset",
workspace=workspace.name,
settings=Settings(
fields=[TextField(name="text")],
questions=[TextQuestion(name="text_question")],
distribution=TaskDistribution(min_submitted=4),
),
client=client,
)
dataset.create()
datasets = client.datasets
assert len(datasets) > 0, "No datasets were found"

dataset_idx = 0
for idx, ds in enumerate(datasets):
if ds.id == dataset.id:
dataset_idx = idx
assert ds.settings.distribution.min_submitted == 4, "The dataset was not loaded properly"
break

ds = client.datasets[dataset_idx]
assert ds.settings.distribution.min_submitted == 4, "The dataset was not loaded properly"

0 comments on commit bdd6232

Please sign in to comment.