diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e307b15..1e31794 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,7 +55,7 @@ jobs: mypy pycrdt_websocket tests - name: Run tests run: | - pytest -v --color=yes + pytest -v --color=yes --timeout=60 check_release: runs-on: ubuntu-latest diff --git a/pyproject.toml b/pyproject.toml index 03a30c2..7878e16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ test = [ "mypy !=1.10.0", # see https://github.com/python/mypy/issues/17166 "pre-commit", "pytest", + "pytest-timeout", "httpx-ws >=0.5.2", "hypercorn >=0.16.0", "trio >=0.25.0", diff --git a/tests/conftest.py b/tests/conftest.py index 561a186..30129ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,9 +9,9 @@ from hypercorn import Config from pycrdt import Doc from sniffio import current_async_library -from utils import StartStopContextManager, Websocket, ensure_server_running +from utils import StartStopContextManager, Websocket, connected_websockets, ensure_server_running -from pycrdt_websocket import ASGIServer, WebsocketProvider, WebsocketServer +from pycrdt_websocket import ASGIServer, WebsocketProvider, WebsocketServer, YRoom @pytest.fixture(params=("websocket_server_context_manager", "websocket_server_start_stop")) @@ -19,6 +19,26 @@ def websocket_server_api(request): return request.param +@pytest.fixture(params=("websocket_provider_context_manager", "websocket_provider_start_stop")) +def websocket_provider_api(request): + return request.param + + +@pytest.fixture(params=("yroom_context_manager", "yroom_start_stop")) +def yroom_api(request): + return request.param + + +@pytest.fixture(params=("real_websocket",)) +def websocket_provider_connect(request): + return request.param + + +@pytest.fixture(params=("ystore_context_manager", "ystore_start_stop")) +def ystore_api(request): + return request.param + + @pytest.fixture async def yws_server(request, unused_tcp_port, websocket_server_api): try: @@ -50,31 +70,32 @@ async def yws_server(request, unused_tcp_port, websocket_server_api): pass -@pytest.fixture(params=("websocket_provider_context_manager", "websocket_provider_start_stop")) -def websocket_provider_api(request): - return request.param - - @pytest.fixture -def yws_provider_factory(room_name, websocket_provider_api): +def yws_provider_factory(room_name, websocket_provider_api, websocket_provider_connect): @asynccontextmanager async def factory(): ydoc = Doc() - async with aconnect_ws(f"http://localhost:{pytest.port}/{room_name}") as websocket: + if websocket_provider_connect == "real_websocket": + server_websocket = None + connect = aconnect_ws(f"http://localhost:{pytest.port}/{room_name}") + else: + server_websocket, connect = connected_websockets() + async with connect as websocket: async with create_task_group() as tg: websocket_provider = WebsocketProvider(ydoc, Websocket(websocket, room_name)) if websocket_provider_api == "websocket_provider_start_stop": websocket_provider = StartStopContextManager(websocket_provider, tg) async with websocket_provider as websocket_provider: - yield ydoc + yield ydoc, server_websocket return factory @pytest.fixture async def yws_provider(yws_provider_factory): - async with yws_provider_factory() as ydoc: - yield ydoc + async with yws_provider_factory() as provider: + ydoc, server_websocket = provider + yield ydoc, server_websocket @pytest.fixture @@ -83,6 +104,20 @@ async def yws_providers(request, yws_provider_factory): yield [yws_provider_factory() for idx in range(number)] +@pytest.fixture +async def yroom(request, yroom_api): + async with create_task_group() as tg: + try: + kwargs = request.param + except AttributeError: + kwargs = {} + room = YRoom(**kwargs) + if yroom_api == "yroom_start_stop": + room = StartStopContextManager(room, tg) + async with room as room: + yield room + + @pytest.fixture def yjs_client(request): client_id = request.param diff --git a/tests/test_asgi.py b/tests/test_asgi.py index ac703c1..67a3fb1 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -10,13 +10,15 @@ async def test_asgi(yws_server, yws_providers): yws_provider1, yws_provider2 = yws_providers # client 1 - async with yws_provider1 as ydoc1: + async with yws_provider1 as yws_provider1: + ydoc1, _ = yws_provider1 ydoc1["map"] = ymap1 = Map() ymap1["key"] = "value" await sleep(0.1) # client 2 - async with yws_provider2 as ydoc2: + async with yws_provider2 as yws_provider2: + ydoc2, _ = yws_provider2 ymap2 = ydoc2.get("map", type=Map) await sleep(0.1) assert str(ymap2) == '{"key":"value"}' diff --git a/tests/test_pycrdt_yjs.py b/tests/test_pycrdt_yjs.py index eaf3b3e..fe0fca2 100644 --- a/tests/test_pycrdt_yjs.py +++ b/tests/test_pycrdt_yjs.py @@ -39,7 +39,7 @@ def watch(ydata, key: str | None = None, timeout: float = 1.0): @pytest.mark.parametrize("yjs_client", [0], indirect=True) async def test_pycrdt_yjs_0(yws_server, yws_provider, yjs_client): - ydoc = yws_provider + ydoc, _ = yws_provider ydoc["map"] = ymap = Map() for v_in in range(10): ymap["in"] = float(v_in) @@ -49,7 +49,7 @@ async def test_pycrdt_yjs_0(yws_server, yws_provider, yjs_client): @pytest.mark.parametrize("yjs_client", [1], indirect=True) async def test_pycrdt_yjs_1(yws_server, yws_provider, yjs_client): - ydoc = yws_provider + ydoc, _ = yws_provider ydoc["cells"] = ycells = Array() ydoc["state"] = ystate = Map() ycells_change = watch(ycells) diff --git a/tests/test_yroom.py b/tests/test_yroom.py index a0edbce..8fa6818 100644 --- a/tests/test_yroom.py +++ b/tests/test_yroom.py @@ -1,7 +1,8 @@ import pytest -from anyio import TASK_STATUS_IGNORED, sleep +from anyio import TASK_STATUS_IGNORED, create_task_group, sleep from anyio.abc import TaskStatus from pycrdt import Map +from utils import Websocket from pycrdt_websocket import exception_logger from pycrdt_websocket.yroom import YRoom @@ -9,6 +10,30 @@ pytestmark = pytest.mark.anyio +@pytest.mark.parametrize("websocket_provider_connect", ["fake_websocket"], indirect=True) +@pytest.mark.parametrize("yws_providers", [2], indirect=True) +async def test_yroom(yroom, yws_providers, websocket_provider_connect, room_name): + async with create_task_group() as tg: + yws_provider1, yws_provider2 = yws_providers + # client 1 + async with yws_provider1 as yws_provider1: + ydoc1, server_ws1 = yws_provider1 + tg.start_soon(yroom.serve, Websocket(server_ws1, room_name)) + ydoc1["map"] = ymap1 = Map() + ymap1["key"] = "value" + await sleep(0.1) + + # client 2 + async with yws_provider2 as yws_provider2: + ydoc2, server_ws2 = yws_provider2 + tg.start_soon(yroom.serve, Websocket(server_ws2, room_name)) + ymap2 = ydoc2.get("map", type=Map) + await sleep(0.1) + + assert str(ymap2) == '{"key":"value"}' + tg.cancel_scope.cancel() + + @pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True) @pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True) async def test_yroom_restart(yws_server, yws_provider): @@ -19,7 +44,7 @@ async def raise_error(task_status: TaskStatus[None] = TASK_STATUS_IGNORED): task_status.started() raise RuntimeError("foo") - yroom.ydoc = yws_provider + yroom.ydoc, _ = yws_provider await server.start_room(yroom) yroom.ydoc["map"] = ymap1 = Map() ymap1["key"] = "value" diff --git a/tests/utils.py b/tests/utils.py index e33f970..1ad418d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,4 @@ -from anyio import Lock, connect_tcp +from anyio import Lock, connect_tcp, create_memory_object_stream from pycrdt import Array, Doc @@ -60,6 +60,45 @@ async def recv(self) -> bytes: return bytes(b) +class ClientWebsocket: + def __init__(self, server_websocket: "ServerWebsocket"): + self.server_websocket = server_websocket + self.send_stream, self.receive_stream = create_memory_object_stream[bytes](65536) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + pass + + async def send_bytes(self, message: bytes) -> None: + await self.server_websocket.send_stream.send(message) + + async def receive_bytes(self) -> bytes: + return await self.receive_stream.receive() + + +class ServerWebsocket: + client_websocket: ClientWebsocket | None = None + + def __init__(self): + self.send_stream, self.receive_stream = create_memory_object_stream[bytes](65536) + + async def send_bytes(self, message: bytes) -> None: + assert self.client_websocket is not None + await self.client_websocket.send_stream.send(message) + + async def receive_bytes(self) -> bytes: + return await self.receive_stream.receive() + + +def connected_websockets() -> tuple[ServerWebsocket, ClientWebsocket]: + server_websocket = ServerWebsocket() + client_websocket = ClientWebsocket(server_websocket) + server_websocket.client_websocket = client_websocket + return server_websocket, client_websocket + + async def ensure_server_running(host: str, port: int) -> None: while True: try: