diff --git a/idom_router/router.py b/idom_router/router.py index efadfa6..4c70423 100644 --- a/idom_router/router.py +++ b/idom_router/router.py @@ -5,20 +5,24 @@ from typing import Any, Callable, Iterator, Sequence from urllib.parse import parse_qs -from idom import component, create_context, use_context, use_memo, use_state +from idom import ( + component, + create_context, + use_memo, + use_state, + use_context, + use_location, +) from idom.core.types import VdomAttributesAndChildren, VdomDict from idom.core.vdom import coalesce_attributes_and_children -from idom.types import ComponentType, Context, Location +from idom.types import ComponentType, Location, Context from idom.web.module import export, module_from_file +from idom.backend.hooks import ConnectionContext, use_connection +from idom.backend.types import Connection, Location from starlette.routing import compile_path as _compile_starlette_path from idom_router.types import RoutePattern, RouteCompiler, Route -try: - from typing import Protocol -except ImportError: # pragma: no cover - from typing_extensions import Protocol # type: ignore - def compile_starlette_route(route: str) -> RoutePattern: pattern, _, converters = _compile_starlette_path(route) @@ -30,8 +34,9 @@ def router( *routes: Route, compiler: RouteCompiler = compile_starlette_route, ) -> ComponentType | None: - initial_location = use_location() - location, set_location = use_state(initial_location) + old_conn = use_connection() + location, set_location = use_state(old_conn.location) + compiled_routes = use_memo( lambda: [(compiler(r), e) for r, e in _iter_routes(routes)], dependencies=routes, @@ -39,16 +44,19 @@ def router( for compiled_route, element in compiled_routes: match = compiled_route.pattern.match(location.pathname) if match: - return _LocationStateContext( - element, - value=_LocationState( - location, - set_location, - { - k: compiled_route.converters[k](v) - for k, v in match.groupdict().items() - }, + convs = compiled_route.converters + return ConnectionContext( + _route_state_context( + element, + value=_RouteState( + set_location, + { + k: convs[k](v) if k in convs else v + for k, v in match.groupdict().items() + }, + ), ), + value=Connection(old_conn.scope, location, old_conn.carrier), key=compiled_route.pattern.pattern, ) return None @@ -57,23 +65,18 @@ def router( @component def link(*attributes_or_children: VdomAttributesAndChildren, to: str) -> VdomDict: attributes, children = coalesce_attributes_and_children(attributes_or_children) - set_location = _use_location_state().set_location + set_location = _use_route_state().set_location attrs = { **attributes, "to": to, "onClick": lambda event: set_location(Location(**event)), } - return _Link(attrs, *children) - - -def use_location() -> Location: - """Get the current route location""" - return _use_location_state().location + return _link(attrs, *children) def use_params() -> dict[str, Any]: """Get parameters from the currently matching route pattern""" - return _use_location_state().params + return use_context(_route_state_context).params def use_query( @@ -94,6 +97,10 @@ def use_query( ) +def _use_route_state() -> _RouteState: + return use_context(_route_state_context) + + def _iter_routes(routes: Sequence[Route]) -> Iterator[tuple[str, Any]]: for r in routes: for path, element in _iter_routes(r.routes): @@ -101,22 +108,16 @@ def _iter_routes(routes: Sequence[Route]) -> Iterator[tuple[str, Any]]: yield r.path, r.element -def _use_location_state() -> _LocationState: - location_state = use_context(_LocationStateContext) - assert location_state is not None, "No location state. Did you use a Router?" - return location_state +_link = export( + module_from_file("idom-router", file=Path(__file__).parent / "bundle.js"), + "Link", +) @dataclass -class _LocationState: - location: Location +class _RouteState: set_location: Callable[[Location], None] params: dict[str, Any] -_LocationStateContext: Context[_LocationState | None] = create_context(None) - -_Link = export( - module_from_file("idom-router", file=Path(__file__).parent / "bundle.js"), - "Link", -) +_route_state_context: Context[_RouteState | None] = create_context(None) diff --git a/tests/test_router.py b/tests/test_router.py index f958854..dc8f53b 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -1,12 +1,12 @@ -import pytest -from idom import Ref, component, html +import re + +from idom import Ref, component, html, use_location from idom.testing import DisplayFixture from idom_router import ( Route, router, link, - use_location, use_params, use_query, ) @@ -45,7 +45,7 @@ def sample(): await display.goto("/missing") try: - root_element = display.root_element() + root_element = await display.root_element() except AttributeError: root_element = await display.page.wait_for_selector( f"#display-{display._next_view_id}", state="attached" @@ -162,10 +162,6 @@ def sample(): await display.page.wait_for_selector("#success") -def custom_path_compiler(path): - pattern = re.compile(path) - - async def test_custom_path_compiler(display: DisplayFixture): expected_params = {} @@ -178,26 +174,33 @@ def check_params(): def sample(): return router( Route( - "/first/{first:str}", + r"/first/(?P\d+)", check_params(), Route( - "/second/{second:str}", + r"/second/(?P[\d\.]+)", check_params(), Route( - "/third/{third:str}", + r"/third/(?P[\d,]+)", check_params(), ), ), ), - compiler=lambda path: RoutePattern(re.compile()), + compiler=lambda path: RoutePattern( + re.compile(rf"^{path}$"), + { + "first": int, + "second": float, + "third": lambda s: list(map(int, s.split(","))), + }, + ), ) await display.show(sample) for path, expected_params in [ - ("/first/1", {"first": "1"}), - ("/first/1/second/2", {"first": "1", "second": "2"}), - ("/first/1/second/2/third/3", {"first": "1", "second": "2", "third": "3"}), + ("/first/1", {"first": 1}), + ("/first/1/second/2.1", {"first": 1, "second": 2.1}), + ("/first/1/second/2.1/third/3,3", {"first": 1, "second": 2.1, "third": [3, 3]}), ]: await display.goto(path) await display.page.wait_for_selector("#success")