From 3d982595d2ea7f53c51bb9392e237b170793f852 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 1 Oct 2020 15:33:32 -0700 Subject: [PATCH] Completed work on add column (validation etc), closes #4 --- datasette_edit_schema/__init__.py | 33 +++++++++++++++++++---- tests/test_edit_schema.py | 45 ++++++++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/datasette_edit_schema/__init__.py b/datasette_edit_schema/__init__.py index fb4931e..35b3c23 100644 --- a/datasette_edit_schema/__init__.py +++ b/datasette_edit_schema/__init__.py @@ -1,5 +1,6 @@ from datasette import hookimpl from datasette.utils.asgi import Response, NotFound, Forbidden +from datasette.utils import sqlite3 from urllib.parse import quote_plus, unquote_plus import sqlite_utils @@ -210,15 +211,37 @@ async def add_column(request, datasette, database, table, formdata): name = formdata["name"] type = formdata["type"] + redirect = Response.redirect( + "/-/edit-schema/{}/{}".format(quote_plus(database.name), quote_plus(table)) + ) + + if not name: + datasette.add_message(request, "Column name is required", datasette.ERROR) + return redirect + + if type.upper() not in REV_TYPES: + datasette.add_message(request, "Invalid type: {}".format(type), datasette.ERROR) + return redirect + def do_add_column(conn): db = sqlite_utils.Database(conn) - db[table].add_column(name, type) + db[table].add_column(name, REV_TYPES[type.upper()]) - await datasette.databases[database.name].execute_write_fn(do_add_column, block=True) + error = None + try: + await datasette.databases[database.name].execute_write_fn( + do_add_column, block=True + ) + except sqlite3.OperationalError as e: + if "duplicate column name" in str(e): + error = "A column called '{}' already exists".format(name) + else: + error = str(e) - return Response.redirect( - "/{}/{}".format(quote_plus(database.name), quote_plus(table)) - ) + if error: + datasette.add_message(request, error, datasette.ERROR) + + return redirect async def rename_table(request, datasette, database, table, formdata): diff --git a/tests/test_edit_schema.py b/tests/test_edit_schema.py index fa429db..3625395 100644 --- a/tests/test_edit_schema.py +++ b/tests/test_edit_schema.py @@ -83,7 +83,7 @@ async def test_delete_table(db_path): @pytest.mark.asyncio @pytest.mark.parametrize( "col_type,expected_type", - [("text", str), ("integer", int), ("float", float), ("blob", bytes)], + [("text", str), ("integer", int), ("real", float), ("blob", bytes)], ) async def test_add_column(db_path, col_type, expected_type): ds = Datasette([db_path]) @@ -110,6 +110,12 @@ async def test_add_column(db_path, col_type, expected_type): cookies=cookies, ) assert 302 == response.status_code + if "ds_messages" in response.cookies: + messages = ds.unsign(response.cookies["ds_messages"], "messages") + # None of these should be errors + assert all(m[1] == Datasette.INFO for m in messages), "Got an error: {}".format( + messages + ) assert { "name": str, "description": str, @@ -117,6 +123,43 @@ async def test_add_column(db_path, col_type, expected_type): } == table.columns_dict +@pytest.mark.asyncio +@pytest.mark.parametrize( + "name,type,expected_error", + [ + ("name", "text", "A column called 'name' already exists"), + ("", "text", "Column name is required"), + ("]]]", "integer", 'unrecognized token: "]"'), + ("name", "blop", "Invalid type: blop"), + ], +) +async def test_add_column_errors(db_path, name, type, expected_error): + ds = Datasette([db_path]) + cookies = {"ds_actor": ds.sign({"a": {"id": "root"}}, "actor")} + async with httpx.AsyncClient(app=ds.app()) as client: + csrftoken = ( + await client.get( + "http://localhost/-/edit-schema/data/creatures", cookies=cookies + ) + ).cookies["ds_csrftoken"] + response = await client.post( + "http://localhost/-/edit-schema/data/creatures", + data={ + "add_column": "1", + "name": name, + "type": type, + "csrftoken": csrftoken, + }, + allow_redirects=False, + cookies=cookies, + ) + response.status_code == 302 + assert response.headers["location"] == "/-/edit-schema/data/creatures" + messages = ds.unsign(response.cookies["ds_messages"], "messages") + assert len(messages) == 1 + assert messages[0][0] == expected_error + + @pytest.mark.asyncio @pytest.mark.parametrize( "post_data,expected_columns_dict,expected_order",