Skip to content

Commit

Permalink
Define QueryContext and extract get_tables() method, refs #2127
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Aug 7, 2023
1 parent 002289b commit 2f9038a
Showing 1 changed file with 199 additions and 39 deletions.
238 changes: 199 additions & 39 deletions datasette/views/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from asyncinject import Registry
import asyncio
import os
import hashlib
import itertools
Expand All @@ -11,6 +12,7 @@
from datasette.utils import (
add_cors_headers,
await_me_maybe,
call_with_supported_arguments,
derive_named_parameters,
format_bytes,
tilde_decode,
Expand Down Expand Up @@ -80,33 +82,7 @@ async def database_view(request, datasette):
}
)

tables = []
for table in table_counts:
table_visible, table_private = await datasette.check_visibility(
request.actor,
permissions=[
("view-table", (database, table)),
("view-database", database),
"view-instance",
],
)
if not table_visible:
continue
table_columns = await db.table_columns(table)
tables.append(
{
"name": table,
"columns": table_columns,
"primary_keys": await db.primary_keys(table),
"count": table_counts[table],
"hidden": table in hidden_table_names,
"fts_table": await db.fts_table(table),
"foreign_keys": all_foreign_keys[table],
"private": table_private,
}
)

tables.sort(key=lambda t: (t["hidden"], t["name"]))
tables = await get_tables(datasette, request, db)
canned_queries = []
for query in (await datasette.get_canned_queries(database, request.actor)).values():
query_visible, query_private = await datasette.check_visibility(
Expand Down Expand Up @@ -175,6 +151,96 @@ async def database_actions():
)


from dataclasses import dataclass, field


@dataclass
class QueryContext:
database: str = field(metadata={"help": "The name of the database being queried"})
query: dict = field(
metadata={"help": "The SQL query object containing the `sql` string"}
)
canned_query: str = field(
metadata={"help": "The name of the canned query if this is a canned query"}
)
private: bool = field(
metadata={"help": "Boolean indicating if this is a private database"}
)
urls: dict = field(
metadata={"help": "Object containing URL helpers like `database()`"}
)
canned_write: bool = field(
metadata={"help": "Boolean indicating if this canned query allows writes"}
)
db_is_immutable: bool = field(
metadata={"help": "Boolean indicating if this database is immutable"}
)
error: str = field(metadata={"help": "Any query error message"})
hide_sql: bool = field(
metadata={"help": "Boolean indicating if the SQL should be hidden"}
)
show_hide_link: str = field(
metadata={"help": "The URL to toggle showing/hiding the SQL"}
)
show_hide_text: str = field(
metadata={"help": "The text for the show/hide SQL link"}
)
editable: bool = field(
metadata={"help": "Boolean indicating if the SQL can be edited"}
)
allow_execute_sql: bool = field(
metadata={"help": "Boolean indicating if custom SQL can be executed"}
)
tables: list = field(metadata={"help": "List of table objects in the database"})
named_parameter_values: dict = field(
metadata={"help": "Dictionary of parameter names/values"}
)
csrftoken: callable = field(metadata={"help": "Function to generate a CSRF token"})
edit_sql_url: str = field(
metadata={"help": "URL to edit the SQL for a canned query"}
)
display_rows: list = field(metadata={"help": "List of result rows to display"})
columns: list = field(metadata={"help": "List of column names"})
renderers: dict = field(metadata={"help": "Dictionary of renderer name to URL"})
url_csv: str = field(metadata={"help": "URL for CSV export"})
metadata: dict = field(metadata={"help": "Metadata about the query/database"})


async def get_tables(datasette, request, db):
tables = []
database = db.name
table_counts = await db.table_counts(5)
hidden_table_names = set(await db.hidden_table_names())
all_foreign_keys = await db.get_all_foreign_keys()

for table in table_counts:
table_visible, table_private = await datasette.check_visibility(
request.actor,
permissions=[
("view-table", (database, table)),
("view-database", database),
"view-instance",
],
)
if not table_visible:
continue
table_columns = await db.table_columns(table)
tables.append(
{
"name": table,
"columns": table_columns,
"primary_keys": await db.primary_keys(table),
"count": table_counts[table],
"hidden": table in hidden_table_names,
"fts_table": await db.fts_table(table),
"foreign_keys": all_foreign_keys[table],
"private": table_private,
}
)
tables.sort(key=lambda t: (t["hidden"], t["name"]))
return tables


