diff --git a/src/evidently/collector/app.py b/src/evidently/collector/app.py index 49e44847e9..5b1288e0b6 100644 --- a/src/evidently/collector/app.py +++ b/src/evidently/collector/app.py @@ -4,6 +4,7 @@ from typing import AsyncGenerator from typing import Dict from typing import List +from typing import Optional import pandas as pd import uvicorn @@ -148,7 +149,7 @@ async def create_snapshot(collector: CollectorConfig, storage: CollectorStorage) storage.log(collector.id, LogEvent(ok=True)) -def run(host: str = "0.0.0.0", port: int = 8001, config_path: str = CONFIG_PATH, secret: str = None): +def create_app(config_path: str = CONFIG_PATH, secret: Optional[str] = None) -> Litestar: service = CollectorServiceConfig.load_or_default(config_path) service.storage.init_all(service) @@ -205,7 +206,7 @@ async def check_service_snapshots_periodically(): stop_event.set() await task - app = Litestar( + return Litestar( route_handlers=[ create_collector, get_collector, @@ -222,12 +223,12 @@ async def check_service_snapshots_periodically(): guards=[is_authenticated], lifespan=[check_snapshots_factory_lifespan], ) - uvicorn.run(app, host=host, port=port) -def main(): - run() +def run(host: str = "0.0.0.0", port: int = 8001, config_path: str = CONFIG_PATH, secret: Optional[str] = None): + app = create_app(config_path, secret) + uvicorn.run(app, host=host, port=port) if __name__ == "__main__": - main() + run() diff --git a/tests/collector/test_app.py b/tests/collector/test_app.py new file mode 100644 index 0000000000..1e05618912 --- /dev/null +++ b/tests/collector/test_app.py @@ -0,0 +1,98 @@ +import socket +import time +from multiprocessing import Process +from typing import Optional + +import pandas as pd +import pytest + +from evidently.collector.app import run +from evidently.collector.client import CollectorClient +from evidently.collector.config import CollectorConfig +from evidently.collector.config import IntervalTrigger +from evidently.collector.config import ReportConfig +from evidently.test_suite import TestSuite +from evidently.tests import TestNumberOfOutRangeValues + +COLLECTOR_ID = "1" +PROJECT_ID = "2" + +HOST = "localhost" +PORT = 8080 +BASE_URL = f"http://{HOST}:{PORT}" + + +def create_test_data() -> pd.DataFrame: + return pd.DataFrame([{"values1": 5.0, "values2": 0.0}]) + + +def create_test_suite() -> TestSuite: + return TestSuite(tests=[TestNumberOfOutRangeValues("values1", left=5)], tags=["quality"]) + + +def create_report_config( + current_data: pd.DataFrame, + test_suite: TestSuite, + references: Optional[pd.DataFrame] = None, +): + test_suite.run(reference_data=references, current_data=current_data) + return ReportConfig.from_test_suite(test_suite) + + +def create_collector_config(data: pd.DataFrame, test_suite: TestSuite) -> CollectorConfig: + return CollectorConfig( + trigger=IntervalTrigger(interval=1), + report_config=create_report_config(current_data=data, test_suite=test_suite), + project_id=PROJECT_ID, + ) + + +def create_collector_client(base_url: str = BASE_URL) -> CollectorClient: + return CollectorClient(base_url=base_url) + + +def wait_server_start(host: str = HOST, port: int = PORT, timeout: float = 10.0, check_interval: float = 0.01) -> None: + start_time = time.time() + while time.time() - start_time < timeout: + time.sleep(check_interval) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.connect((host, port)) + except ConnectionRefusedError: + continue + return + raise TimeoutError() + + +@pytest.fixture +def server() -> None: + proc = Process(target=run, args=(HOST, PORT), daemon=True) + proc.start() + wait_server_start() + try: + yield + except Exception as e: + proc.kill() + raise e + + +def test_create_collector_handler_work_with_collector_client(server: None): + collector_config = create_collector_config(create_test_data(), create_test_suite()) + client = create_collector_client() + resp = client.create_collector(COLLECTOR_ID, collector_config) + assert resp["id"] == COLLECTOR_ID + assert resp["project_id"] == PROJECT_ID + + +def test_data_handler_work_with_collector_client(server: None): + collector_config = create_collector_config(create_test_data(), create_test_suite()) + client = create_collector_client() + client.create_collector(COLLECTOR_ID, collector_config) + client.send_data(COLLECTOR_ID, create_test_data()) + + +def test_references_handler_work_with_collector_client(server: None): + collector_config = create_collector_config(create_test_data(), create_test_suite()) + client = create_collector_client() + client.create_collector(COLLECTOR_ID, collector_config) + client.set_reference(COLLECTOR_ID, create_test_data())