diff --git a/datasette_write/__init__.py b/datasette_write/__init__.py index 0f7a081..330db68 100644 --- a/datasette_write/__init__.py +++ b/datasette_write/__init__.py @@ -9,16 +9,9 @@ async def write(request, datasette): request.actor, "datasette-write", default=False ): raise Forbidden("Permission denied for datasette-write") - databases = [ - db - for db in datasette.databases.values() - if db.is_mutable and db.name != "_internal" - ] + database_name = request.url_vars["database"] if request.method == "GET": - selected_database = request.args.get("database") or "" - if not selected_database or selected_database == "_internal": - selected_database = databases[0].name - database = datasette.get_database(selected_database) + database = datasette.get_database(database_name) tables = await database.table_names() views = await database.view_names() sql = request.args.get("sql") or "" @@ -30,10 +23,14 @@ async def write(request, datasette): await datasette.render_template( "datasette_write.html", { - "databases": databases, "sql_from_args": sql, +<<<<<<< HEAD "selected_database": selected_database, "parameters": parameters, +======= + "database_name": database_name, + "parameters": await derive_parameters(database, sql), +>>>>>>> main "tables": tables, "views": views, "redirect_to": request.args.get("_redirect_to") @@ -45,12 +42,8 @@ async def write(request, datasette): ) elif request.method == "POST": formdata = await request.post_vars() - database_name = formdata["database"] sql = formdata["sql"] - try: - database = [db for db in databases if db.name == database_name][0] - except IndexError: - return Response.html("Database not found", status_code=404) + database = datasette.get_database(database_name) result = None message = None @@ -101,6 +94,19 @@ async def write(request, datasette): return Response.html("Bad method", status_code=405) +async def write_redirect(request, datasette): + if not await datasette.permission_allowed( + request.actor, "datasette-write", default=False + ): + raise Forbidden("Permission denied for datasette-write") + + db = request.args.get("database") or "" + if not db: + db = datasette.get_database().name + + return Response.redirect(datasette.urls.database(db) + "/-/write") + + async def derive_parameters(db, sql): parameters = await derive_named_parameters(db, sql) return [ @@ -133,7 +139,8 @@ async def write_derive_parameters(datasette, request): @hookimpl def register_routes(): return [ - (r"^/-/write$", write), + (r"^/(?P[^/]+)/-/write$", write), + (r"^/-/write$", write_redirect), (r"^/-/write/derive-parameters$", write_derive_parameters), ] @@ -144,20 +151,6 @@ def permission_allowed(actor, action): return True -@hookimpl -def menu_links(datasette, actor): - async def inner(): - if await datasette.permission_allowed(actor, "datasette-write", default=False): - return [ - { - "href": datasette.urls.path("/-/write"), - "label": "Execute SQL write", - }, - ] - - return inner - - @hookimpl def database_actions(datasette, actor, database): async def inner(): @@ -166,14 +159,7 @@ async def inner(): ): return [ { - "href": datasette.urls.path( - "/-/write?" - + urlencode( - { - "database": database, - } - ) - ), + "href": datasette.urls.database(database) + "/-/write", "label": "Execute SQL write", "description": "Run queries like insert/update/delete against this database", }, diff --git a/datasette_write/templates/datasette_write.html b/datasette_write/templates/datasette_write.html index 7beed99..d51af03 100644 --- a/datasette_write/templates/datasette_write.html +++ b/datasette_write/templates/datasette_write.html @@ -36,14 +36,15 @@ {% endblock %} +{% block crumbs %} +{{ crumbs.nav(request=request, database=database_name) }} +{% endblock %} + {% block content %} -

Write to the database with SQL

+

Write to {{ database_name }} with SQL

-
+ -

@@ -66,7 +67,7 @@

Write to the database with SQL

{% if tables %}

Tables: {% for table in tables %} - {{ table }}{% if not loop.last %}, {% endif %} + {{ table }}{% if not loop.last %}, {% endif %} {% endfor %}

{% endif %} @@ -74,7 +75,7 @@

Write to the database with SQL

{% if views %}

Views: {% for view in views %} - {{ view }}{% if not loop.last %}, {% endif %} + {{ view }}{% if not loop.last %}, {% endif %} {% endfor %}

{% endif %} diff --git a/setup.py b/setup.py index 273f411..c6045de 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup import os -VERSION = "0.3.1" +VERSION = "0.3.2" def get_long_description(): diff --git a/tests/test_write.py b/tests/test_write.py index c4fe801..14a011e 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -24,14 +24,14 @@ def ds(tmp_path_factory): @pytest.mark.asyncio async def test_permission_denied(ds): - response = await ds.client.get("/-/write") + response = await ds.client.get("/test/-/write") assert 403 == response.status_code @pytest.mark.asyncio async def test_permission_granted_to_root(ds): response = await ds.client.get( - "/-/write", + "/test/-/write", cookies={"ds_actor": ds.sign({"a": {"id": "root"}}, "actor")}, ) assert response.status_code == 200 @@ -40,7 +40,7 @@ async def test_permission_granted_to_root(ds): # Should have database action menu option too: anon_response = (await ds.client.get("/test")).text - fragment = ">Execute SQL write<" + fragment = 'Execute SQL write' assert fragment not in anon_response root_response = ( await ds.client.get( @@ -50,21 +50,10 @@ async def test_permission_granted_to_root(ds): assert fragment in root_response -@pytest.mark.asyncio -@pytest.mark.parametrize("database", ["test", "test2"]) -async def test_select_database(ds, database): - response = await ds.client.get( - "/-/write?database={}".format(database), - cookies={"ds_actor": ds.sign({"a": {"id": "root"}}, "actor")}, - ) - assert response.status_code == 200 - assert ''.format(database) in response.text - - @pytest.mark.asyncio async def test_populate_sql_from_query_string(ds): response = await ds.client.get( - "/-/write?sql=select+1", + "/test/-/write?sql=select+1", cookies={"ds_actor": ds.sign({"a": {"id": "root"}}, "actor")}, ) assert response.status_code == 200 @@ -121,19 +110,18 @@ async def test_populate_sql_from_query_string(ds): async def test_execute_write(ds, database, sql, params, expected_message): # Get csrftoken cookies = {"ds_actor": ds.sign({"a": {"id": "root"}}, "actor")} - response = await ds.client.get("/-/write", cookies=cookies) + response = await ds.client.get("/{}/-/write".format(database), cookies=cookies) assert 200 == response.status_code csrftoken = response.cookies["ds_csrftoken"] cookies["ds_csrftoken"] = csrftoken data = { "sql": sql, "csrftoken": csrftoken, - "database": database, } data.update(params) # write to database response2 = await ds.client.post( - "/-/write", + "/{}/-/write".format(database), data=data, cookies=cookies, )