async def database_download(request, datasette):
database = tilde_decode(request.url_vars["database"])
await datasette.ensure_permissions(
Expand Down Expand Up @@ -233,7 +299,7 @@ async def query_view(
sql = params.pop("sql")
if "_shape" in params:
params.pop("_shape")

# extras come from original request.args to avoid being flattened
extras = request.args.getlist("_extra")

Expand All @@ -247,10 +313,7 @@ async def query_view(

async def fetch_data_for_csv(request, _next=None):
results = await db.execute(sql, params, truncate=True)
data = {
"rows": results.rows,
"columns": results.columns()
}
data = {"rows": results.rows, "columns": results.columns}
return data, None, None

return await stream_csv(datasette, fetch_data_for_csv, request, db.name)
Expand All @@ -264,8 +327,8 @@ async def fetch_data_for_csv(request, _next=None):
rows=rows,
sql=sql,
query_name=None,
database=resolved.db.name,
table=resolved.table,
database=database,
table=None,
request=request,
view_name="table",
# These will be deprecated in Datasette 1.0:
Expand Down Expand Up @@ -298,9 +361,7 @@ async def fetch_data_for_csv(request, _next=None):
request,
datasette.urls.path(path_with_format(request=request, format="json")),
)
data = {

}
data = {}
headers.update(
{
"Link": '{}; rel="alternate"; type="application/json+datasette"'.format(
Expand All @@ -310,6 +371,32 @@ async def fetch_data_for_csv(request, _next=None):
)
metadata = (datasette.metadata("databases") or {}).get(database, {})
datasette.update_with_inherited_metadata(metadata)

results = await db.execute(sql, params, truncate=True)
rows = results.rows
columns = results.columns

renderers = {}
for key, (_, can_render) in datasette.renderers.items():
it_can_render = call_with_supported_arguments(
can_render,
datasette=datasette,
columns=data.get("columns") or [],
rows=data.get("rows") or [],
sql=data.get("query", {}).get("sql", None),
query_name=data.get("query_name"),
database=database,
table=data.get("table"),
request=request,
# TODO: Fix this
view_name=None,
)
it_can_render = await await_me_maybe(it_can_render)
if it_can_render:
renderers[key] = datasette.urls.path(
path_with_format(request=request, format=key)
)

r = Response.html(
await datasette.render_template(
template,
Expand All @@ -319,13 +406,20 @@ async def fetch_data_for_csv(request, _next=None):
"database_color": lambda _: "#ff0000",
"metadata": metadata,
"columns": columns,
"display_rows": display_rows,
"display_rows": await display_rows(
datasette, database, request, rows, columns
),
"renderers": renderers,
"editable": True,
# TODO: permission check
"allow_execute_sql": True,
"tables": await get_tables(datasette, request, db),
},
request=request,
),
headers=headers,
)

# dict(
# data,
# append_querystring=append_querystring,
Expand Down Expand Up @@ -902,3 +996,69 @@ async def _table_columns(datasette, database_name):
for view_name in await db.view_names():
table_columns[view_name] = []
return table_columns


async def display_rows(datasette, database, request, rows, columns):
display_rows = []
truncate_cells = datasette.setting("truncate_cells_html")
for row in rows:
display_row = []
for column, value in zip(columns, row):
display_value = value
# Let the plugins have a go
# pylint: disable=no-member
plugin_display_value = None
for candidate in pm.hook.render_cell(
row=row,
value=value,
column=column,
table=None,
database=database,
datasette=datasette,
request=request,
):
candidate = await await_me_maybe(candidate)
if candidate is not None:
plugin_display_value = candidate
break
if plugin_display_value is not None:
display_value = plugin_display_value
else:
if value in ("", None):
display_value = markupsafe.Markup(" ")
elif is_url(str(display_value).strip()):
display_value = markupsafe.Markup(
'<a href="{url}">{truncated_url}</a>'.format(
url=markupsafe.escape(value.strip()),
truncated_url=markupsafe.escape(
truncate_url(value.strip(), truncate_cells)
),
)
)
elif isinstance(display_value, bytes):
blob_url = path_with_format(
request=request,
format="blob",
extra_qs={
"_blob_column": column,
"_blob_hash": hashlib.sha256(display_value).hexdigest(),
},
)
formatted = format_bytes(len(value))
display_value = markupsafe.Markup(
'<a class="blob-download" href="{}"{}>&lt;Binary:&nbsp;{:,}&nbsp;byte{}&gt;</a>'.format(
blob_url,
' title="{}"'.format(formatted)
if "bytes" not in formatted
else "",
len(value),
"" if len(value) == 1 else "s",
)
)
else:
display_value = str(value)
if truncate_cells and len(display_value) > truncate_cells:
display_value = display_value[:truncate_cells] + "\u2026"
display_row.append(display_value)
display_rows.append(display_row)
return display_rows

0 comments on commit 2f9038a

Please sign in to comment.