From 5e653fdbdb055cd88e53a72c9d75ca294efde42e Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 12 Jun 2024 16:51:07 -0700 Subject: [PATCH] named_parameters(sql) sync function, refs #2354 Also refs #2353 and #2352 --- datasette/utils/__init__.py | 33 ++++++++++++++++++++++++--------- datasette/views/database.py | 6 ++---- docs/internals.rst | 10 +++++----- tests/test_utils.py | 8 ++++++-- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index 8075420278..b3b51c5fa7 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -1131,23 +1131,38 @@ class StartupError(Exception): pass -_re_named_parameter = re.compile(":([a-zA-Z0-9_]+)") +_single_line_comment_re = re.compile(r"--.*") +_multi_line_comment_re = re.compile(r"/\*.*?\*/", re.DOTALL) +_single_quote_re = re.compile(r"'(?:''|[^'])*'") +_double_quote_re = re.compile(r'"(?:\"\"|[^"])*"') +_named_param_re = re.compile(r":(\w+)") @documented -async def derive_named_parameters(db: "Database", sql: str) -> List[str]: +def named_parameters(sql: str) -> List[str]: """ Given a SQL statement, return a list of named parameters that are used in the statement e.g. for ``select * from foo where id=:id`` this would return ``["id"]`` """ - explain = "explain {}".format(sql.strip().rstrip(";")) - possible_params = _re_named_parameter.findall(sql) - try: - results = await db.execute(explain, {p: None for p in possible_params}) - return [row["p4"].lstrip(":") for row in results if row["opcode"] == "Variable"] - except (sqlite3.DatabaseError, AttributeError): - return possible_params + # Remove single-line comments + sql = _single_line_comment_re.sub("", sql) + # Remove multi-line comments + sql = _multi_line_comment_re.sub("", sql) + # Remove single-quoted strings + sql = _single_quote_re.sub("", sql) + # Remove double-quoted strings + sql = _double_quote_re.sub("", sql) + # Extract parameters from what is left + return _named_param_re.findall(sql) + + +async def derive_named_parameters(db: "Database", sql: str) -> List[str]: + """ + This undocumented but stable method exists for backwards compatibility + with plugins that were using it before it switched to named_parameters() + """ + return named_parameters(sql) def add_cors_headers(headers): diff --git a/datasette/views/database.py b/datasette/views/database.py index 2698a0ebaa..1d76c5e047 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -17,7 +17,7 @@ add_cors_headers, await_me_maybe, call_with_supported_arguments, - derive_named_parameters, + named_parameters as derive_named_parameters, format_bytes, make_slot_function, tilde_decode, @@ -484,9 +484,7 @@ async def get(self, request, datasette): if canned_query and canned_query.get("params"): named_parameters = canned_query["params"] if not named_parameters: - named_parameters = await derive_named_parameters( - datasette.get_database(database), sql - ) + named_parameters = derive_named_parameters(sql) named_parameter_values = { named_parameter: params.get(named_parameter) or "" for named_parameter in named_parameters diff --git a/docs/internals.rst b/docs/internals.rst index 795856594d..38e66a57b7 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -1256,14 +1256,14 @@ Utility function for calling ``await`` on a return value if it is awaitable, oth .. autofunction:: datasette.utils.await_me_maybe -.. _internals_utils_derive_named_parameters: +.. _internals_utils_named_parameters: -derive_named_parameters(db, sql) --------------------------------- +named_parameters(sql) +--------------------- -Derive the list of named parameters referenced in a SQL query, using an ``explain`` query executed against the provided database. +Derive the list of ``:named`` parameters referenced in a SQL query. -.. autofunction:: datasette.utils.derive_named_parameters +.. autofunction:: datasette.utils.named_parameters .. _internals_tilde_encoding: diff --git a/tests/test_utils.py b/tests/test_utils.py index 254b130055..88a4532ad1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -612,10 +612,14 @@ def test_parse_metadata(content, expected): ("select this is invalid :one, :two, :three", ["one", "two", "three"]), ), ) -async def test_derive_named_parameters(sql, expected): +@pytest.mark.parametrize("use_async_version", (False, True)) +async def test_named_parameters(sql, expected, use_async_version): ds = Datasette([], memory=True) db = ds.get_database("_memory") - params = await utils.derive_named_parameters(db, sql) + if use_async_version: + params = await utils.derive_named_parameters(db, sql) + else: + params = utils.named_parameters(sql) assert params == expected