diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e3a05ea4..0b8d4d92f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ - We updated the graders to support python 3.12 and moved away from `nltk`-package: - `BleuGrader` now uses `sacrebleu`-package. - `RougeGrader` now uses the `rouge_score`-package. + - When using the `ArgillaEvaluator`, attempting to submit to a dataset, which already exists, will no longer work append to the dataset. This makes it more in-line with other evaluation concepts. + - Instead of appending to an active argilla dataset, you now need to create a new dataset, retrieve it and then finally combine both datasets in the aggregation step. + - The `ArgillaClient` now has methods `create_dataset` for less fault-ignoring dataset creation and `add_records` for performant uploads. ### New Features - Add `how_to_implement_incremental_evaluation`. @@ -23,7 +26,8 @@ - We now support python 3.12 ### Fixes -- The document index client now correctly URL-encodes document names in its queries. + - The document index client now correctly URL-encodes document names in its queries. + - The `ArgillaEvaluator` not properly supports `dataset_name` ### Deprecations ... diff --git a/Concepts.md b/Concepts.md index f45fe4ea2..e156f2970 100644 --- a/Concepts.md +++ b/Concepts.md @@ -41,7 +41,7 @@ task: ```Python class Task(ABC, Generic[Input, Output]): @final - def run(self, input: Input, tracer: Tracer, trace_id: Optional[str] = None) -> Output: + def run(self, input: Input, tracer: Tracer) -> Output: ... ``` diff --git a/scripts/test.sh b/scripts/test.sh index b3d1512ff..9b1d85893 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -1,3 +1,3 @@ #!/usr/bin/env -S bash -eu -o pipefail -poetry run pytest -n 10 +TQDM_DISABLE=1 poetry run pytest -n 10 diff --git a/src/intelligence_layer/connectors/argilla/argilla_client.py b/src/intelligence_layer/connectors/argilla/argilla_client.py index d0a8f5617..b9db07c23 100644 --- a/src/intelligence_layer/connectors/argilla/argilla_client.py +++ b/src/intelligence_layer/connectors/argilla/argilla_client.py @@ -3,7 +3,17 @@ from abc import ABC, abstractmethod from http import HTTPStatus from itertools import chain, count, islice -from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union, cast +from typing import ( + Any, + Callable, + Iterable, + Mapping, + Optional, + Sequence, + TypeVar, + Union, + cast, +) from uuid import uuid4 from pydantic import BaseModel @@ -96,6 +106,29 @@ class ArgillaClient(ABC): the intelligence layer to create feedback datasets or retrieve evaluation results. """ + @abstractmethod + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + """Creates and publishes a new feedback dataset in Argilla. + + Raises an error if the name exists already. + + Args: + workspace_id: the id of the workspace the feedback-dataset should be created in. + The user executing this request must have corresponding permissions for this workspace. + dataset_name: the name of the feedback-dataset to be created. + fields: all fields of this dataset. + questions: all questions for this dataset. + Returns: + The id of the created dataset. + """ + ... + @abstractmethod def ensure_dataset_exists( self, @@ -119,14 +152,24 @@ def ensure_dataset_exists( @abstractmethod def add_record(self, dataset_id: str, record: RecordData) -> None: - """Adds a new record to be evalated to the given dataset. + """Adds a new record to the given dataset. Args: dataset_id: id of the dataset the record is added to - record: contains the actual record data (i.e. content for the dataset's fields) + record: the actual record data (i.e. content for the dataset's fields) """ ... + def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None: + """Adds new records to the given dataset. + + Args: + dataset_id: id of the dataset the record is added to + records: list containing the record data (i.e. content for the dataset's fields) + """ + for record in records: + return self.add_record(dataset_id, record) + @abstractmethod def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]: """Returns all human-evaluated evaluations for the given dataset. @@ -149,6 +192,18 @@ def split_dataset(self, dataset_id: str, n_splits: int) -> None: ... +T = TypeVar("T") + + +def batch_iterator(iterable: Iterable[T], batch_size: int) -> Iterable[list[T]]: + iterator = iter(iterable) + while True: + batch = list(islice(iterator, batch_size)) + if not batch: + break + yield batch + + class DefaultArgillaClient(ArgillaClient): def __init__( self, @@ -196,7 +251,7 @@ def ensure_workspace_exists(self, workspace_name: str) -> str: ) raise e - def ensure_dataset_exists( + def create_dataset( self, workspace_id: str, dataset_name: str, @@ -205,6 +260,39 @@ def ensure_dataset_exists( ) -> str: try: dataset_id: str = self._create_dataset(dataset_name, workspace_id)["id"] + for field in fields: + self._create_field(field.name, field.title, dataset_id) + + for question in questions: + self._create_question( + question.name, + question.title, + question.description, + question.options, + dataset_id, + ) + self._publish_dataset(dataset_id) + return dataset_id + + except HTTPError as e: + if e.response.status_code == HTTPStatus.CONFLICT: + raise ValueError( + f"Cannot create dataset with name '{dataset_name}', either the given dataset name, already exists" + f"or field name or question name are duplicates." + ) + raise e + + def ensure_dataset_exists( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + try: + dataset_id: str = self.create_dataset( + workspace_id, dataset_name, fields, questions + ) except HTTPError as e: if e.response.status_code == HTTPStatus.CONFLICT: datasets = self._list_datasets(workspace_id) @@ -249,9 +337,10 @@ def _ignore_failure_status( raise e def add_record(self, dataset_id: str, record: RecordData) -> None: - self._create_record( - record.content, record.metadata, record.example_id, dataset_id - ) + self._create_records([record], dataset_id) + + def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None: + self._create_records(records, dataset_id) def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]: def to_responses( @@ -481,24 +570,27 @@ def _list_records( record["example_id"] = example_id yield from cast(Sequence[Mapping[str, Any]], records) - def _create_record( + def _create_records( self, - content: Mapping[str, str], - metadata: Mapping[str, str], - example_id: str, + records: Sequence[RecordData], dataset_id: str, ) -> None: url = self.api_url + f"api/v1/datasets/{dataset_id}/records" - data = { - "items": [ - { - "fields": content, - "metadata": {**metadata, "example_id": example_id}, - } - ] - } - response = self.session.post(url, json=data) - response.raise_for_status() + for batch in batch_iterator(records, 200): + data = { + "items": [ + { + "fields": record.content, + "metadata": { + **record.metadata, + "example_id": record.example_id, + }, + } + for record in batch + ] + } + response = self.session.post(url, json=data) + response.raise_for_status() def delete_workspace(self, workspace_id: str) -> None: for dataset in self._list_datasets(workspace_id)["items"]: diff --git a/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py b/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py index e6430e627..05a13e92c 100644 --- a/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py +++ b/src/intelligence_layer/evaluation/evaluation/evaluator/argilla_evaluator.py @@ -3,6 +3,7 @@ from datetime import datetime from itertools import combinations from typing import Mapping, Optional, Sequence +from uuid import uuid4 from pydantic import BaseModel @@ -132,11 +133,12 @@ def submit( self, *run_ids: str, num_examples: Optional[int] = None, + dataset_name: Optional[str] = None, abort_on_error: bool = False, ) -> PartialEvaluationOverview: - argilla_dataset_id = self._client.ensure_dataset_exists( + argilla_dataset_id = self._client.create_dataset( self._workspace_id, - dataset_name="name", + dataset_name if dataset_name else str(uuid4()), fields=list(self._evaluation_logic.fields.values()), questions=self._evaluation_logic.questions, ) diff --git a/tests/connectors/argilla/test_argilla_client.py b/tests/connectors/argilla/test_argilla_client.py index 383f8ed70..e2c16675a 100644 --- a/tests/connectors/argilla/test_argilla_client.py +++ b/tests/connectors/argilla/test_argilla_client.py @@ -46,7 +46,7 @@ def retry( @fixture def argilla_client() -> DefaultArgillaClient: load_dotenv() - return DefaultArgillaClient(total_retries=8) + return DefaultArgillaClient(total_retries=1) @fixture @@ -78,6 +78,51 @@ def qa_dataset_id(argilla_client: DefaultArgillaClient, workspace_id: str) -> st ) +@pytest.mark.docker +def test_client_can_create_a_dataset( + argilla_client: DefaultArgillaClient, + workspace_id: str, +) -> None: + dataset_id = argilla_client.create_dataset( + workspace_id, + dataset_name="name", + fields=[Field(name="a", title="b")], + questions=[ + Question(name="a", title="b", description="c", options=list(range(1, 5))) + ], + ) + datasets = argilla_client._list_datasets(workspace_id) + assert len(argilla_client._list_datasets(workspace_id)) == 1 + assert dataset_id == datasets["items"][0]["id"] + + +@pytest.mark.docker +def test_client_cannot_create_two_datasets_with_the_same_name( + argilla_client: DefaultArgillaClient, + workspace_id: str, +) -> None: + dataset_name = str(uuid4()) + argilla_client.create_dataset( + workspace_id, + dataset_name=dataset_name, + fields=[Field(name="a", title="b")], + questions=[ + Question(name="a", title="b", description="c", options=list(range(1, 5))) + ], + ) + with pytest.raises(ValueError): + argilla_client.create_dataset( + workspace_id, + dataset_name=dataset_name, + fields=[Field(name="a", title="b")], + questions=[ + Question( + name="a", title="b", description="c", options=list(range(1, 5)) + ) + ], + ) + + @fixture def qa_records( argilla_client: ArgillaClient, qa_dataset_id: str @@ -90,8 +135,7 @@ def qa_records( ) for i in range(60) ] - for record in records: - argilla_client.add_record(qa_dataset_id, record) + argilla_client.add_records(qa_dataset_id, records) return records @@ -107,13 +151,12 @@ def long_qa_records( ) for i in range(1024) ] - for record in records: - argilla_client.add_record(qa_dataset_id, record) + argilla_client.add_records(qa_dataset_id, records) return records @pytest.mark.docker -def test_error_on_non_existent_dataset( +def test_retrieving_records_on_non_existant_dataset_raises_errors( argilla_client: DefaultArgillaClient, ) -> None: with pytest.raises(HTTPError): diff --git a/tests/evaluation/conftest.py b/tests/evaluation/conftest.py index c15362589..833d5df79 100644 --- a/tests/evaluation/conftest.py +++ b/tests/evaluation/conftest.py @@ -157,6 +157,15 @@ class StubArgillaClient(ArgillaClient): _datasets: dict[str, list[RecordData]] = {} _score = 3.0 + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions) + def ensure_dataset_exists( self, workspace_id: str, diff --git a/tests/evaluation/test_argilla_evaluator.py b/tests/evaluation/test_argilla_evaluator.py index c0167d641..34326c294 100644 --- a/tests/evaluation/test_argilla_evaluator.py +++ b/tests/evaluation/test_argilla_evaluator.py @@ -87,8 +87,19 @@ class CustomException(Exception): class DummyArgillaClient(ArgillaClient): - _datasets: dict[str, list[RecordData]] = {} - _score = 3.0 + def __init__(self) -> None: + self._datasets: dict[str, list[RecordData]] = {} + self._names: dict[str, str] = {} + self._score = 3.0 + + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions) def ensure_dataset_exists( self, @@ -99,6 +110,7 @@ def ensure_dataset_exists( ) -> str: dataset_id = str(uuid4()) self._datasets[dataset_id] = [] + self._names[dataset_id] = dataset_name return dataset_id def add_record(self, dataset_id: str, record: RecordData) -> None: @@ -129,6 +141,15 @@ class FailedEvaluationDummyArgillaClient(ArgillaClient): _upload_count = 0 _datasets: dict[str, list[RecordData]] = {} + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions) + def ensure_dataset_exists( self, workspace_id: str, @@ -165,8 +186,8 @@ def split_dataset(self, dataset_id: str, n_splits: int) -> None: @fixture -def arg() -> StubArgillaClient: - return StubArgillaClient() +def dummy_client() -> DummyArgillaClient: + return DummyArgillaClient() @fixture() @@ -233,13 +254,14 @@ def test_argilla_evaluator_can_submit_evals_to_argilla( in_memory_run_repository: InMemoryRunRepository, async_in_memory_evaluation_repository: AsyncInMemoryEvaluationRepository, ) -> None: + client = DummyArgillaClient() evaluator = ArgillaEvaluator( in_memory_dataset_repository, in_memory_run_repository, async_in_memory_evaluation_repository, "dummy-string-task", DummyStringTaskArgillaEvaluationLogic(), - DummyArgillaClient(), + client, workspace_id="workspace-id", ) @@ -264,7 +286,7 @@ def test_argilla_evaluator_can_submit_evals_to_argilla( ) assert len(list(async_in_memory_evaluation_repository.evaluation_overviews())) == 1 - assert len(DummyArgillaClient()._datasets[partial_evaluation_overview.id]) == 1 + assert len(client._datasets[partial_evaluation_overview.id]) == 1 def test_argilla_evaluator_correctly_lists_failed_eval_counts( @@ -344,3 +366,57 @@ def test_argilla_aggregation_logic_works() -> None: aggregation.scores["player_3"].elo_standard_error > aggregation.scores["player_1"].elo_standard_error ) + + +def test_argilla_evaluator_has_distinct_names_for_datasets( + string_argilla_runner: Runner[DummyStringInput, DummyStringOutput], + string_dataset_id: str, + in_memory_dataset_repository: InMemoryDatasetRepository, + in_memory_run_repository: InMemoryRunRepository, + async_in_memory_evaluation_repository: AsyncInMemoryEvaluationRepository, + dummy_client: DummyArgillaClient, +) -> None: + workspace_id = "workspace_id" + evaluator = ArgillaEvaluator( + in_memory_dataset_repository, + in_memory_run_repository, + async_in_memory_evaluation_repository, + "dummy-string-task", + DummyStringTaskArgillaEvaluationLogic(), + dummy_client, + workspace_id, + ) + + run_overview = string_argilla_runner.run_dataset(string_dataset_id) + evaluator.submit(run_overview.id) + evaluator.submit(run_overview.id) + + assert len(dummy_client._datasets) == 2 + assert dummy_client._datasets.keys() != dummy_client + + +def test_argilla_evaluator_can_take_name( + string_argilla_runner: Runner[DummyStringInput, DummyStringOutput], + string_dataset_id: str, + in_memory_dataset_repository: InMemoryDatasetRepository, + in_memory_run_repository: InMemoryRunRepository, + async_in_memory_evaluation_repository: AsyncInMemoryEvaluationRepository, + dummy_client: DummyArgillaClient, +) -> None: + workspace_id = "workspace_id" + evaluator = ArgillaEvaluator( + in_memory_dataset_repository, + in_memory_run_repository, + async_in_memory_evaluation_repository, + "dummy-string-task", + DummyStringTaskArgillaEvaluationLogic(), + dummy_client, + workspace_id, + ) + + dataset_name = str(uuid4()) + run_overview = string_argilla_runner.run_dataset(string_dataset_id) + dataset_id = evaluator.submit(run_overview.id, dataset_name=dataset_name).id + + assert len(dummy_client._datasets) == 1 + assert dummy_client._names[dataset_id] == dataset_name diff --git a/tests/evaluation/test_instruct_comparison_argilla_evaluator.py b/tests/evaluation/test_instruct_comparison_argilla_evaluator.py index d6e85e2be..a8c2ad9df 100644 --- a/tests/evaluation/test_instruct_comparison_argilla_evaluator.py +++ b/tests/evaluation/test_instruct_comparison_argilla_evaluator.py @@ -37,6 +37,15 @@ class ArgillaFake(ArgillaClient): def __init__(self) -> None: self.records: dict[str, list[RecordData]] = defaultdict(list) + def create_dataset( + self, + workspace_id: str, + dataset_name: str, + fields: Sequence[Field], + questions: Sequence[Question], + ) -> str: + return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions) + def ensure_dataset_exists( self, workspace_id: str,