Skip to content

Commit

Permalink
Async support for magic parameters
Browse files Browse the repository at this point in the history
Closes #2441
  • Loading branch information
simonw committed Nov 15, 2024
1 parent b0b600b commit dce7189
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
25 changes: 22 additions & 3 deletions datasette/views/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,10 @@ async def post(self, request, datasette):
or request.args.get("_json")
or params.get("_json")
)
params_for_query = MagicParameters(params, request, datasette)
params_for_query = MagicParameters(
canned_query["sql"], params, request, datasette
)
await params_for_query.execute_params()
ok = None
redirect_url = None
try:
Expand Down Expand Up @@ -523,7 +526,8 @@ async def get(self, request, datasette):
validate_sql_select(sql)
else:
# Canned queries can run magic parameters
params_for_query = MagicParameters(params, request, datasette)
params_for_query = MagicParameters(sql, params, request, datasette)
await params_for_query.execute_params()
results = await datasette.execute(
database, sql, params_for_query, truncate=True, **extra_args
)
Expand Down Expand Up @@ -792,14 +796,26 @@ async def query_actions():


class MagicParameters(dict):
def __init__(self, data, request, datasette):
def __init__(self, sql, data, request, datasette):
super().__init__(data)
self._sql = sql
self._request = request
self._magics = dict(
itertools.chain.from_iterable(
pm.hook.register_magic_parameters(datasette=datasette)
)
)
self._prepared = {}

async def execute_params(self):
for key in derive_named_parameters(self._sql):
if key.startswith("_") and key.count("_") >= 2:
prefix, suffix = key[1:].split("_", 1)
if prefix in self._magics:
result = await await_me_maybe(
self._magics[prefix](suffix, self._request)
)
self._prepared[key] = result

def __len__(self):
# Workaround for 'Incorrect number of bindings' error
Expand All @@ -808,6 +824,9 @@ def __len__(self):

def __getitem__(self, key):
if key.startswith("_") and key.count("_") >= 2:
if key in self._prepared:
return self._prepared[key]
# Try the other route
prefix, suffix = key[1:].split("_", 1)
if prefix in self._magics:
try:
Expand Down
6 changes: 5 additions & 1 deletion docs/plugin_hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ Magic parameters all take this format: ``_prefix_rest_of_parameter``. The prefix

To register a new function, return it as a tuple of ``(string prefix, function)`` from this hook. The function you register should take two arguments: ``key`` and ``request``, where ``key`` is the ``rest_of_parameter`` portion of the parameter and ``request`` is the current :ref:`internals_request`.

This example registers two new magic parameters: ``:_request_http_version`` returning the HTTP version of the current request, and ``:_uuid_new`` which returns a new UUID:
This example registers two new magic parameters: ``:_request_http_version`` returning the HTTP version of the current request, and ``:_uuid_new`` which returns a new UUID. It also registers an `:_asynclookup_key` parameter, demonstrating that these functions can be asynchronous:

.. code-block:: python
Expand All @@ -1337,11 +1337,15 @@ This example registers two new magic parameters: ``:_request_http_version`` retu
raise KeyError
async def asynclookup(key, request):
return await do_something_async(key)
@hookimpl
def register_magic_parameters(datasette):
return [
("request", request),
("uuid", uuid),
("asynclookup", asynclookup),
]
.. _plugin_hook_forbidden:
Expand Down
4 changes: 4 additions & 0 deletions tests/plugins/my_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,13 @@ def request(key, request):
else:
raise KeyError

async def asyncrequest(key, request):
return key

return [
("request", request),
("uuid", uuid),
("asyncrequest", asyncrequest),
]


Expand Down
7 changes: 7 additions & 0 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,9 @@ def test_hook_register_magic_parameters(restore_working_directory):
"get_uuid": {
"sql": "select :_uuid_new",
},
"asyncrequest": {
"sql": "select :_asyncrequest_key",
},
}
}
}
Expand All @@ -871,6 +874,10 @@ def test_hook_register_magic_parameters(restore_working_directory):
assert 200 == response_get.status
new_uuid = response_get.json[0][":_uuid_new"]
assert 4 == new_uuid.count("-")
# And test the async one
response_async = client.get("/data/asyncrequest.json?_shape=array")
assert 200 == response_async.status
assert response_async.json[0][":_asyncrequest_key"] == "key"


def test_hook_forbidden(restore_working_directory):
Expand Down

0 comments on commit dce7189

Please sign in to comment.