diff --git a/datasette/app.py b/datasette/app.py index 4a8ead1ddf..16a29e200f 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -651,9 +651,12 @@ async def setup_db(): if not database.is_mutable: await database.table_counts(limit=60 * 60 * 1000) - return AsgiLifespan( + asgi = AsgiLifespan( AsgiTracer(DatasetteRouter(self, routes)), on_startup=setup_db ) + for wrapper in pm.hook.asgi_wrapper(datasette=self): + asgi = wrapper(asgi) + return asgi class DatasetteRouter(AsgiRouter): diff --git a/datasette/hookspecs.py b/datasette/hookspecs.py index 61523a31b6..42adaae8fd 100644 --- a/datasette/hookspecs.py +++ b/datasette/hookspecs.py @@ -5,6 +5,11 @@ hookimpl = HookimplMarker("datasette") +@hookspec +def asgi_wrapper(datasette): + "Returns an ASGI middleware callable to wrap our ASGI application with" + + @hookspec def prepare_connection(conn): "Modify SQLite connection in some way e.g. register custom SQL functions" diff --git a/docs/plugins.rst b/docs/plugins.rst index bd32b3a666..be3355462a 100644 --- a/docs/plugins.rst +++ b/docs/plugins.rst @@ -666,3 +666,44 @@ The plugin hook can then be used to register the new facet class like this: @hookimpl def register_facet_classes(): return [SpecialFacet] + + +.. _plugin_asgi_wrapper: + +asgi_wrapper(datasette) +~~~~~~~~~~~~~~~~~~~~~~~ + +Return an `ASGI `__ middleware wrapper function that will be applied to the Datasette ASGI application. + +This is a very powerful hook. You can use it to manipulate the entire Datasette response, or even to configure new URL routes that will be handled by your own custom code. + +You can write your ASGI code directly against the low-level specification, or you can use the middleware utilites provided by an ASGI framework such as `Starlette `__. + +This example plugin adds a ``x-databases`` HTTP header listing the currently attached databases: + +.. code-block:: python + + from datasette import hookimpl + from functools import wraps + + + @hookimpl + def asgi_wrapper(datasette): + def wrap_with_databases_header(app): + @wraps(app) + async def add_x_databases_header(scope, recieve, send): + async def wrapped_send(event): + if event["type"] == "http.response.start": + original_headers = event.get("headers") or [] + event = { + "type": event["type"], + "status": event["status"], + "headers": original_headers + [ + [b"x-databases", + ", ".join(datasette.databases.keys()).encode("utf-8")] + ], + } + await send(event) + await app(scope, recieve, wrapped_send) + return add_x_databases_header + return wrap_with_databases_header diff --git a/tests/fixtures.py b/tests/fixtures.py index 0330c8ed25..fab6509e16 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -372,6 +372,7 @@ def render_cell(value, column, table, database, datasette): PLUGIN2 = """ from datasette import hookimpl +from functools import wraps import jinja2 import json @@ -413,6 +414,28 @@ def render_cell(value, database): label=jinja2.escape(data["label"] or "") or " " ) ) + + +@hookimpl +def asgi_wrapper(datasette): + def wrap_with_databases_header(app): + @wraps(app) + async def add_x_databases_header(scope, recieve, send): + async def wrapped_send(event): + if event["type"] == "http.response.start": + original_headers = event.get("headers") or [] + event = { + "type": event["type"], + "status": event["status"], + "headers": original_headers + [ + [b"x-databases", + ", ".join(datasette.databases.keys()).encode("utf-8")] + ], + } + await send(event) + await app(scope, recieve, wrapped_send) + return add_x_databases_header + return wrap_with_databases_header """ TABLES = ( diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 56033bddc4..9bdd491a0b 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -162,3 +162,8 @@ def test_plugins_extra_body_script(app_client, path, expected_extra_body_script) json_data = r.search(app_client.get(path).body.decode("utf8")).group(1) actual_data = json.loads(json_data) assert expected_extra_body_script == actual_data + + +def test_plugins_asgi_wrapper(app_client): + response = app_client.get("/fixtures") + assert "fixtures" == response.headers["x-databases"]