From 3de7a3c389e8f8b047b883f4ac622b6ffdec1501 Mon Sep 17 00:00:00 2001 From: Merlin Kallenborn Date: Tue, 28 May 2024 17:39:01 +0200 Subject: [PATCH 1/5] WIP: feat: Add test for submit two different datasets refactor: Move fixture for DefaultArgillaClient to evluation/conftest.py TASK: IL-541 --- .../connectors/argilla/argilla_client.py | 4 +++ .../evaluation/evaluator/argilla_evaluator.py | 3 ++- .../connectors/argilla/test_argilla_client.py | 16 +++++++++++ tests/evaluation/test_argilla_evaluator.py | 27 +++++++++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/intelligence_layer/connectors/argilla/argilla_client.py b/src/intelligence_layer/connectors/argilla/argilla_client.py index d0a8f5617..08a91f1a6 100644 --- a/src/intelligence_layer/connectors/argilla/argilla_client.py +++ b/src/intelligence_layer/connectors/argilla/argilla_client.py @@ -238,6 +238,10 @@ def ensure_dataset_exists( lambda: self._publish_dataset(dataset_id), ) return dataset_id + + + + def _ignore_failure_status( self, expected_failure: frozenset[HTTPStatus], f: Callable[[], None] diff --git a/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py b/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py index e6430e627..1e7c72530 100644 --- a/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py +++ b/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py @@ -132,11 +132,12 @@ def submit( self, *run_ids: str, num_examples: Optional[int] = None, + dataset_name: Optional[str] = None, abort_on_error: bool = False, ) -> PartialEvaluationOverview: argilla_dataset_id = self._client.ensure_dataset_exists( self._workspace_id, - dataset_name="name", + dataset_name, fields=list(self._evaluation_logic.fields.values()), questions=self._evaluation_logic.questions, ) diff --git a/tests/connectors/argilla/test_argilla_client.py b/tests/connectors/argilla/test_argilla_client.py index 383f8ed70..60e65f905 100644 --- a/tests/connectors/argilla/test_argilla_client.py +++ b/tests/connectors/argilla/test_argilla_client.py @@ -296,3 +296,19 @@ def test_split_dataset_can_split_long_dataset( for old_metadata, new_metadata in zip(record_metadata, new_metadata_list): del new_metadata["split"] # type: ignore assert old_metadata == new_metadata + + +@pytest.mark.docker +def test_client_creates_two_datasets_with_same_name( + argilla_client: DefaultArgillaClient, + workspace_id: str, +) -> None: + id1 = argilla_client.ensure_dataset_exists( + workspace_id, dataset_name="name", fields=[], questions=[] + ) + + id2 = argilla_client.ensure_dataset_exists( + workspace_id, dataset_name="name", fields=[], questions=[] + ) + + assert id1 == id2 diff --git a/tests/evaluation/test_argilla_evaluator.py b/tests/evaluation/test_argilla_evaluator.py index c0167d641..63f4622d5 100644 --- a/tests/evaluation/test_argilla_evaluator.py +++ b/tests/evaluation/test_argilla_evaluator.py @@ -344,3 +344,30 @@ def test_argilla_aggregation_logic_works() -> None: aggregation.scores["player_3"].elo_standard_error > aggregation.scores["player_1"].elo_standard_error ) + + +#def test_argilla_evaluator_has_distinct_names_for_datasets( +# string_argilla_runner: Runner[DummyStringInput, DummyStringOutput], +# string_dataset_id: str, +# in_memory_dataset_repository: InMemoryDatasetRepository, +# in_memory_run_repository: InMemoryRunRepository, +# async_in_memory_evaluation_repository: AsyncInMemoryEvaluationRepository, +# dummy_argilla_client: DummyArgillaClient, +#) -> None: +# workspace_id = "workspace_id" +# evaluator = ArgillaEvaluator( +# in_memory_dataset_repository, +# in_memory_run_repository, +# async_in_memory_evaluation_repository, +# "dummy-string-task", +# DummyStringTaskArgillaEvaluationLogic(), +# dummy_argilla_client, +# workspace_id, +# ) +# +# run_overview = string_argilla_runner.run_dataset(string_dataset_id) +# _ = evaluator.submit(run_overview.id) +# _ = evaluator.submit(run_overview.id) +# +# assert len(dummy_argilla_client._datasets) == 2 +# assert dummy_argilla_client._datasets.keys() != dummy_argilla_client From 1310adcaa1a28803b39a1360fe6f24071721bfc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niklas=20K=C3=B6hnecke?= Date: Wed, 29 May 2024 12:50:07 +0200 Subject: [PATCH 2/5] fix: argilla evaluator handling dataset names correctly --- .../connectors/argilla/argilla_client.py | 137 +++++++++++++++--- .../evaluation/evaluator/argilla_evaluator.py | 3 +- .../connectors/argilla/test_argilla_client.py | 70 ++++++--- tests/evaluation/conftest.py | 9 ++ tests/evaluation/test_argilla_evaluator.py | 111 ++++++++++---- ...t_instruct_comparison_argilla_evaluator.py | 9 ++ 6 files changed, 262 insertions(+), 77 deletions(-) diff --git a/src/intelligence_layer/connectors/argilla/argilla_client.py b/src/intelligence_layer/connectors/argilla/argilla_client.py index 08a91f1a6..e7c1add33 100644 --- a/src/intelligence_layer/connectors/argilla/argilla_client.py +++ b/src/intelligence_layer/connectors/argilla/argilla_client.py @@ -3,7 +3,17 @@ from abc import ABC, abstractmethod from http import HTTPStatus from itertools import chain, count, islice -from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union, cast +from typing import ( + Any, + Callable, + Iterable, + Mapping, + Optional, + Sequence, + TypeVar, + Union, + cast, +) from uuid import uuid4 from pydantic import BaseModel @@ -96,6 +106,29 @@ class ArgillaClient(ABC): the intelligence layer to create feedback datasets or retrieve evaluation results. """ + @abstractmethod + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + """Creates and publishes a new feedback dataset in Argilla. + + Raises an error if the name exists already. + + Args: + workspace_id: the id of the workspace the feedback-dataset should be created in. + The user executing this request must have corresponding permissions for this workspace. + dataset_name: the name of the feedback-dataset to be created. + fields: all fields of this dataset. + questions: all questions for this dataset. + Returns: + The id of the dataset to be retrieved . + """ + ... + @abstractmethod def ensure_dataset_exists( self, @@ -127,6 +160,16 @@ def add_record(self, dataset_id: str, record: RecordData) -> None: """ ... + def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None: + """Adds new records to be evalated to the given dataset. + + Args: + dataset_id: id of the dataset the record is added to + records: contains the actual record data (i.e. content for the dataset's fields) + """ + for record in records: + return self.add_record(dataset_id, record) + @abstractmethod def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]: """Returns all human-evaluated evaluations for the given dataset. @@ -149,6 +192,18 @@ def split_dataset(self, dataset_id: str, n_splits: int) -> None: ... +T = TypeVar("T") + + +def batch_iterator(iterable: Iterable[T], batch_size: int) -> Iterable[list[T]]: + iterator = iter(iterable) + while True: + batch = list(islice(iterator, batch_size)) + if not batch: + break + yield batch + + class DefaultArgillaClient(ArgillaClient): def __init__( self, @@ -196,7 +251,7 @@ def ensure_workspace_exists(self, workspace_name: str) -> str: ) raise e - def ensure_dataset_exists( + def create_dataset( self, workspace_id: str, dataset_name: str, @@ -204,7 +259,40 @@ def ensure_dataset_exists( questions: Sequence[Question], ) -> str: try: + print(workspace_id) dataset_id: str = self._create_dataset(dataset_name, workspace_id)["id"] + for field in fields: + self._create_field(field.name, field.title, dataset_id) + + for question in questions: + self._create_question( + question.name, + question.title, + question.description, + question.options, + dataset_id, + ) + self._publish_dataset(dataset_id) + return dataset_id + + except HTTPError as e: + if e.response.status_code == HTTPStatus.CONFLICT: + raise ValueError( + f"Cannot create dataset with name '{dataset_name}', dataset already exists." + ) + raise e + + def ensure_dataset_exists( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + try: + dataset_id: str = self.create_dataset( + workspace_id, dataset_name, fields, questions + ) except HTTPError as e: if e.response.status_code == HTTPStatus.CONFLICT: datasets = self._list_datasets(workspace_id) @@ -238,10 +326,6 @@ def ensure_dataset_exists( lambda: self._publish_dataset(dataset_id), ) return dataset_id - - - - def _ignore_failure_status( self, expected_failure: frozenset[HTTPStatus], f: Callable[[], None] @@ -253,9 +337,10 @@ def _ignore_failure_status( raise e def add_record(self, dataset_id: str, record: RecordData) -> None: - self._create_record( - record.content, record.metadata, record.example_id, dataset_id - ) + self._create_records([record], dataset_id) + + def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None: + self._create_records(records, dataset_id) def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]: def to_responses( @@ -405,6 +490,8 @@ def _list_datasets(self, workspace_id: str) -> Mapping[str, Any]: def _publish_dataset(self, dataset_id: str) -> None: url = self.api_url + f"api/v1/datasets/{dataset_id}/publish" response = self.session.put(url) + print(response) + print(response.content) response.raise_for_status() def _create_dataset( @@ -418,6 +505,7 @@ def _create_dataset( "allow_extra_metadata": True, } response = self.session.post(url, json=data) + print(response.content) response.raise_for_status() return cast(Mapping[str, Any], response.json()) @@ -485,24 +573,27 @@ def _list_records( record["example_id"] = example_id yield from cast(Sequence[Mapping[str, Any]], records) - def _create_record( + def _create_records( self, - content: Mapping[str, str], - metadata: Mapping[str, str], - example_id: str, + records: Sequence[RecordData], dataset_id: str, ) -> None: url = self.api_url + f"api/v1/datasets/{dataset_id}/records" - data = { - "items": [ - { - "fields": content, - "metadata": {**metadata, "example_id": example_id}, - } - ] - } - response = self.session.post(url, json=data) - response.raise_for_status() + for batch in batch_iterator(records, 200): + data = { + "items": [ + { + "fields": record.content, + "metadata": { + **record.metadata, + "example_id": record.example_id, + }, + } + for record in batch + ] + } + response = self.session.post(url, json=data) + response.raise_for_status() def delete_workspace(self, workspace_id: str) -> None: for dataset in self._list_datasets(workspace_id)["items"]: diff --git a/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py b/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py index 1e7c72530..b7923b769 100644 --- a/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py +++ b/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py @@ -3,6 +3,7 @@ from datetime import datetime from itertools import combinations from typing import Mapping, Optional, Sequence +from uuid import uuid4 from pydantic import BaseModel @@ -137,7 +138,7 @@ def submit( ) -> PartialEvaluationOverview: argilla_dataset_id = self._client.ensure_dataset_exists( self._workspace_id, - dataset_name, + dataset_name if dataset_name else str(uuid4()), fields=list(self._evaluation_logic.fields.values()), questions=self._evaluation_logic.questions, ) diff --git a/tests/connectors/argilla/test_argilla_client.py b/tests/connectors/argilla/test_argilla_client.py index 60e65f905..8dfacf2af 100644 --- a/tests/connectors/argilla/test_argilla_client.py +++ b/tests/connectors/argilla/test_argilla_client.py @@ -46,7 +46,7 @@ def retry( @fixture def argilla_client() -> DefaultArgillaClient: load_dotenv() - return DefaultArgillaClient(total_retries=8) + return DefaultArgillaClient(total_retries=1) @fixture @@ -78,6 +78,50 @@ def qa_dataset_id(argilla_client: DefaultArgillaClient, workspace_id: str) -> st ) +@pytest.mark.docker +def test_client_can_create_a_dataset( + argilla_client: DefaultArgillaClient, + workspace_id: str, +) -> None: + id = argilla_client.create_dataset( + workspace_id, + dataset_name="name", + fields=[Field(name="a", title="b")], + questions=[ + Question(name="a", title="b", description="c", options=list(range(1, 5))) + ], + ) + datasets = argilla_client._list_datasets(workspace_id) + assert len(argilla_client._list_datasets(workspace_id)) == 1 + assert id == datasets["items"][0]["id"] + + +@pytest.mark.docker +def test_client_cannot_create_two_datasets_with_the_same_name( + argilla_client: DefaultArgillaClient, + workspace_id: str, +) -> None: + argilla_client.create_dataset( + workspace_id, + dataset_name="name", + fields=[Field(name="a", title="b")], + questions=[ + Question(name="a", title="b", description="c", options=list(range(1, 5))) + ], + ) + with pytest.raises(ValueError): + argilla_client.create_dataset( + workspace_id, + dataset_name="name", + fields=[Field(name="a", title="b")], + questions=[ + Question( + name="a", title="b", description="c", options=list(range(1, 5)) + ) + ], + ) + + @fixture def qa_records( argilla_client: ArgillaClient, qa_dataset_id: str @@ -90,8 +134,7 @@ def qa_records( ) for i in range(60) ] - for record in records: - argilla_client.add_record(qa_dataset_id, record) + argilla_client.add_records(qa_dataset_id, records) return records @@ -107,13 +150,12 @@ def long_qa_records( ) for i in range(1024) ] - for record in records: - argilla_client.add_record(qa_dataset_id, record) + argilla_client.add_records(qa_dataset_id, records) return records @pytest.mark.docker -def test_error_on_non_existent_dataset( +def test_retrieving_records_on_non_existant_dataset_raises_errors( argilla_client: DefaultArgillaClient, ) -> None: with pytest.raises(HTTPError): @@ -296,19 +338,3 @@ def test_split_dataset_can_split_long_dataset( for old_metadata, new_metadata in zip(record_metadata, new_metadata_list): del new_metadata["split"] # type: ignore assert old_metadata == new_metadata - - -@pytest.mark.docker -def test_client_creates_two_datasets_with_same_name( - argilla_client: DefaultArgillaClient, - workspace_id: str, -) -> None: - id1 = argilla_client.ensure_dataset_exists( - workspace_id, dataset_name="name", fields=[], questions=[] - ) - - id2 = argilla_client.ensure_dataset_exists( - workspace_id, dataset_name="name", fields=[], questions=[] - ) - - assert id1 == id2 diff --git a/tests/evaluation/conftest.py b/tests/evaluation/conftest.py index c15362589..833d5df79 100644 --- a/tests/evaluation/conftest.py +++ b/tests/evaluation/conftest.py @@ -157,6 +157,15 @@ class StubArgillaClient(ArgillaClient): _datasets: dict[str, list[RecordData]] = {} _score = 3.0 + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions) + def ensure_dataset_exists( self, workspace_id: str, diff --git a/tests/evaluation/test_argilla_evaluator.py b/tests/evaluation/test_argilla_evaluator.py index 63f4622d5..b4e993c91 100644 --- a/tests/evaluation/test_argilla_evaluator.py +++ b/tests/evaluation/test_argilla_evaluator.py @@ -87,8 +87,20 @@ class CustomException(Exception): class DummyArgillaClient(ArgillaClient): - _datasets: dict[str, list[RecordData]] = {} - _score = 3.0 + def __init__(self) -> None: + self._datasets: dict[str, list[RecordData]] = {} + self._names: dict[str, str] = {} + self._score = 3.0 + + + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions) def ensure_dataset_exists( self, @@ -99,6 +111,7 @@ def ensure_dataset_exists( ) -> str: dataset_id = str(uuid4()) self._datasets[dataset_id] = [] + self._names[dataset_id] = dataset_name return dataset_id def add_record(self, dataset_id: str, record: RecordData) -> None: @@ -129,6 +142,15 @@ class FailedEvaluationDummyArgillaClient(ArgillaClient): _upload_count = 0 _datasets: dict[str, list[RecordData]] = {} + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions) + def ensure_dataset_exists( self, workspace_id: str, @@ -165,8 +187,8 @@ def split_dataset(self, dataset_id: str, n_splits: int) -> None: @fixture -def arg() -> StubArgillaClient: - return StubArgillaClient() +def dummy_client() -> DummyArgillaClient: + return DummyArgillaClient() @fixture() @@ -233,13 +255,14 @@ def test_argilla_evaluator_can_submit_evals_to_argilla( in_memory_run_repository: InMemoryRunRepository, async_in_memory_evaluation_repository: AsyncInMemoryEvaluationRepository, ) -> None: + client = DummyArgillaClient() evaluator = ArgillaEvaluator( in_memory_dataset_repository, in_memory_run_repository, async_in_memory_evaluation_repository, "dummy-string-task", DummyStringTaskArgillaEvaluationLogic(), - DummyArgillaClient(), + client, workspace_id="workspace-id", ) @@ -264,7 +287,7 @@ def test_argilla_evaluator_can_submit_evals_to_argilla( ) assert len(list(async_in_memory_evaluation_repository.evaluation_overviews())) == 1 - assert len(DummyArgillaClient()._datasets[partial_evaluation_overview.id]) == 1 + assert len(client._datasets[partial_evaluation_overview.id]) == 1 def test_argilla_evaluator_correctly_lists_failed_eval_counts( @@ -346,28 +369,54 @@ def test_argilla_aggregation_logic_works() -> None: ) -#def test_argilla_evaluator_has_distinct_names_for_datasets( -# string_argilla_runner: Runner[DummyStringInput, DummyStringOutput], -# string_dataset_id: str, -# in_memory_dataset_repository: InMemoryDatasetRepository, -# in_memory_run_repository: InMemoryRunRepository, -# async_in_memory_evaluation_repository: AsyncInMemoryEvaluationRepository, -# dummy_argilla_client: DummyArgillaClient, -#) -> None: -# workspace_id = "workspace_id" -# evaluator = ArgillaEvaluator( -# in_memory_dataset_repository, -# in_memory_run_repository, -# async_in_memory_evaluation_repository, -# "dummy-string-task", -# DummyStringTaskArgillaEvaluationLogic(), -# dummy_argilla_client, -# workspace_id, -# ) -# -# run_overview = string_argilla_runner.run_dataset(string_dataset_id) -# _ = evaluator.submit(run_overview.id) -# _ = evaluator.submit(run_overview.id) -# -# assert len(dummy_argilla_client._datasets) == 2 -# assert dummy_argilla_client._datasets.keys() != dummy_argilla_client +def test_argilla_evaluator_has_distinct_names_for_datasets( + string_argilla_runner: Runner[DummyStringInput, DummyStringOutput], + string_dataset_id: str, + in_memory_dataset_repository: InMemoryDatasetRepository, + in_memory_run_repository: InMemoryRunRepository, + async_in_memory_evaluation_repository: AsyncInMemoryEvaluationRepository, + dummy_client: DummyArgillaClient, +) -> None: + workspace_id = "workspace_id" + evaluator = ArgillaEvaluator( + in_memory_dataset_repository, + in_memory_run_repository, + async_in_memory_evaluation_repository, + "dummy-string-task", + DummyStringTaskArgillaEvaluationLogic(), + dummy_client, + workspace_id, + ) + + run_overview = string_argilla_runner.run_dataset(string_dataset_id) + _ = evaluator.submit(run_overview.id) + _ = evaluator.submit(run_overview.id) + + assert len(dummy_client._datasets) == 2 + assert dummy_client._datasets.keys() != dummy_client + + +def test_argilla_evaluator_can_take_name( + string_argilla_runner: Runner[DummyStringInput, DummyStringOutput], + string_dataset_id: str, + in_memory_dataset_repository: InMemoryDatasetRepository, + in_memory_run_repository: InMemoryRunRepository, + async_in_memory_evaluation_repository: AsyncInMemoryEvaluationRepository, + dummy_client: DummyArgillaClient, +) -> None: + workspace_id = "workspace_id" + evaluator = ArgillaEvaluator( + in_memory_dataset_repository, + in_memory_run_repository, + async_in_memory_evaluation_repository, + "dummy-string-task", + DummyStringTaskArgillaEvaluationLogic(), + dummy_client, + workspace_id, + ) + + run_overview = string_argilla_runner.run_dataset(string_dataset_id) + dataset_id = evaluator.submit(run_overview.id, dataset_name="my-dataset").id + + assert len(dummy_client._datasets) == 1 + assert dummy_client._names[dataset_id] == "my-dataset" diff --git a/tests/evaluation/test_instruct_comparison_argilla_evaluator.py b/tests/evaluation/test_instruct_comparison_argilla_evaluator.py index d6e85e2be..a8c2ad9df 100644 --- a/tests/evaluation/test_instruct_comparison_argilla_evaluator.py +++ b/tests/evaluation/test_instruct_comparison_argilla_evaluator.py @@ -37,6 +37,15 @@ class ArgillaFake(ArgillaClient): def __init__(self) -> None: self.records: dict[str, list[RecordData]] = defaultdict(list) + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions) + def ensure_dataset_exists( self, workspace_id: str, From a25be6c4dbae127674c656cb9ba5af94a299a844 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niklas=20K=C3=B6hnecke?= Date: Wed, 29 May 2024 12:50:20 +0200 Subject: [PATCH 3/5] feat: improve local test readability --- scripts/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test.sh b/scripts/test.sh index b3d1512ff..9b1d85893 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -1,3 +1,3 @@ #!/usr/bin/env -S bash -eu -o pipefail -poetry run pytest -n 10 +TQDM_DISABLE=1 poetry run pytest -n 10 From 382826112bf0dcc1b481eafc1908a37b9203f904 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niklas=20K=C3=B6hnecke?= Date: Wed, 29 May 2024 13:06:48 +0200 Subject: [PATCH 4/5] docs: update changelog and docs --- CHANGELOG.md | 6 +++++- .../connectors/argilla/argilla_client.py | 8 ++------ .../evaluation/evaluation/evaluator/argilla_evaluator.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e3a05ea4..0b8d4d92f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ - We updated the graders to support python 3.12 and moved away from `nltk`-package: - `BleuGrader` now uses `sacrebleu`-package. - `RougeGrader` now uses the `rouge_score`-package. + - When using the `ArgillaEvaluator`, attempting to submit to a dataset, which already exists, will no longer work append to the dataset. This makes it more in-line with other evaluation concepts. + - Instead of appending to an active argilla dataset, you now need to create a new dataset, retrieve it and then finally combine both datasets in the aggregation step. + - The `ArgillaClient` now has methods `create_dataset` for less fault-ignoring dataset creation and `add_records` for performant uploads. ### New Features - Add `how_to_implement_incremental_evaluation`. @@ -23,7 +26,8 @@ - We now support python 3.12 ### Fixes -- The document index client now correctly URL-encodes document names in its queries. + - The document index client now correctly URL-encodes document names in its queries. + - The `ArgillaEvaluator` not properly supports `dataset_name` ### Deprecations ... diff --git a/src/intelligence_layer/connectors/argilla/argilla_client.py b/src/intelligence_layer/connectors/argilla/argilla_client.py index e7c1add33..8868173c0 100644 --- a/src/intelligence_layer/connectors/argilla/argilla_client.py +++ b/src/intelligence_layer/connectors/argilla/argilla_client.py @@ -156,7 +156,7 @@ def add_record(self, dataset_id: str, record: RecordData) -> None: Args: dataset_id: id of the dataset the record is added to - record: contains the actual record data (i.e. content for the dataset's fields) + record: the actual record data (i.e. content for the dataset's fields) """ ... @@ -165,7 +165,7 @@ def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None: Args: dataset_id: id of the dataset the record is added to - records: contains the actual record data (i.e. content for the dataset's fields) + records: list containing the record data (i.e. content for the dataset's fields) """ for record in records: return self.add_record(dataset_id, record) @@ -259,7 +259,6 @@ def create_dataset( questions: Sequence[Question], ) -> str: try: - print(workspace_id) dataset_id: str = self._create_dataset(dataset_name, workspace_id)["id"] for field in fields: self._create_field(field.name, field.title, dataset_id) @@ -490,8 +489,6 @@ def _list_datasets(self, workspace_id: str) -> Mapping[str, Any]: def _publish_dataset(self, dataset_id: str) -> None: url = self.api_url + f"api/v1/datasets/{dataset_id}/publish" response = self.session.put(url) - print(response) - print(response.content) response.raise_for_status() def _create_dataset( @@ -505,7 +502,6 @@ def _create_dataset( "allow_extra_metadata": True, } response = self.session.post(url, json=data) - print(response.content) response.raise_for_status() return cast(Mapping[str, Any], response.json()) diff --git a/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py b/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py index b7923b769..05a13e92c 100644 --- a/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py +++ b/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py @@ -136,7 +136,7 @@ def submit( dataset_name: Optional[str] = None, abort_on_error: bool = False, ) -> PartialEvaluationOverview: - argilla_dataset_id = self._client.ensure_dataset_exists( + argilla_dataset_id = self._client.create_dataset( self._workspace_id, dataset_name if dataset_name else str(uuid4()), fields=list(self._evaluation_logic.fields.values()), From 1473472ef30ea7c3bb0d056f91b85201df7ef275 Mon Sep 17 00:00:00 2001 From: Sebastian Niehus Date: Wed, 29 May 2024 14:42:36 +0200 Subject: [PATCH 5/5] feat: Small improvements to tests and docstrings * Fixed outdated run parameters in Concepts.md TASK: IL-541 --- Concepts.md | 2 +- .../connectors/argilla/argilla_client.py | 9 +++++---- tests/connectors/argilla/test_argilla_client.py | 9 +++++---- tests/evaluation/test_argilla_evaluator.py | 10 +++++----- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/Concepts.md b/Concepts.md index f45fe4ea2..e156f2970 100644 --- a/Concepts.md +++ b/Concepts.md @@ -41,7 +41,7 @@ task: ```Python class Task(ABC, Generic[Input, Output]): @final - def run(self, input: Input, tracer: Tracer, trace_id: Optional[str] = None) -> Output: + def run(self, input: Input, tracer: Tracer) -> Output: ... ``` diff --git a/src/intelligence_layer/connectors/argilla/argilla_client.py b/src/intelligence_layer/connectors/argilla/argilla_client.py index 8868173c0..b9db07c23 100644 --- a/src/intelligence_layer/connectors/argilla/argilla_client.py +++ b/src/intelligence_layer/connectors/argilla/argilla_client.py @@ -125,7 +125,7 @@ def create_dataset( fields: all fields of this dataset. questions: all questions for this dataset. Returns: - The id of the dataset to be retrieved . + The id of the created dataset. """ ... @@ -152,7 +152,7 @@ def ensure_dataset_exists( @abstractmethod def add_record(self, dataset_id: str, record: RecordData) -> None: - """Adds a new record to be evalated to the given dataset. + """Adds a new record to the given dataset. Args: dataset_id: id of the dataset the record is added to @@ -161,7 +161,7 @@ def add_record(self, dataset_id: str, record: RecordData) -> None: ... def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None: - """Adds new records to be evalated to the given dataset. + """Adds new records to the given dataset. Args: dataset_id: id of the dataset the record is added to @@ -277,7 +277,8 @@ def create_dataset( except HTTPError as e: if e.response.status_code == HTTPStatus.CONFLICT: raise ValueError( - f"Cannot create dataset with name '{dataset_name}', dataset already exists." + f"Cannot create dataset with name '{dataset_name}', either the given dataset name, already exists" + f"or field name or question name are duplicates." ) raise e diff --git a/tests/connectors/argilla/test_argilla_client.py b/tests/connectors/argilla/test_argilla_client.py index 8dfacf2af..e2c16675a 100644 --- a/tests/connectors/argilla/test_argilla_client.py +++ b/tests/connectors/argilla/test_argilla_client.py @@ -83,7 +83,7 @@ def test_client_can_create_a_dataset( argilla_client: DefaultArgillaClient, workspace_id: str, ) -> None: - id = argilla_client.create_dataset( + dataset_id = argilla_client.create_dataset( workspace_id, dataset_name="name", fields=[Field(name="a", title="b")], @@ -93,7 +93,7 @@ def test_client_can_create_a_dataset( ) datasets = argilla_client._list_datasets(workspace_id) assert len(argilla_client._list_datasets(workspace_id)) == 1 - assert id == datasets["items"][0]["id"] + assert dataset_id == datasets["items"][0]["id"] @pytest.mark.docker @@ -101,9 +101,10 @@ def test_client_cannot_create_two_datasets_with_the_same_name( argilla_client: DefaultArgillaClient, workspace_id: str, ) -> None: + dataset_name = str(uuid4()) argilla_client.create_dataset( workspace_id, - dataset_name="name", + dataset_name=dataset_name, fields=[Field(name="a", title="b")], questions=[ Question(name="a", title="b", description="c", options=list(range(1, 5))) @@ -112,7 +113,7 @@ def test_client_cannot_create_two_datasets_with_the_same_name( with pytest.raises(ValueError): argilla_client.create_dataset( workspace_id, - dataset_name="name", + dataset_name=dataset_name, fields=[Field(name="a", title="b")], questions=[ Question( diff --git a/tests/evaluation/test_argilla_evaluator.py b/tests/evaluation/test_argilla_evaluator.py index b4e993c91..34326c294 100644 --- a/tests/evaluation/test_argilla_evaluator.py +++ b/tests/evaluation/test_argilla_evaluator.py @@ -91,7 +91,6 @@ def __init__(self) -> None: self._datasets: dict[str, list[RecordData]] = {} self._names: dict[str, str] = {} self._score = 3.0 - def create_dataset( self, @@ -389,8 +388,8 @@ def test_argilla_evaluator_has_distinct_names_for_datasets( ) run_overview = string_argilla_runner.run_dataset(string_dataset_id) - _ = evaluator.submit(run_overview.id) - _ = evaluator.submit(run_overview.id) + evaluator.submit(run_overview.id) + evaluator.submit(run_overview.id) assert len(dummy_client._datasets) == 2 assert dummy_client._datasets.keys() != dummy_client @@ -415,8 +414,9 @@ def test_argilla_evaluator_can_take_name( workspace_id, ) + dataset_name = str(uuid4()) run_overview = string_argilla_runner.run_dataset(string_dataset_id) - dataset_id = evaluator.submit(run_overview.id, dataset_name="my-dataset").id + dataset_id = evaluator.submit(run_overview.id, dataset_name=dataset_name).id assert len(dummy_client._datasets) == 1 - assert dummy_client._names[dataset_id] == "my-dataset" + assert dummy_client._names[dataset_id] == dataset_name