diff --git a/datasette/__init__.py b/datasette/__init__.py index 64fb4ff7d0..271e09ada0 100644 --- a/datasette/__init__.py +++ b/datasette/__init__.py @@ -1,6 +1,7 @@ -from datasette.permissions import Permission +from datasette.permissions import Permission # noqa from datasette.version import __version_info__, __version__ # noqa from datasette.utils.asgi import Forbidden, NotFound, Request, Response # noqa from datasette.utils import actor_matches_allow # noqa +from datasette.views import Context # noqa from .hookspecs import hookimpl # noqa from .hookspecs import hookspec # noqa diff --git a/datasette/app.py b/datasette/app.py index 69074e8fdf..39c2bb6de9 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,7 +1,8 @@ import asyncio -from typing import Sequence, Union, Tuple, Optional, Dict, Iterable +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union import asgi_csrf import collections +import dataclasses import datetime import functools import glob @@ -33,6 +34,7 @@ from jinja2.environment import Template from jinja2.exceptions import TemplateNotFound +from .views import Context from .views.base import ureg from .views.database import database_download, database_view, TableCreateView from .views.index import IndexView @@ -1115,7 +1117,11 @@ def _register_renderers(self): ) async def render_template( - self, templates, context=None, request=None, view_name=None + self, + templates: Union[List[str], str, Template], + context: Optional[Union[Dict[str, Any], Context]] = None, + request: Optional[Request] = None, + view_name: Optional[str] = None, ): if not self._startup_invoked: raise Exception("render_template() called before await ds.invoke_startup()") @@ -1126,6 +1132,8 @@ async def render_template( if isinstance(templates, str): templates = [templates] template = self.jinja_env.select_template(templates) + if dataclasses.is_dataclass(context): + context = dataclasses.asdict(context) body_scripts = [] # pylint: disable=no-member for extra_script in pm.hook.extra_body_script( diff --git a/datasette/views/__init__.py b/datasette/views/__init__.py index e69de29bb2..e3b1b7f44b 100644 --- a/datasette/views/__init__.py +++ b/datasette/views/__init__.py @@ -0,0 +1,3 @@ +class Context: + "Base class for all documented contexts" + pass diff --git a/tests/test_internals_datasette.py b/tests/test_internals_datasette.py index 3d5bb2da58..d59ff72976 100644 --- a/tests/test_internals_datasette.py +++ b/tests/test_internals_datasette.py @@ -1,10 +1,12 @@ """ Tests for the datasette.app.Datasette class """ -from datasette import Forbidden +import dataclasses +from datasette import Forbidden, Context from datasette.app import Datasette, Database from itsdangerous import BadSignature import pytest +from typing import Optional @pytest.fixture @@ -136,6 +138,22 @@ async def test_datasette_render_template_no_request(): assert "Error " in rendered +@pytest.mark.asyncio +async def test_datasette_render_template_with_dataclass(): + @dataclasses.dataclass + class ExampleContext(Context): + title: str + status: int + error: str + + context = ExampleContext(title="Hello", status=200, error="Error message") + ds = Datasette(memory=True) + await ds.invoke_startup() + rendered = await ds.render_template("error.html", context) + assert "