Skip to content

Commit

Permalink
feat: restructure tests according to src structure
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasKoehneckeAA committed May 30, 2024
1 parent 88e5def commit b898671
Show file tree
Hide file tree
Showing 47 changed files with 134 additions and 164 deletions.
31 changes: 1 addition & 30 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from aleph_alpha_client import Client, Image
from dotenv import load_dotenv
from pydantic import BaseModel
from pytest import fixture

from intelligence_layer.connectors import (
Expand All @@ -17,13 +16,7 @@
QdrantInMemoryRetriever,
RetrieverType,
)
from intelligence_layer.core import (
LuminousControlModel,
NoOpTracer,
Task,
TaskSpan,
utc_now,
)
from intelligence_layer.core import LuminousControlModel, NoOpTracer, utc_now
from intelligence_layer.evaluation import (
AsyncInMemoryEvaluationRepository,
EvaluationOverview,
Expand Down Expand Up @@ -117,28 +110,6 @@ def to_document(document_chunk: DocumentChunk) -> Document:
return Document(text=document_chunk.text, metadata=document_chunk.metadata)


class DummyStringInput(BaseModel):
input: str = "dummy-input"


class DummyStringOutput(BaseModel):
output: str = "dummy-output"


class DummyStringEvaluation(BaseModel):
evaluation: str = "dummy-evaluation"


class DummyStringTask(Task[DummyStringInput, DummyStringOutput]):
def do_run(self, input: DummyStringInput, task_span: TaskSpan) -> DummyStringOutput:
return DummyStringOutput()


@fixture
def dummy_string_task() -> DummyStringTask:
return DummyStringTask()


@fixture
def in_memory_dataset_repository() -> InMemoryDatasetRepository:
return InMemoryDatasetRepository()
Expand Down
27 changes: 27 additions & 0 deletions tests/evaluation/aggregation/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pytest import fixture

from intelligence_layer.core import utc_now
from intelligence_layer.evaluation import AggregationOverview, EvaluationOverview
from tests.evaluation.conftest import DummyAggregatedEvaluation


@fixture
def dummy_aggregated_evaluation() -> DummyAggregatedEvaluation:
return DummyAggregatedEvaluation(score=0.5)


@fixture
def aggregation_overview(
evaluation_overview: EvaluationOverview,
dummy_aggregated_evaluation: DummyAggregatedEvaluation,
) -> AggregationOverview[DummyAggregatedEvaluation]:
return AggregationOverview(
evaluation_overviews=frozenset([evaluation_overview]),
id="aggregation-id",
start=utc_now(),
end=utc_now(),
successful_evaluation_count=5,
crashed_during_evaluation_count=3,
description="dummy-evaluator",
statistics=dummy_aggregated_evaluation,
)
File renamed without changes.
12 changes: 12 additions & 0 deletions tests/evaluation/aggregation/test_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest

from intelligence_layer.evaluation.aggregation.domain import AggregationOverview
from intelligence_layer.evaluation.evaluation.domain import EvaluationFailed
from tests.evaluation.conftest import DummyAggregatedEvaluation


def test_raise_on_exception_for_evaluation_run_overview(
aggregation_overview: AggregationOverview[DummyAggregatedEvaluation],
) -> None:
with pytest.raises(EvaluationFailed):
aggregation_overview.raise_on_evaluation_failure()
File renamed without changes.
140 changes: 33 additions & 107 deletions tests/evaluation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,9 @@
from pydantic import BaseModel
from pytest import fixture

from intelligence_layer.connectors import (
ArgillaClient,
ArgillaEvaluation,
Field,
Question,
RecordData,
)
from intelligence_layer.core import Task, Tracer, utc_now
from intelligence_layer.core import Task, TaskSpan, Tracer
from intelligence_layer.evaluation import (
AggregationOverview,
DatasetRepository,
EvaluationOverview,
Example,
ExampleEvaluation,
FileAggregationRepository,
Expand All @@ -29,23 +20,50 @@
InMemoryRunRepository,
Runner,
)
from tests.conftest import DummyStringInput, DummyStringOutput

FAIL_IN_EVAL_INPUT = "fail in eval"
FAIL_IN_TASK_INPUT = "fail in task"


class DummyStringInput(BaseModel):
input: str = "dummy-input"


class DummyStringOutput(BaseModel):
output: str = "dummy-output"


class DummyStringEvaluation(BaseModel):
evaluation: str = "dummy-evaluation"


class DummyStringTask(Task[DummyStringInput, DummyStringOutput]):
def do_run(self, input: DummyStringInput, task_span: TaskSpan) -> DummyStringOutput:
return DummyStringOutput()


@fixture
def dummy_string_task() -> DummyStringTask:
return DummyStringTask()


@fixture
def string_dataset_id(
dummy_string_examples: Iterable[Example[DummyStringInput, DummyStringOutput]],
in_memory_dataset_repository: DatasetRepository,
) -> str:
return in_memory_dataset_repository.create_dataset(
examples=dummy_string_examples, dataset_name="test-dataset"
).id


class DummyTask(Task[str, str]):
def do_run(self, input: str, tracer: Tracer) -> str:
if input == FAIL_IN_TASK_INPUT:
raise RuntimeError(input)
return input


class DummyStringEvaluation(BaseModel):
same: bool


class DummyEvaluation(BaseModel):
result: str

Expand Down Expand Up @@ -93,38 +111,6 @@ def file_run_repository(tmp_path: Path) -> FileRunRepository:
return FileRunRepository(tmp_path)


@fixture
def string_dataset_id(
dummy_string_examples: Iterable[Example[DummyStringInput, DummyStringOutput]],
in_memory_dataset_repository: DatasetRepository,
) -> str:
return in_memory_dataset_repository.create_dataset(
examples=dummy_string_examples, dataset_name="test-dataset"
).id


@fixture
def dummy_aggregated_evaluation() -> DummyAggregatedEvaluation:
return DummyAggregatedEvaluation(score=0.5)


@fixture
def aggregation_overview(
evaluation_overview: EvaluationOverview,
dummy_aggregated_evaluation: DummyAggregatedEvaluation,
) -> AggregationOverview[DummyAggregatedEvaluation]:
return AggregationOverview(
evaluation_overviews=frozenset([evaluation_overview]),
id="aggregation-id",
start=utc_now(),
end=utc_now(),
successful_evaluation_count=5,
crashed_during_evaluation_count=3,
description="dummy-evaluator",
statistics=dummy_aggregated_evaluation,
)


@fixture
def dummy_string_example() -> Example[DummyStringInput, DummyStringOutput]:
return Example(input=DummyStringInput(), expected_output=DummyStringOutput())
Expand All @@ -150,66 +136,6 @@ def dummy_runner(
)


class StubArgillaClient(ArgillaClient):
_expected_workspace_id: str
_expected_fields: Sequence[Field]
_expected_questions: Sequence[Question]
_datasets: dict[str, list[RecordData]] = {}
_score = 3.0

def create_dataset(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions)

def ensure_dataset_exists(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
if workspace_id != self._expected_workspace_id:
raise Exception("Incorrect workspace id")
elif fields != self._expected_fields:
raise Exception("Incorrect fields")
elif questions != self._expected_questions:
raise Exception("Incorrect questions")
dataset_id = str(uuid4())
self._datasets[dataset_id] = []
return dataset_id

def add_record(self, dataset_id: str, record: RecordData) -> None:
if dataset_id not in self._datasets:
raise Exception("Add record: dataset not found")
self._datasets[dataset_id].append(record)

def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]:
dataset = self._datasets.get(dataset_id)
assert dataset
return [
ArgillaEvaluation(
example_id=record.example_id,
record_id="ignored",
responses={"human-score": self._score},
metadata=dict(),
)
for record in dataset
]

def split_dataset(self, dataset_id: str, n_splits: int) -> None:
raise NotImplementedError


@fixture
def stub_argilla_client() -> StubArgillaClient:
return StubArgillaClient()


@fixture()
def temp_file_system() -> Iterable[MemoryFileSystem]:
mfs = MemoryFileSystem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from intelligence_layer.evaluation.dataset.hugging_face_dataset_repository import (
HuggingFaceDatasetRepository,
)
from tests.conftest import DummyStringInput, DummyStringOutput
from tests.evaluation.conftest import DummyStringInput, DummyStringOutput


@fixture
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,63 @@
Runner,
SuccessfulExampleOutput,
)
from tests.conftest import (
from tests.evaluation.conftest import (
DummyStringEvaluation,
DummyStringInput,
DummyStringOutput,
DummyStringTask,
)
from tests.evaluation.conftest import StubArgillaClient


class StubArgillaClient(ArgillaClient):
_expected_workspace_id: str
_expected_fields: Sequence[Field]
_expected_questions: Sequence[Question]
_datasets: dict[str, list[RecordData]] = {}
_score = 3.0

def ensure_dataset_exists(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
if workspace_id != self._expected_workspace_id:
raise Exception("Incorrect workspace id")
elif fields != self._expected_fields:
raise Exception("Incorrect fields")
elif questions != self._expected_questions:
raise Exception("Incorrect questions")
dataset_id = str(uuid4())
self._datasets[dataset_id] = []
return dataset_id

def add_record(self, dataset_id: str, record: RecordData) -> None:
if dataset_id not in self._datasets:
raise Exception("Add record: dataset not found")
self._datasets[dataset_id].append(record)

def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]:
dataset = self._datasets.get(dataset_id)
assert dataset
return [
ArgillaEvaluation(
example_id=record.example_id,
record_id="ignored",
responses={"human-score": self._score},
metadata=dict(),
)
for record in dataset
]

def split_dataset(self, dataset_id: str, n_splits: int) -> None:
raise NotImplementedError


@fixture
def stub_argilla_client() -> StubArgillaClient:
return StubArgillaClient()


class DummyStringTaskArgillaEvaluationLogic(
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def test_run_evaluation(
[
"",
"--eval-logic",
"tests.evaluation.test_run.DummyEvaluationLogic",
"tests.evaluation.run.test_run.DummyEvaluationLogic",
"--aggregation-logic",
"tests.evaluation.test_run.DummyAggregationLogic",
"tests.evaluation.run.test_run.DummyAggregationLogic",
"--task",
"tests.evaluation.test_run.DummyTask",
"tests.evaluation.run.test_run.DummyTask",
"--dataset-repository-path",
str(dataset_path),
"--dataset-id",
Expand Down Expand Up @@ -112,11 +112,11 @@ def test_run_evaluation_with_task_with_client(
[
"",
"--eval-logic",
"tests.evaluation.test_run.DummyEvaluationLogic",
"tests.evaluation.run.test_run.DummyEvaluationLogic",
"--aggregation-logic",
"tests.evaluation.test_run.DummyAggregationLogic",
"tests.evaluation.run.test_run.DummyAggregationLogic",
"--task",
"tests.evaluation.test_run.DummyTaskWithClient",
"tests.evaluation.run.test_run.DummyTaskWithClient",
"--dataset-repository-path",
str(dataset_path),
"--dataset-id",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
TaskSpanTrace,
)
from intelligence_layer.evaluation.run.domain import FailedExampleRun
from tests.conftest import DummyStringInput
from tests.evaluation.conftest import DummyStringInput

test_repository_fixtures = [
"file_run_repository",
Expand Down
File renamed without changes.
Loading

0 comments on commit b898671

Please sign in to comment.