Skip to content

Commit

Permalink
fix: naming in argilla evaluator (#879)
Browse files Browse the repository at this point in the history
* WIP: feat: Add test for submit two different datasets

refactor: Move fixture for DefaultArgillaClient to evluation/conftest.py

TASK: IL-541

* fix: argilla evaluator handling dataset names correctly

* feat: improve local test readability

* docs: update changelog and docs

* feat: Small improvements to tests and docstrings
 * Fixed outdated run parameters in Concepts.md
TASK: IL-541

---------

Co-authored-by: Merlin Kallenborn <[email protected]>
Co-authored-by: Sebastian Niehus <[email protected]>
  • Loading branch information
3 people authored May 30, 2024
1 parent f7c4d9f commit 88e5def
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 38 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
- We updated the graders to support python 3.12 and moved away from `nltk`-package:
- `BleuGrader` now uses `sacrebleu`-package.
- `RougeGrader` now uses the `rouge_score`-package.
- When using the `ArgillaEvaluator`, attempting to submit to a dataset, which already exists, will no longer work append to the dataset. This makes it more in-line with other evaluation concepts.
- Instead of appending to an active argilla dataset, you now need to create a new dataset, retrieve it and then finally combine both datasets in the aggregation step.
- The `ArgillaClient` now has methods `create_dataset` for less fault-ignoring dataset creation and `add_records` for performant uploads.

### New Features
- Add `skip_example_on_any_failure` flag to `evaluate_runs` (defaults to True). This allows to configure if you want to keep an example for evaluation, even if it failed for some run.
Expand All @@ -24,7 +27,8 @@
- We now support python 3.12

### Fixes
- The document index client now correctly URL-encodes document names in its queries.
- The document index client now correctly URL-encodes document names in its queries.
- The `ArgillaEvaluator` not properly supports `dataset_name`

### Deprecations
...
Expand Down
2 changes: 1 addition & 1 deletion Concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ task:
```Python
class Task(ABC, Generic[Input, Output]):
@final
def run(self, input: Input, tracer: Tracer, trace_id: Optional[str] = None) -> Output:
def run(self, input: Input, tracer: Tracer) -> Output:
...
```

Expand Down
2 changes: 1 addition & 1 deletion scripts/test.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env -S bash -eu -o pipefail

poetry run pytest -n 10
TQDM_DISABLE=1 poetry run pytest -n 10
134 changes: 113 additions & 21 deletions src/intelligence_layer/connectors/argilla/argilla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
from abc import ABC, abstractmethod
from http import HTTPStatus
from itertools import chain, count, islice
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union, cast
from typing import (
Any,
Callable,
Iterable,
Mapping,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
from uuid import uuid4

from pydantic import BaseModel
Expand Down Expand Up @@ -96,6 +106,29 @@ class ArgillaClient(ABC):
the intelligence layer to create feedback datasets or retrieve evaluation results.
"""

@abstractmethod
def create_dataset(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
"""Creates and publishes a new feedback dataset in Argilla.
Raises an error if the name exists already.
Args:
workspace_id: the id of the workspace the feedback-dataset should be created in.
The user executing this request must have corresponding permissions for this workspace.
dataset_name: the name of the feedback-dataset to be created.
fields: all fields of this dataset.
questions: all questions for this dataset.
Returns:
The id of the created dataset.
"""
...

@abstractmethod
def ensure_dataset_exists(
self,
Expand All @@ -119,14 +152,24 @@ def ensure_dataset_exists(

@abstractmethod
def add_record(self, dataset_id: str, record: RecordData) -> None:
"""Adds a new record to be evalated to the given dataset.
"""Adds a new record to the given dataset.
Args:
dataset_id: id of the dataset the record is added to
record: contains the actual record data (i.e. content for the dataset's fields)
record: the actual record data (i.e. content for the dataset's fields)
"""
...

def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None:
"""Adds new records to the given dataset.
Args:
dataset_id: id of the dataset the record is added to
records: list containing the record data (i.e. content for the dataset's fields)
"""
for record in records:
return self.add_record(dataset_id, record)

@abstractmethod
def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]:
"""Returns all human-evaluated evaluations for the given dataset.
Expand All @@ -149,6 +192,18 @@ def split_dataset(self, dataset_id: str, n_splits: int) -> None:
...


T = TypeVar("T")


def batch_iterator(iterable: Iterable[T], batch_size: int) -> Iterable[list[T]]:
iterator = iter(iterable)
while True:
batch = list(islice(iterator, batch_size))
if not batch:
break
yield batch


class DefaultArgillaClient(ArgillaClient):
def __init__(
self,
Expand Down Expand Up @@ -196,7 +251,7 @@ def ensure_workspace_exists(self, workspace_name: str) -> str:
)
raise e

def ensure_dataset_exists(
def create_dataset(
self,
workspace_id: str,
dataset_name: str,
Expand All @@ -205,6 +260,39 @@ def ensure_dataset_exists(
) -> str:
try:
dataset_id: str = self._create_dataset(dataset_name, workspace_id)["id"]
for field in fields:
self._create_field(field.name, field.title, dataset_id)

for question in questions:
self._create_question(
question.name,
question.title,
question.description,
question.options,
dataset_id,
)
self._publish_dataset(dataset_id)
return dataset_id

except HTTPError as e:
if e.response.status_code == HTTPStatus.CONFLICT:
raise ValueError(
f"Cannot create dataset with name '{dataset_name}', either the given dataset name, already exists"
f"or field name or question name are duplicates."
)
raise e

def ensure_dataset_exists(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
try:
dataset_id: str = self.create_dataset(
workspace_id, dataset_name, fields, questions
)
except HTTPError as e:
if e.response.status_code == HTTPStatus.CONFLICT:
datasets = self._list_datasets(workspace_id)
Expand Down Expand Up @@ -249,9 +337,10 @@ def _ignore_failure_status(
raise e

def add_record(self, dataset_id: str, record: RecordData) -> None:
self._create_record(
record.content, record.metadata, record.example_id, dataset_id
)
self._create_records([record], dataset_id)

def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None:
self._create_records(records, dataset_id)

def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]:
def to_responses(
Expand Down Expand Up @@ -481,24 +570,27 @@ def _list_records(
record["example_id"] = example_id
yield from cast(Sequence[Mapping[str, Any]], records)

def _create_record(
def _create_records(
self,
content: Mapping[str, str],
metadata: Mapping[str, str],
example_id: str,
records: Sequence[RecordData],
dataset_id: str,
) -> None:
url = self.api_url + f"api/v1/datasets/{dataset_id}/records"
data = {
"items": [
{
"fields": content,
"metadata": {**metadata, "example_id": example_id},
}
]
}
response = self.session.post(url, json=data)
response.raise_for_status()
for batch in batch_iterator(records, 200):
data = {
"items": [
{
"fields": record.content,
"metadata": {
**record.metadata,
"example_id": record.example_id,
},
}
for record in batch
]
}
response = self.session.post(url, json=data)
response.raise_for_status()

def delete_workspace(self, workspace_id: str) -> None:
for dataset in self._list_datasets(workspace_id)["items"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from itertools import combinations
from typing import Mapping, Optional, Sequence
from uuid import uuid4

from pydantic import BaseModel

Expand Down Expand Up @@ -132,12 +133,13 @@ def submit(
self,
*run_ids: str,
num_examples: Optional[int] = None,
dataset_name: Optional[str] = None,
abort_on_error: bool = False,
skip_example_on_any_failure: bool = True,
) -> PartialEvaluationOverview:
argilla_dataset_id = self._client.ensure_dataset_exists(
argilla_dataset_id = self._client.create_dataset(
self._workspace_id,
dataset_name="name",
dataset_name if dataset_name else str(uuid4()),
fields=list(self._evaluation_logic.fields.values()),
questions=self._evaluation_logic.questions,
)
Expand Down
55 changes: 49 additions & 6 deletions tests/connectors/argilla/test_argilla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def retry(
@fixture
def argilla_client() -> DefaultArgillaClient:
load_dotenv()
return DefaultArgillaClient(total_retries=8)
return DefaultArgillaClient(total_retries=1)


@fixture
Expand Down Expand Up @@ -78,6 +78,51 @@ def qa_dataset_id(argilla_client: DefaultArgillaClient, workspace_id: str) -> st
)


@pytest.mark.docker
def test_client_can_create_a_dataset(
argilla_client: DefaultArgillaClient,
workspace_id: str,
) -> None:
dataset_id = argilla_client.create_dataset(
workspace_id,
dataset_name="name",
fields=[Field(name="a", title="b")],
questions=[
Question(name="a", title="b", description="c", options=list(range(1, 5)))
],
)
datasets = argilla_client._list_datasets(workspace_id)
assert len(argilla_client._list_datasets(workspace_id)) == 1
assert dataset_id == datasets["items"][0]["id"]


@pytest.mark.docker
def test_client_cannot_create_two_datasets_with_the_same_name(
argilla_client: DefaultArgillaClient,
workspace_id: str,
) -> None:
dataset_name = str(uuid4())
argilla_client.create_dataset(
workspace_id,
dataset_name=dataset_name,
fields=[Field(name="a", title="b")],
questions=[
Question(name="a", title="b", description="c", options=list(range(1, 5)))
],
)
with pytest.raises(ValueError):
argilla_client.create_dataset(
workspace_id,
dataset_name=dataset_name,
fields=[Field(name="a", title="b")],
questions=[
Question(
name="a", title="b", description="c", options=list(range(1, 5))
)
],
)


@fixture
def qa_records(
argilla_client: ArgillaClient, qa_dataset_id: str
Expand All @@ -90,8 +135,7 @@ def qa_records(
)
for i in range(60)
]
for record in records:
argilla_client.add_record(qa_dataset_id, record)
argilla_client.add_records(qa_dataset_id, records)
return records


Expand All @@ -107,13 +151,12 @@ def long_qa_records(
)
for i in range(1024)
]
for record in records:
argilla_client.add_record(qa_dataset_id, record)
argilla_client.add_records(qa_dataset_id, records)
return records


@pytest.mark.docker
def test_error_on_non_existent_dataset(
def test_retrieving_records_on_non_existant_dataset_raises_errors(
argilla_client: DefaultArgillaClient,
) -> None:
with pytest.raises(HTTPError):
Expand Down
9 changes: 9 additions & 0 deletions tests/evaluation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ class StubArgillaClient(ArgillaClient):
_datasets: dict[str, list[RecordData]] = {}
_score = 3.0

def create_dataset(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions)

def ensure_dataset_exists(
self,
workspace_id: str,
Expand Down
Loading

0 comments on commit 88e5def

Please sign in to comment.