Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: naming in argilla evaluator #879

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
- We updated the graders to support python 3.12 and moved away from `nltk`-package:
- `BleuGrader` now uses `sacrebleu`-package.
- `RougeGrader` now uses the `rouge_score`-package.
- When using the `ArgillaEvaluator`, attempting to submit to a dataset, which already exists, will no longer work append to the dataset. This makes it more in-line with other evaluation concepts.
- Instead of appending to an active argilla dataset, you now need to create a new dataset, retrieve it and then finally combine both datasets in the aggregation step.
- The `ArgillaClient` now has methods `create_dataset` for less fault-ignoring dataset creation and `add_records` for performant uploads.

### New Features
- Add `how_to_implement_incremental_evaluation`.
Expand All @@ -23,7 +26,8 @@
- We now support python 3.12

### Fixes
- The document index client now correctly URL-encodes document names in its queries.
- The document index client now correctly URL-encodes document names in its queries.
- The `ArgillaEvaluator` not properly supports `dataset_name`

### Deprecations
...
Expand Down
2 changes: 1 addition & 1 deletion Concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ task:
```Python
class Task(ABC, Generic[Input, Output]):
@final
def run(self, input: Input, tracer: Tracer, trace_id: Optional[str] = None) -> Output:
def run(self, input: Input, tracer: Tracer) -> Output:
...
```

Expand Down
2 changes: 1 addition & 1 deletion scripts/test.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env -S bash -eu -o pipefail

