Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: restructure tests according to src structure #872

Merged
merged 3 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/intelligence_layer/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def __init__(
limited_concurrency_client_from_env() if client is None else client
)
if name not in [model["name"] for model in self._client.models()]:
raise ValueError(f"Invalid model name: {name}")
raise ValueError(
f"Could not find model: {name}. Either model name is invalid or model is currently down."
)
self._complete: Task[CompleteInput, CompleteOutput] = _Complete(
self._client, name
)
Expand Down
31 changes: 1 addition & 30 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from aleph_alpha_client import Client, Image
from dotenv import load_dotenv
from pydantic import BaseModel
from pytest import fixture

from intelligence_layer.connectors import (
Expand All @@ -17,13 +16,7 @@
QdrantInMemoryRetriever,
RetrieverType,
)
from intelligence_layer.core import (
LuminousControlModel,
NoOpTracer,
Task,
TaskSpan,
utc_now,
)
from intelligence_layer.core import LuminousControlModel, NoOpTracer, utc_now
from intelligence_layer.evaluation import (
AsyncInMemoryEvaluationRepository,
EvaluationOverview,
Expand Down Expand Up @@ -117,28 +110,6 @@ def to_document(document_chunk: DocumentChunk) -> Document:
return Document(text=document_chunk.text, metadata=document_chunk.metadata)


class DummyStringInput(BaseModel):
input: str = "dummy-input"


class DummyStringOutput(BaseModel):
output: str = "dummy-output"


class DummyStringEvaluation(BaseModel):
evaluation: str = "dummy-evaluation"


class DummyStringTask(Task[DummyStringInput, DummyStringOutput]):
def do_run(self, input: DummyStringInput, task_span: TaskSpan) -> DummyStringOutput:
return DummyStringOutput()


@fixture
def dummy_string_task() -> DummyStringTask:
return DummyStringTask()


@fixture
def in_memory_dataset_repository() -> InMemoryDatasetRepository:
return InMemoryDatasetRepository()
Expand Down
27 changes: 27 additions & 0 deletions tests/evaluation/aggregation/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pytest import fixture

from intelligence_layer.core import utc_now
from intelligence_layer.evaluation import AggregationOverview, EvaluationOverview
from tests.evaluation.conftest import DummyAggregatedEvaluation


@fixture
def dummy_aggregated_evaluation() -> DummyAggregatedEvaluation:
return DummyAggregatedEvaluation(score=0.5)


@fixture
def aggregation_overview(
evaluation_overview: EvaluationOverview,
dummy_aggregated_evaluation: DummyAggregatedEvaluation,
) -> AggregationOverview[DummyAggregatedEvaluation]:
return AggregationOverview(
evaluation_overviews=frozenset([evaluation_overview]),
id="aggregation-id",
start=utc_now(),
end=utc_now(),
successful_evaluation_count=5,
crashed_during_evaluation_count=3,
description="dummy-evaluator",
statistics=dummy_aggregated_evaluation,
)
12 changes: 12 additions & 0 deletions tests/evaluation/aggregation/test_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest

from intelligence_layer.evaluation.aggregation.domain import AggregationOverview
from intelligence_layer.evaluation.evaluation.domain import EvaluationFailed
from tests.evaluation.conftest import DummyAggregatedEvaluation


def test_raise_on_exception_for_evaluation_run_overview(
aggregation_overview: AggregationOverview[DummyAggregatedEvaluation],
) -> None:
with pytest.raises(EvaluationFailed):
aggregation_overview.raise_on_evaluation_failure()
140 changes: 33 additions & 107 deletions tests/evaluation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,9 @@
from pydantic import BaseModel
from pytest import fixture

from intelligence_layer.connectors import (
ArgillaClient,
ArgillaEvaluation,
Field,
Question,
RecordData,
)
from intelligence_layer.core import Task, Tracer, utc_now
from intelligence_layer.core import Task, TaskSpan, Tracer
from intelligence_layer.evaluation import (
AggregationOverview,
DatasetRepository,
EvaluationOverview,
Example,
ExampleEvaluation,
FileAggregationRepository,
Expand All @@ -29,23 +20,50 @@
InMemoryRunRepository,
Runner,
)
from tests.conftest import DummyStringInput, DummyStringOutput

FAIL_IN_EVAL_INPUT = "fail in eval"
FAIL_IN_TASK_INPUT = "fail in task"


class DummyStringInput(BaseModel):
input: str = "dummy-input"


class DummyStringOutput(BaseModel):
output: str = "dummy-output"


class DummyStringEvaluation(BaseModel):
evaluation: str = "dummy-evaluation"


class DummyStringTask(Task[DummyStringInput, DummyStringOutput]):
def do_run(self, input: DummyStringInput, task_span: TaskSpan) -> DummyStringOutput:
return DummyStringOutput()


@fixture
def dummy_string_task() -> DummyStringTask:
return DummyStringTask()


@fixture
def string_dataset_id(
dummy_string_examples: Iterable[Example[DummyStringInput, DummyStringOutput]],
in_memory_dataset_repository: DatasetRepository,
) -> str:
return in_memory_dataset_repository.create_dataset(
examples=dummy_string_examples, dataset_name="test-dataset"
).id


class DummyTask(Task[str, str]):
def do_run(self, input: str, tracer: Tracer) -> str:
if input == FAIL_IN_TASK_INPUT:
raise RuntimeError(input)
return input


class DummyStringEvaluation(BaseModel):
same: bool


class DummyEvaluation(BaseModel):
result: str

Expand Down Expand Up @@ -93,38 +111,6 @@ def file_run_repository(tmp_path: Path) -> FileRunRepository:
return FileRunRepository(tmp_path)


@fixture
def string_dataset_id(
dummy_string_examples: Iterable[Example[DummyStringInput, DummyStringOutput]],
in_memory_dataset_repository: DatasetRepository,
) -> str:
return in_memory_dataset_repository.create_dataset(
examples=dummy_string_examples, dataset_name="test-dataset"
).id


@fixture
def dummy_aggregated_evaluation() -> DummyAggregatedEvaluation:
return DummyAggregatedEvaluation(score=0.5)


@fixture
def aggregation_overview(
evaluation_overview: EvaluationOverview,
dummy_aggregated_evaluation: DummyAggregatedEvaluation,
) -> AggregationOverview[DummyAggregatedEvaluation]:
return AggregationOverview(
evaluation_overviews=frozenset([evaluation_overview]),
id="aggregation-id",
start=utc_now(),
end=utc_now(),
successful_evaluation_count=5,
crashed_during_evaluation_count=3,
description="dummy-evaluator",
statistics=dummy_aggregated_evaluation,
)


@fixture
def dummy_string_example() -> Example[DummyStringInput, DummyStringOutput]:
return Example(input=DummyStringInput(), expected_output=DummyStringOutput())
Expand All @@ -150,66 +136,6 @@ def dummy_runner(
)


class StubArgillaClient(ArgillaClient):
_expected_workspace_id: str
_expected_fields: Sequence[Field]
_expected_questions: Sequence[Question]
_datasets: dict[str, list[RecordData]] = {}
_score = 3.0

def create_dataset(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions)

def ensure_dataset_exists(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
if workspace_id != self._expected_workspace_id:
raise Exception("Incorrect workspace id")
elif fields != self._expected_fields:
raise Exception("Incorrect fields")
elif questions != self._expected_questions:
raise Exception("Incorrect questions")
dataset_id = str(uuid4())
self._datasets[dataset_id] = []
return dataset_id

def add_record(self, dataset_id: str, record: RecordData) -> None:
if dataset_id not in self._datasets:
raise Exception("Add record: dataset not found")
self._datasets[dataset_id].append(record)

def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]:
dataset = self._datasets.get(dataset_id)
assert dataset
return [
ArgillaEvaluation(
example_id=record.example_id,
record_id="ignored",
responses={"human-score": self._score},
metadata=dict(),
)
for record in dataset
]

def split_dataset(self, dataset_id: str, n_splits: int) -> None:
raise NotImplementedError


@fixture
def stub_argilla_client() -> StubArgillaClient:
return StubArgillaClient()


@fixture()
def temp_file_system() -> Iterable[MemoryFileSystem]:
mfs = MemoryFileSystem()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from intelligence_layer.evaluation.dataset.hugging_face_dataset_repository import (
HuggingFaceDatasetRepository,
)
from tests.conftest import DummyStringInput, DummyStringOutput
from tests.evaluation.conftest import DummyStringInput, DummyStringOutput


@fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,72 @@
Runner,
SuccessfulExampleOutput,
)
from tests.conftest import (
from tests.evaluation.conftest import (
DummyStringEvaluation,
DummyStringInput,
DummyStringOutput,
DummyStringTask,
)
from tests.evaluation.conftest import StubArgillaClient


class StubArgillaClient(ArgillaClient):
_expected_workspace_id: str
_expected_fields: Sequence[Field]
_expected_questions: Sequence[Question]
_datasets: dict[str, list[RecordData]] = {}
_score = 3.0

def create_dataset(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
return self.ensure_dataset_exists(workspace_id, dataset_name, fields, questions)

def ensure_dataset_exists(
self,
workspace_id: str,
dataset_name: str,
fields: Sequence[Field],
questions: Sequence[Question],
) -> str:
if workspace_id != self._expected_workspace_id:
raise Exception("Incorrect workspace id")
elif fields != self._expected_fields:
raise Exception("Incorrect fields")
elif questions != self._expected_questions:
raise Exception("Incorrect questions")
dataset_id = str(uuid4())
self._datasets[dataset_id] = []
return dataset_id

def add_record(self, dataset_id: str, record: RecordData) -> None:
if dataset_id not in self._datasets:
raise Exception("Add record: dataset not found")
self._datasets[dataset_id].append(record)

def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]:
dataset = self._datasets.get(dataset_id)
assert dataset
return [
ArgillaEvaluation(
example_id=record.example_id,
record_id="ignored",
responses={"human-score": self._score},
metadata=dict(),
)
for record in dataset
]

def split_dataset(self, dataset_id: str, n_splits: int) -> None:
raise NotImplementedError


@fixture
def stub_argilla_client() -> StubArgillaClient:
return StubArgillaClient()


class DummyStringTaskArgillaEvaluationLogic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def test_run_evaluation(
[
"",
"--eval-logic",
"tests.evaluation.test_run.DummyEvaluationLogic",
"tests.evaluation.run.test_run.DummyEvaluationLogic",
"--aggregation-logic",
"tests.evaluation.test_run.DummyAggregationLogic",
"tests.evaluation.run.test_run.DummyAggregationLogic",
"--task",
"tests.evaluation.test_run.DummyTask",
"tests.evaluation.run.test_run.DummyTask",
"--dataset-repository-path",
str(dataset_path),
"--dataset-id",
Expand Down Expand Up @@ -112,11 +112,11 @@ def test_run_evaluation_with_task_with_client(
[
"",
"--eval-logic",
"tests.evaluation.test_run.DummyEvaluationLogic",
"tests.evaluation.run.test_run.DummyEvaluationLogic",
"--aggregation-logic",
"tests.evaluation.test_run.DummyAggregationLogic",
"tests.evaluation.run.test_run.DummyAggregationLogic",
"--task",
"tests.evaluation.test_run.DummyTaskWithClient",
"tests.evaluation.run.test_run.DummyTaskWithClient",
"--dataset-repository-path",
str(dataset_path),
"--dataset-id",
Expand Down
Loading