Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-79579: Improve DML query detection in sqlite3 #93623

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions Lib/test/test_sqlite3/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,22 +746,44 @@ def test_execute_illegal_sql(self):
with self.assertRaises(sqlite.OperationalError):
self.cu.execute("select asdf")

def test_execute_too_much_sql(self):
self.assertRaisesRegex(sqlite.ProgrammingError,
"You can only execute one statement at a time",
self.cu.execute, "select 5+4; select 4+5")

def test_execute_too_much_sql2(self):
self.cu.execute("select 5+4; -- foo bar")
def test_execute_multiple_statements(self):
msg = "You can only execute one statement at a time"
dataset = (
"select 1; select 2",
"select 1; // c++ comments are not allowed",
"select 1; *not a comment",
"select 1; -*not a comment",
"select 1; /* */ a",
"select 1; /**/a",
"select 1; -",
"select 1; /",
"select 1; -\n- select 2",
"""select 1;
-- comment
select 2
""",
)
for query in dataset:
with self.subTest(query=query):
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
self.cu.execute(query)

def test_execute_too_much_sql3(self):
self.cu.execute("""
def test_execute_with_appended_comments(self):
dataset = (
"select 1; -- foo bar",
"select 1; --",
"select 1; /*", # Unclosed comments ending in \0 are skipped.
"""
select 5+4;

/*
foo
*/
""")
""",
)
for query in dataset:
with self.subTest(query=query):
self.cu.execute(query)

def test_execute_wrong_sql_arg(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -906,6 +928,30 @@ def test_rowcount_update_returning(self):
self.assertEqual(self.cu.fetchone()[0], 1)
self.assertEqual(self.cu.rowcount, 1)

def test_rowcount_prefixed_with_comment(self):
# gh-79579: rowcount is updated even if query is prefixed with comments
self.cu.execute("""
-- foo
insert into test(name) values ('foo'), ('foo')
""")
self.assertEqual(self.cu.rowcount, 2)
self.cu.execute("""
/* -- messy *r /* /* ** *- *--
*/
/* one more */ insert into test(name) values ('messy')
""")
self.assertEqual(self.cu.rowcount, 1)
self.cu.execute("/* bar */ update test set name='bar' where name='foo'")
self.assertEqual(self.cu.rowcount, 3)

def test_rowcount_vaccuum(self):
data = ((1,), (2,), (3,))
self.cu.executemany("insert into test(income) values(?)", data)
self.assertEqual(self.cu.rowcount, 3)
self.cx.commit()
self.cu.execute("vacuum")
self.assertEqual(self.cu.rowcount, -1)

def test_total_changes(self):
self.cu.execute("insert into test(name) values ('foo')")
self.cu.execute("insert into test(name) values ('foo')")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
:mod:`sqlite3` now correctly detects DML queries with leading comments.
Patch by Erlend E. Aasland.
119 changes: 45 additions & 74 deletions Modules/_sqlite/statement.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,7 @@
#include "util.h"

/* prototypes */
static int pysqlite_check_remaining_sql(const char* tail);

typedef enum {
LINECOMMENT_1,
IN_LINECOMMENT,
COMMENTSTART_1,
IN_COMMENT,
COMMENTEND_1,
NORMAL
} parse_remaining_sql_state;
static const char *lstrip_sql(const char *sql);

pysqlite_Statement *
pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
Expand Down Expand Up @@ -73,7 +64,7 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
return NULL;
}

if (pysqlite_check_remaining_sql(tail)) {
if (lstrip_sql(tail) != NULL) {
PyErr_SetString(connection->ProgrammingError,
"You can only execute one statement at a time.");
goto error;
Expand All @@ -82,20 +73,12 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
/* Determine if the statement is a DML statement.
SELECT is the only exception. See #9924. */
int is_dml = 0;
for (const char *p = sql_cstr; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}

const char *p = lstrip_sql(sql_cstr);
if (p != NULL) {
is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
erlend-aasland marked this conversation as resolved.
Show resolved Hide resolved
break;
}

pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
Expand Down Expand Up @@ -139,73 +122,61 @@ stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
}

/*
* Checks if there is anything left in an SQL string after SQLite compiled it.
* This is used to check if somebody tried to execute more than one SQL command
* with one execute()/executemany() command, which the DB-API and we don't
* allow.
* Strip leading whitespace and comments from incoming SQL (null terminated C
* string) and return a pointer to the first non-whitespace, non-comment
* character.
*
* Returns 1 if there is more left than should be. 0 if ok.
* This is used to check if somebody tries to execute more than one SQL query
* with one execute()/executemany() command, which the DB-API don't allow.
*
* It is also used to harden DML query detection.
*/
static int pysqlite_check_remaining_sql(const char* tail)
static inline const char *
lstrip_sql(const char *sql)
{
const char* pos = tail;

parse_remaining_sql_state state = NORMAL;

for (;;) {
// This loop is borrowed from the SQLite source code.
for (const char *pos = sql; *pos; pos++) {
switch (*pos) {
case 0:
return 0;
case '-':
if (state == NORMAL) {
state = LINECOMMENT_1;
} else if (state == LINECOMMENT_1) {
state = IN_LINECOMMENT;
}
break;
case ' ':
case '\t':
break;
case '\f':
case '\n':
case 13:
if (state == IN_LINECOMMENT) {
state = NORMAL;
}
case '\r':
// Skip whitespace.
break;
case '/':
if (state == NORMAL) {
state = COMMENTSTART_1;
} else if (state == COMMENTEND_1) {
state = NORMAL;
} else if (state == COMMENTSTART_1) {
return 1;
case '-':
// Skip line comments.
if (pos[1] == '-') {
pos += 2;
while (pos[0] && pos[0] != '\n') {
pos++;
}
if (pos[0] == '\0') {
return NULL;
}
continue;
}
break;
case '*':
if (state == NORMAL) {
return 1;
} else if (state == LINECOMMENT_1) {
return 1;
} else if (state == COMMENTSTART_1) {
state = IN_COMMENT;
} else if (state == IN_COMMENT) {
state = COMMENTEND_1;
return pos;
case '/':
// Skip C style comments.
if (pos[1] == '*') {
pos += 2;
while (pos[0] && (pos[0] != '*' || pos[1] != '/')) {
pos++;
}
if (pos[0] == '\0') {
return NULL;
}
pos++;
continue;
}
break;
return pos;
default:
if (state == COMMENTEND_1) {
state = IN_COMMENT;
} else if (state == IN_LINECOMMENT) {
} else if (state == IN_COMMENT) {
} else {
return 1;
}
return pos;
}

pos++;
}

return 0;
return NULL;
}

static PyType_Slot stmt_slots[] = {
Expand Down