poetry run pytest -n 10
TQDM_DISABLE=1 poetry run pytest -n 10
134 changes: 113 additions & 21 deletions src/intelligence_layer/connectors/argilla/argilla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
from abc import ABC, abstractmethod
from http import HTTPStatus
from itertools import chain, count, islice
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union, cast
from typing import (
Any,
Callable,
Iterable,
Mapping,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
from uuid import uuid4

from pydantic import BaseModel
Expand Down Expand Up @@ -96,6 +106,29 @@ class ArgillaClient(ABC):
the intelligence layer to create feedback datasets or retrieve evaluation results.
"""

@abstractmethod
def create_dataset(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
"""Creates and publishes a new feedback dataset in Argilla.

Raises an error if the name exists already.

Args:
workspace_id: the id of the workspace the feedback-dataset should be created in.
The user executing this request must have corresponding permissions for this workspace.
dataset_name: the name of the feedback-dataset to be created.
fields: all fields of this dataset.
questions: all questions for this dataset.
Returns:
The id of the created dataset.
"""
...

@abstractmethod
def ensure_dataset_exists(
self,
Expand All @@ -119,14 +152,24 @@ def ensure_dataset_exists(

@abstractmethod
def add_record(self, dataset_id: str, record: RecordData) -> None:
"""Adds a new record to be evalated to the given dataset.
"""Adds a new record to the given dataset.

Args:
dataset_id: id of the dataset the record is added to
record: contains the actual record data (i.e. content for the dataset's fields)
record: the actual record data (i.e. content for the dataset's fields)
"""
...

def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None:
"""Adds new records to the given dataset.

Args:
dataset_id: id of the dataset the record is added to
records: list containing the record data (i.e. content for the dataset's fields)
"""
for record in records:
return self.add_record(dataset_id, record)

@abstractmethod
def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]:
"""Returns all human-evaluated evaluations for the given dataset.
Expand All @@ -149,6 +192,18 @@ def split_dataset(self, dataset_id: str, n_splits: int) -> None:
...


T = TypeVar("T")


def batch_iterator(iterable: Iterable[T], batch_size: int) -> Iterable[list[T]]:
iterator = iter(iterable)
while True:
batch = list(islice(iterator, batch_size))
if not batch:
break
yield batch


class DefaultArgillaClient(ArgillaClient):
def __init__(
self,
Expand Down Expand Up @@ -196,7 +251,7 @@ def ensure_workspace_exists(self, workspace_name: str) -> str:
)
raise e

def ensure_dataset_exists(
def create_dataset(
self,
workspace_id: str,
dataset_name: str,
Expand All @@ -205,6 +260,39 @@ def ensure_dataset_exists(
) -> str:
try:
dataset_id: str = self._create_dataset(dataset_name, workspace_id)["id"]
for field in fields:
self._create_field(field.name, field.title, dataset_id)

for question in questions:
self._create_question(
question.name,
question.title,
question.description,
question.options,
dataset_id,
)
self._publish_dataset(dataset_id)
return dataset_id

except HTTPError as e:
if e.response.status_code == HTTPStatus.CONFLICT:
SebastianNiehusAA marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Cannot create dataset with name '{dataset_name}', either the given dataset name, already exists"
f"or field name or question name are duplicates."
)
raise e

def ensure_dataset_exists(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
try:
dataset_id: str = self.create_dataset(
workspace_id, dataset_name, fields, questions
)
except HTTPError as e:
if e.response.status_code == HTTPStatus.CONFLICT:
datasets = self._list_datasets(workspace_id)
Expand Down Expand Up @@ -249,9 +337,10 @@ def _ignore_failure_status(
raise e

def add_record(self, dataset_id: str, record: RecordData) -> None:
self._create_record(
record.content, record.metadata, record.example_id, dataset_id
)
self._create_records([record], dataset_id)

def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None:
self._create_records(records, dataset_id)

def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]:
def to_responses(
Expand Down Expand Up @@ -481,24 +570,27 @@ def _list_records(
record["example_id"] = example_id
yield from cast(Sequence[Mapping[str, Any]], records)

def _create_record(
def _create_records(
self,
content: Mapping[str, str],
metadata: Mapping[str, str],
example_id: str,
records: Sequence[RecordData],
dataset_id: str,
) -> None:
url = self.api_url + f"api/v1/datasets/{dataset_id}/records"
data = {
"items": [
{
"fields": content,
"metadata": {**metadata, "example_id": example_id},
}
]
}
response = self.session.post(url, json=data)
response.raise_for_status()
for batch in batch_iterator(records, 200):
data = {
"items": [
{
"fields": record.content,
"metadata": {
**record.metadata,
"example_id": record.example_id,
},
}
for record in batch
]
}
response = self.session.post(url, json=data)
response.raise_for_status()

def delete_workspace(self, workspace_id: str) -> None:
for dataset in self._list_datasets(workspace_id)["items"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from itertools import combinations
from typing import Mapping, Optional, Sequence
from uuid import uuid4

from pydantic import BaseModel

Expand Down Expand Up @@ -132,11 +133,12 @@ def submit(
self,
*run_ids: str,
num_examples: Optional[int] = None,
dataset_name: Optional[str] = None,
abort_on_error: bool = False,
) -> PartialEvaluationOverview:
argilla_dataset_id = self._client.ensure_dataset_exists(
argilla_dataset_id = self._client.create_dataset(
self._workspace_id,
dataset_name="name",
dataset_name if dataset_name else str(uuid4()),
fields=list(self._evaluation_logic.fields.values()),
questions=self._evaluation_logic.questions,
)
Expand Down
55 changes: 49 additions & 6 deletions tests/connectors/argilla/test_argilla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def retry(
@fixture
def argilla_client() -> DefaultArgillaClient:
load_dotenv()
return DefaultArgillaClient(total_retries=8)
return DefaultArgillaClient(total_retries=1)


@fixture
Expand Down Expand Up @@ -78,6 +78,51 @@ def qa_dataset_id(argilla_client: DefaultArgillaClient, workspace_id: str) -> st
)


@pytest.mark.docker
def test_client_can_create_a_dataset(
argilla_client: DefaultArgillaClient,
workspace_id: str,
) -> None:
dataset_id = argilla_client.create_dataset(
workspace_id,
dataset_name="name",
fields=[Field(name="a", title="b")],
questions=[
Question(name="a", title="b", description="c", options=list(range(1, 5)))
],
)
datasets = argilla_client._list_datasets(workspace_id)
assert len(argilla_client._list_datasets(workspace_id)) == 1
assert dataset_id == datasets["items"][0]["id"]


@pytest.mark.docker
def test_client_cannot_create_two_datasets_with_the_same_name(
argilla_client: DefaultArgillaClient,
workspace_id: str,
) -> None:
dataset_name = str(uuid4())
argilla_client.create_dataset(
workspace_id,
dataset_name=dataset_name,
fields=[Field(name="a", title="b")],
questions=[
Question(name="a", title="b", description="c", options=list(range(1, 5)))
],
)
with pytest.raises(ValueError):
argilla_client.create_dataset(
workspace_id,
dataset_name=dataset_name,
fields=[Field(name="a", title="b")],
questions=[
Question(
name="a", title="b", description="c", options=list(range(1, 5))
)
],
)


@fixture
def qa_records(
argilla_client: ArgillaClient, qa_dataset_id: str
Expand All @@ -90,8 +135,7 @@ def qa_records(
)
for i in range(60)
]
for record in records:
argilla_client.add_record(qa_dataset_id, record)
argilla_client.add_records(qa_dataset_id, records)
return records


Expand All @@ -107,13 +151,12 @@ def long_qa_records(
)
for i in range(1024)
]
for record in records:
argilla_client.add_record(qa_dataset_id, record)
argilla_client.add_records(qa_dataset_id, records)
return records


@pytest.mark.docker
def test_error_on_non_existent_dataset(
def test_retrieving_records_on_non_existant_dataset_raises_errors(
argilla_client: DefaultArgillaClient,
) -> None:
with pytest.raises(HTTPError):
Expand Down
9 changes: 9 additions & 0 deletions tests/evaluation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ class StubArgillaClient(ArgillaClient):
_datasets: dict[str, list[RecordData]] = {}
_score = 3.0

def create_dataset(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions)

def ensure_dataset_exists(
self,
workspace_id: str,
Expand Down
Loading
Loading