From f45c44ac8f3a9a2182d76c6bda44a06676499e4b Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 3 Jun 2020 06:39:19 -0700 Subject: [PATCH] First unit test for writable canned queries, refs #698 Message also now shows number of affected rows after executing query. --- datasette/views/database.py | 13 +++++++---- tests/fixtures.py | 37 +++++++++++++++++++++++++------ tests/test_canned_write.py | 43 +++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 11 deletions(-) create mode 100644 tests/test_canned_write.py diff --git a/datasette/views/database.py b/datasette/views/database.py index 5104c95980..2c893ca9b6 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -138,14 +138,18 @@ async def data( if write: if request.method == "POST": params = await request.post_vars() - write_ok = await self.ds.databases[database].execute_write( + cursor = await self.ds.databases[database].execute_write( sql, params, block=True ) - self.ds.add_message(request, "Query executed") - return self.redirect( - request, request.path + self.ds.add_message( + request, + "Query executed, {} row{} affected".format( + cursor.rowcount, "" if cursor.rowcount == 1 else "s" + ), ) + return self.redirect(request, request.path) else: + async def extra_template(): return { "request": request, @@ -156,6 +160,7 @@ async def extra_template(): "success_message": request.args.get("_success") or "", "canned_write": True, } + return ( { "database": database, diff --git a/tests/fixtures.py b/tests/fixtures.py index daff016880..78a54c68c9 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -14,7 +14,7 @@ import tempfile import textwrap import time -from urllib.parse import unquote, quote +from urllib.parse import unquote, quote, urlencode # This temp file is used by one of the plugin config tests @@ -54,10 +54,26 @@ def __init__(self, asgi_app): async def get( self, path, allow_redirects=True, redirect_count=0, method="GET", cookies=None ): - return await self._get(path, allow_redirects, redirect_count, method, cookies) + return await self._request( + path, allow_redirects, redirect_count, method, cookies + ) - async def _get( - self, path, allow_redirects=True, redirect_count=0, method="GET", cookies=None + @async_to_sync + async def post( + self, path, post_data=None, allow_redirects=True, redirect_count=0, cookies=None + ): + return await self._request( + path, allow_redirects, redirect_count, "POST", cookies, post_data + ) + + async def _request( + self, + path, + allow_redirects=True, + redirect_count=0, + method="GET", + cookies=None, + post_data=None, ): query_string = b"" if "?" in path: @@ -83,7 +99,13 @@ async def _get( "headers": headers, } instance = ApplicationCommunicator(self.asgi_app, scope) - await instance.send_input({"type": "http.request"}) + + if post_data: + body = urlencode(post_data, doseq=True).encode("utf-8") + await instance.send_input({"type": "http.request", "body": body}) + else: + await instance.send_input({"type": "http.request"}) + # First message back should be response.start with headers and status messages = [] start = await instance.receive_output(2) @@ -110,7 +132,7 @@ async def _get( redirect_count, self.max_redirects ) location = response.headers["Location"] - return await self._get( + return await self._request( location, allow_redirects=True, redirect_count=redirect_count + 1 ) return response @@ -128,6 +150,7 @@ def make_app_client( inspect_data=None, static_mounts=None, template_dir=None, + metadata=None, ): with tempfile.TemporaryDirectory() as tmpdir: filepath = os.path.join(tmpdir, filename) @@ -161,7 +184,7 @@ def make_app_client( immutables=immutables, memory=memory, cors=cors, - metadata=METADATA, + metadata=metadata or METADATA, plugins_dir=PLUGINS_DIR, config=config, inspect_data=inspect_data, diff --git a/tests/test_canned_write.py b/tests/test_canned_write.py new file mode 100644 index 0000000000..4ffae59324 --- /dev/null +++ b/tests/test_canned_write.py @@ -0,0 +1,43 @@ +import pytest +from .fixtures import make_app_client + + +@pytest.fixture(scope="session") +def canned_write_client(): + for client in make_app_client( + extra_databases={"data.db": "create table names (name text)"}, + metadata={ + "databases": { + "data": { + "queries": { + "add_name": { + "sql": "insert into names (name) values (:name)", + "write": True, + }, + "delete_name": { + "sql": "delete from names where rowid = :rowid", + "write": True, + }, + "update_name": { + "sql": "update names set name = :name where rowid = :rowid", + "params": ["rowid", "name"], + "write": True, + }, + } + } + } + }, + ): + yield client + + +def test_insert(canned_write_client): + response = canned_write_client.post( + "/data/add_name", {"name": "Hello"}, allow_redirects=False + ) + assert 302 == response.status + assert "/data/add_name" == response.headers["Location"] + messages = canned_write_client.ds.unsign( + response.cookies["ds_messages"], "messages" + ) + assert [["Query executed, 1 row affected", 1]] == messages