From 2f9038a831a3510d4c9ab39a12d96259b3a55bc7 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 6 Aug 2023 17:02:05 -0700 Subject: [PATCH] Define QueryContext and extract get_tables() method, refs #2127 --- datasette/views/database.py | 238 ++++++++++++++++++++++++++++++------ 1 file changed, 199 insertions(+), 39 deletions(-) diff --git a/datasette/views/database.py b/datasette/views/database.py index 2b7edc7aef..434cdb29ab 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -1,4 +1,5 @@ from asyncinject import Registry +import asyncio import os import hashlib import itertools @@ -11,6 +12,7 @@ from datasette.utils import ( add_cors_headers, await_me_maybe, + call_with_supported_arguments, derive_named_parameters, format_bytes, tilde_decode, @@ -80,33 +82,7 @@ async def database_view(request, datasette): } ) - tables = [] - for table in table_counts: - table_visible, table_private = await datasette.check_visibility( - request.actor, - permissions=[ - ("view-table", (database, table)), - ("view-database", database), - "view-instance", - ], - ) - if not table_visible: - continue - table_columns = await db.table_columns(table) - tables.append( - { - "name": table, - "columns": table_columns, - "primary_keys": await db.primary_keys(table), - "count": table_counts[table], - "hidden": table in hidden_table_names, - "fts_table": await db.fts_table(table), - "foreign_keys": all_foreign_keys[table], - "private": table_private, - } - ) - - tables.sort(key=lambda t: (t["hidden"], t["name"])) + tables = await get_tables(datasette, request, db) canned_queries = [] for query in (await datasette.get_canned_queries(database, request.actor)).values(): query_visible, query_private = await datasette.check_visibility( @@ -175,6 +151,96 @@ async def database_actions(): ) +from dataclasses import dataclass, field + + +@dataclass +class QueryContext: + database: str = field(metadata={"help": "The name of the database being queried"}) + query: dict = field( + metadata={"help": "The SQL query object containing the `sql` string"} + ) + canned_query: str = field( + metadata={"help": "The name of the canned query if this is a canned query"} + ) + private: bool = field( + metadata={"help": "Boolean indicating if this is a private database"} + ) + urls: dict = field( + metadata={"help": "Object containing URL helpers like `database()`"} + ) + canned_write: bool = field( + metadata={"help": "Boolean indicating if this canned query allows writes"} + ) + db_is_immutable: bool = field( + metadata={"help": "Boolean indicating if this database is immutable"} + ) + error: str = field(metadata={"help": "Any query error message"}) + hide_sql: bool = field( + metadata={"help": "Boolean indicating if the SQL should be hidden"} + ) + show_hide_link: str = field( + metadata={"help": "The URL to toggle showing/hiding the SQL"} + ) + show_hide_text: str = field( + metadata={"help": "The text for the show/hide SQL link"} + ) + editable: bool = field( + metadata={"help": "Boolean indicating if the SQL can be edited"} + ) + allow_execute_sql: bool = field( + metadata={"help": "Boolean indicating if custom SQL can be executed"} + ) + tables: list = field(metadata={"help": "List of table objects in the database"}) + named_parameter_values: dict = field( + metadata={"help": "Dictionary of parameter names/values"} + ) + csrftoken: callable = field(metadata={"help": "Function to generate a CSRF token"}) + edit_sql_url: str = field( + metadata={"help": "URL to edit the SQL for a canned query"} + ) + display_rows: list = field(metadata={"help": "List of result rows to display"}) + columns: list = field(metadata={"help": "List of column names"}) + renderers: dict = field(metadata={"help": "Dictionary of renderer name to URL"}) + url_csv: str = field(metadata={"help": "URL for CSV export"}) + metadata: dict = field(metadata={"help": "Metadata about the query/database"}) + + +async def get_tables(datasette, request, db): + tables = [] + database = db.name + table_counts = await db.table_counts(5) + hidden_table_names = set(await db.hidden_table_names()) + all_foreign_keys = await db.get_all_foreign_keys() + + for table in table_counts: + table_visible, table_private = await datasette.check_visibility( + request.actor, + permissions=[ + ("view-table", (database, table)), + ("view-database", database), + "view-instance", + ], + ) + if not table_visible: + continue + table_columns = await db.table_columns(table) + tables.append( + { + "name": table, + "columns": table_columns, + "primary_keys": await db.primary_keys(table), + "count": table_counts[table], + "hidden": table in hidden_table_names, + "fts_table": await db.fts_table(table), + "foreign_keys": all_foreign_keys[table], + "private": table_private, + } + ) + tables.sort(key=lambda t: (t["hidden"], t["name"])) + return tables + + async def database_download(request, datasette): database = tilde_decode(request.url_vars["database"]) await datasette.ensure_permissions( @@ -233,7 +299,7 @@ async def query_view( sql = params.pop("sql") if "_shape" in params: params.pop("_shape") - + # extras come from original request.args to avoid being flattened extras = request.args.getlist("_extra") @@ -247,10 +313,7 @@ async def query_view( async def fetch_data_for_csv(request, _next=None): results = await db.execute(sql, params, truncate=True) - data = { - "rows": results.rows, - "columns": results.columns() - } + data = {"rows": results.rows, "columns": results.columns} return data, None, None return await stream_csv(datasette, fetch_data_for_csv, request, db.name) @@ -264,8 +327,8 @@ async def fetch_data_for_csv(request, _next=None): rows=rows, sql=sql, query_name=None, - database=resolved.db.name, - table=resolved.table, + database=database, + table=None, request=request, view_name="table", # These will be deprecated in Datasette 1.0: @@ -298,9 +361,7 @@ async def fetch_data_for_csv(request, _next=None): request, datasette.urls.path(path_with_format(request=request, format="json")), ) - data = { - - } + data = {} headers.update( { "Link": '{}; rel="alternate"; type="application/json+datasette"'.format( @@ -310,6 +371,32 @@ async def fetch_data_for_csv(request, _next=None): ) metadata = (datasette.metadata("databases") or {}).get(database, {}) datasette.update_with_inherited_metadata(metadata) + + results = await db.execute(sql, params, truncate=True) + rows = results.rows + columns = results.columns + + renderers = {} + for key, (_, can_render) in datasette.renderers.items(): + it_can_render = call_with_supported_arguments( + can_render, + datasette=datasette, + columns=data.get("columns") or [], + rows=data.get("rows") or [], + sql=data.get("query", {}).get("sql", None), + query_name=data.get("query_name"), + database=database, + table=data.get("table"), + request=request, + # TODO: Fix this + view_name=None, + ) + it_can_render = await await_me_maybe(it_can_render) + if it_can_render: + renderers[key] = datasette.urls.path( + path_with_format(request=request, format=key) + ) + r = Response.html( await datasette.render_template( template, @@ -319,13 +406,20 @@ async def fetch_data_for_csv(request, _next=None): "database_color": lambda _: "#ff0000", "metadata": metadata, "columns": columns, - "display_rows": display_rows, + "display_rows": await display_rows( + datasette, database, request, rows, columns + ), + "renderers": renderers, + "editable": True, + # TODO: permission check + "allow_execute_sql": True, + "tables": await get_tables(datasette, request, db), }, request=request, ), headers=headers, ) - + # dict( # data, # append_querystring=append_querystring, @@ -902,3 +996,69 @@ async def _table_columns(datasette, database_name): for view_name in await db.view_names(): table_columns[view_name] = [] return table_columns + + +async def display_rows(datasette, database, request, rows, columns): + display_rows = [] + truncate_cells = datasette.setting("truncate_cells_html") + for row in rows: + display_row = [] + for column, value in zip(columns, row): + display_value = value + # Let the plugins have a go + # pylint: disable=no-member + plugin_display_value = None + for candidate in pm.hook.render_cell( + row=row, + value=value, + column=column, + table=None, + database=database, + datasette=datasette, + request=request, + ): + candidate = await await_me_maybe(candidate) + if candidate is not None: + plugin_display_value = candidate + break + if plugin_display_value is not None: + display_value = plugin_display_value + else: + if value in ("", None): + display_value = markupsafe.Markup(" ") + elif is_url(str(display_value).strip()): + display_value = markupsafe.Markup( + '{truncated_url}'.format( + url=markupsafe.escape(value.strip()), + truncated_url=markupsafe.escape( + truncate_url(value.strip(), truncate_cells) + ), + ) + ) + elif isinstance(display_value, bytes): + blob_url = path_with_format( + request=request, + format="blob", + extra_qs={ + "_blob_column": column, + "_blob_hash": hashlib.sha256(display_value).hexdigest(), + }, + ) + formatted = format_bytes(len(value)) + display_value = markupsafe.Markup( + '<Binary: {:,} byte{}>'.format( + blob_url, + ' title="{}"'.format(formatted) + if "bytes" not in formatted + else "", + len(value), + "" if len(value) == 1 else "s", + ) + ) + else: + display_value = str(value) + if truncate_cells and len(display_value) > truncate_cells: + display_value = display_value[:truncate_cells] + "\u2026" + display_row.append(display_value) + display_rows.append(display_row) + return display_rows