Skip to content

Commit

Permalink
gh-79579: Improve DML query detection in sqlite3 (#93623)
Browse files Browse the repository at this point in the history
The fix involves using pysqlite_check_remaining_sql(), not only to check
for multiple statements, but now also to strip leading comments and
whitespace from SQL statements, so we can improve DML query detection.

pysqlite_check_remaining_sql() is renamed lstrip_sql(), to more
accurately reflect its function, and hardened to handle more SQL comment
corner cases.
  • Loading branch information
erlend-aasland authored Jun 14, 2022
1 parent e566ce5 commit 4674007
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 84 deletions.
66 changes: 56 additions & 10 deletions Lib/test/test_sqlite3/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,22 +746,44 @@ def test_execute_illegal_sql(self):
with self.assertRaises(sqlite.OperationalError):
self.cu.execute("select asdf")

def test_execute_too_much_sql(self):
self.assertRaisesRegex(sqlite.ProgrammingError,
"You can only execute one statement at a time",
self.cu.execute, "select 5+4; select 4+5")

def test_execute_too_much_sql2(self):
self.cu.execute("select 5+4; -- foo bar")
def test_execute_multiple_statements(self):
msg = "You can only execute one statement at a time"
dataset = (
"select 1; select 2",
"select 1; // c++ comments are not allowed",
"select 1; *not a comment",
"select 1; -*not a comment",
"select 1; /* */ a",
"select 1; /**/a",
"select 1; -",
"select 1; /",
"select 1; -\n- select 2",
"""select 1;
-- comment
select 2
""",
)
for query in dataset:
with self.subTest(query=query):
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
self.cu.execute(query)

def test_execute_too_much_sql3(self):
self.cu.execute("""
def test_execute_with_appended_comments(self):
dataset = (
"select 1; -- foo bar",
"select 1; --",
"select 1; /*", # Unclosed comments ending in \0 are skipped.
"""
select 5+4;
/*
foo
*/
""")
""",
)
for query in dataset:
with self.subTest(query=query):
self.cu.execute(query)

def test_execute_wrong_sql_arg(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -906,6 +928,30 @@ def test_rowcount_update_returning(self):
self.assertEqual(self.cu.fetchone()[0], 1)
self.assertEqual(self.cu.rowcount, 1)

def test_rowcount_prefixed_with_comment(self):
# gh-79579: rowcount is updated even if query is prefixed with comments
self.cu.execute("""
-- foo
insert into test(name) values ('foo'), ('foo')
""")
self.assertEqual(self.cu.rowcount, 2)
self.cu.execute("""
/* -- messy *r /* /* ** *- *--
*/
/* one more */ insert into test(name) values ('messy')
""")
self.assertEqual(self.cu.rowcount, 1)
self.cu.execute("/* bar */ update test set name='bar' where name='foo'")
self.assertEqual(self.cu.rowcount, 3)

def test_rowcount_vaccuum(self):
data = ((1,), (2,), (3,))
self.cu.executemany("insert into test(income) values(?)", data)
self.assertEqual(self.cu.rowcount, 3)
self.cx.commit()
self.cu.execute("vacuum")
self.assertEqual(self.cu.rowcount, -1)

def test_total_changes(self):
self.cu.execute("insert into test(name) values ('foo')")
self.cu.execute("insert into test(name) values ('foo')")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
:mod:`sqlite3` now correctly detects DML queries with leading comments.
Patch by Erlend E. Aasland.
119 changes: 45 additions & 74 deletions Modules/_sqlite/statement.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,7 @@
#include "util.h"

/* prototypes */
static int pysqlite_check_remaining_sql(const char* tail);

typedef enum {
LINECOMMENT_1,
IN_LINECOMMENT,
COMMENTSTART_1,
IN_COMMENT,
COMMENTEND_1,
NORMAL
} parse_remaining_sql_state;
static const char *lstrip_sql(const char *sql);

pysqlite_Statement *
pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
Expand Down Expand Up @@ -73,7 +64,7 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
return NULL;
}

if (pysqlite_check_remaining_sql(tail)) {
if (lstrip_sql(tail) != NULL) {
PyErr_SetString(connection->ProgrammingError,
"You can only execute one statement at a time.");
goto error;
Expand All @@ -82,20 +73,12 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
/* Determine if the statement is a DML statement.
SELECT is the only exception. See #9924. */
int is_dml = 0;
for (const char *p = sql_cstr; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}

const char *p = lstrip_sql(sql_cstr);
if (p != NULL) {
is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
break;
}

pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
Expand Down Expand Up @@ -139,73 +122,61 @@ stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
}

/*
* Checks if there is anything left in an SQL string after SQLite compiled it.
* This is used to check if somebody tried to execute more than one SQL command
* with one execute()/executemany() command, which the DB-API and we don't
* allow.
* Strip leading whitespace and comments from incoming SQL (null terminated C
* string) and return a pointer to the first non-whitespace, non-comment
* character.
*
* Returns 1 if there is more left than should be. 0 if ok.
* This is used to check if somebody tries to execute more than one SQL query
* with one execute()/executemany() command, which the DB-API don't allow.
*
* It is also used to harden DML query detection.
*/
static int pysqlite_check_remaining_sql(const char* tail)
static inline const char *
lstrip_sql(const char *sql)
{
const char* pos = tail;

parse_remaining_sql_state state = NORMAL;

for (;;) {
// This loop is borrowed from the SQLite source code.
for (const char *pos = sql; *pos; pos++) {
switch (*pos) {
case 0:
return 0;
case '-':
if (state == NORMAL) {
state = LINECOMMENT_1;
} else if (state == LINECOMMENT_1) {
state = IN_LINECOMMENT;
}
break;
case ' ':
case '\t':
break;
case '\f':
case '\n':
case 13:
if (state == IN_LINECOMMENT) {
state = NORMAL;
}
case '\r':
// Skip whitespace.
break;
case '/':
if (state == NORMAL) {
state = COMMENTSTART_1;
} else if (state == COMMENTEND_1) {
state = NORMAL;
} else if (state == COMMENTSTART_1) {
return 1;
case '-':
// Skip line comments.
if (pos[1] == '-') {
pos += 2;
while (pos[0] && pos[0] != '\n') {
pos++;
}
if (pos[0] == '\0') {
return NULL;
}
continue;
}
break;
case '*':
if (state == NORMAL) {
return 1;
} else if (state == LINECOMMENT_1) {
return 1;
} else if (state == COMMENTSTART_1) {
state = IN_COMMENT;
} else if (state == IN_COMMENT) {
state = COMMENTEND_1;
return pos;
case '/':
// Skip C style comments.
if (pos[1] == '*') {
pos += 2;
while (pos[0] && (pos[0] != '*' || pos[1] != '/')) {
pos++;
}
if (pos[0] == '\0') {
return NULL;
}
pos++;
continue;
}
break;
return pos;
default:
if (state == COMMENTEND_1) {
state = IN_COMMENT;
} else if (state == IN_LINECOMMENT) {
} else if (state == IN_COMMENT) {
} else {
return 1;
}
return pos;
}

pos++;
}

return 0;
return NULL;
}

static PyType_Slot stmt_slots[] = {
Expand Down

0 comments on commit 4674007

Please sign in to comment.