diff --git a/cratedb_sqlparse_py/README.md b/cratedb_sqlparse_py/README.md index 799dab5..86b6c33 100644 --- a/cratedb_sqlparse_py/README.md +++ b/cratedb_sqlparse_py/README.md @@ -27,7 +27,7 @@ query = """ SELECT * FROM SYS.SHARDS; INSERT INTO doc.tbl VALUES (1); """ -statements = sqlparse(query) +statements = sqlparse(query, raise_exception=True) print(len(statements)) # 2 @@ -43,17 +43,28 @@ print(select_query.type) print(select_query.tree) # (statement (query (queryNoWith (queryTerm (querySpec SELECT (selectItem *) FROM (relation (aliasedRelation (relationPrimary (table (qname (ident (unquotedIdent SYS)) . (ident (unquotedIdent (nonReserved SHARDS))))))))))))) -sqlparse('SUUULECT * FROM sys.shards') -# cratedb_sqlparse.parser.parser.ParsingException: line1:0 mismatched input 'SUUULECT' expecting {'SELECT', 'DEALLOCATE', ...} +sqlparse('SEEELECT * FROM sys.shards') +# cratedb_sqlparse.parser.parser.ParsingException: line1:0 mismatched input 'SEEELECT' expecting {'SELECT', 'DEALLOCATE', ...} ``` ## Development ```shell git clone https://github.com/crate/cratedb-sqlparse + cd cratedb-sqlparse/cratedb_sqlparse_py python3 -m venv .venv source .venv/bin/activate pip install --editable='.[develop,generate,release,test]' poe check ``` + +### Run only tests +```shell +poe test +``` + +### Run only one test +```shell +poe test -k test_sqlparse_collects_exceptions_2 +``` \ No newline at end of file diff --git a/cratedb_sqlparse_py/cratedb_sqlparse/parser.py b/cratedb_sqlparse_py/cratedb_sqlparse/parser.py index a272f07..4a8ca58 100644 --- a/cratedb_sqlparse_py/cratedb_sqlparse/parser.py +++ b/cratedb_sqlparse_py/cratedb_sqlparse/parser.py @@ -1,6 +1,7 @@ +import logging from typing import List -from antlr4 import CommonTokenStream, InputStream, Token +from antlr4 import CommonTokenStream, InputStream, RecognitionException, Token from antlr4.error.ErrorListener import ErrorListener from cratedb_sqlparse.generated_parser.SqlBaseLexer import SqlBaseLexer @@ -30,7 +31,51 @@ def END_DOLLAR_QUOTED_STRING_sempred(self, localctx, predIndex) -> bool: class ParsingException(Exception): - pass + def __init__(self, *, query: str, msg: str, offending_token: Token, e: RecognitionException): + self.message = msg + self.offending_token = offending_token + self.e = e + self.query = query + + @property + def error_message(self): + return f"{self!r}[line {self.line}:{self.column} {self.message}]" + + @property + def original_query_with_error_marked(self): + query = self.offending_token.source[1].strdata + offending_token_text: str = query[self.offending_token.start : self.offending_token.stop + 1] + query_lines: list = query.split("\n") + + offending_line: str = query_lines[self.line - 1] + + # White spaces from the beginning of the offending line to the offending text, so the '^' + # chars are correctly placed below the offending token. + newline_offset = offending_line.index(offending_token_text) + newline = ( + offending_line + + "\n" + + (" " * newline_offset + "^" * (self.offending_token.stop - self.offending_token.start + 1)) + ) + + query_lines[self.line - 1] = newline + + msg = "\n".join(query_lines) + return msg + + @property + def column(self): + return self.offending_token.column + + @property + def line(self): + return self.offending_token.line + + def __repr__(self): + return f"{type(self.e).__qualname__}" + + def __str__(self): + return repr(self) class CaseInsensitiveStream(InputStream): @@ -47,7 +92,34 @@ class ExceptionErrorListener(ErrorListener): """ def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): - raise ParsingException(f"line{line}:{column} {msg}") + error = ParsingException( + msg=msg, + offending_token=offendingSymbol, + e=e, + query=e.ctx.parser.getTokenStream().getText(e.ctx.start, e.offendingToken.tokenIndex), + ) + raise error + + +class ExceptionCollectorListener(ErrorListener): + """ + Error listener that collects all errors into errors for further processing. + + Based partially on https://github.com/antlr/antlr4/issues/396 + """ + + def __init__(self): + self.errors = [] + + def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): + error = ParsingException( + msg=msg, + offending_token=offendingSymbol, + e=e, + query=e.ctx.parser.getTokenStream().getText(e.ctx.start, e.offendingToken.tokenIndex), + ) + + self.errors.append(error) class Statement: @@ -55,8 +127,9 @@ class Statement: Represents a CrateDB SQL statement. """ - def __init__(self, ctx: SqlBaseParser.StatementContext): + def __init__(self, ctx: SqlBaseParser.StatementContext, exception: ParsingException = None): self.ctx: SqlBaseParser.StatementContext = ctx + self.exception = exception @property def tree(self): @@ -77,7 +150,7 @@ def query(self) -> str: """ Returns the query, comments and ';' are not included. """ - return self.ctx.parser.getTokenStream().getText(start=self.ctx.start.tokenIndex, stop=self.ctx.stop.tokenIndex) + return self.ctx.parser.getTokenStream().getText(start=self.ctx.start, stop=self.ctx.stop) @property def type(self): @@ -90,7 +163,20 @@ def __repr__(self): return f'{self.__class__.__qualname__}<{self.query if len(self.query) < 15 else self.query[:15] + "..."}>' -def sqlparse(query: str) -> List[Statement]: +def find_suitable_error(statement, errors): + for error in errors[:]: + # We clean the error_query of ';' and spaces because ironically, + # we can get the full query in the error handler but not in the context. + error_query = error.query + if error_query.endswith(";"): + error_query = error_query[: len(error_query) - 1] + + if error_query.lstrip().rstrip() == statement.query: + statement.exception = error + errors.pop(errors.index(error)) + + +def sqlparse(query: str, raise_exception: bool = False) -> List[Statement]: """ Parses a string into SQL `Statement`. """ @@ -101,12 +187,42 @@ def sqlparse(query: str) -> List[Statement]: parser = SqlBaseParser(stream) parser.removeErrorListeners() - parser.addErrorListener(ExceptionErrorListener()) + error_listener = ExceptionErrorListener() if raise_exception else ExceptionCollectorListener() + parser.addErrorListener(error_listener) tree = parser.statements() - # At this point, all errors are already raised; it's seasonably safe to assume - # that the statements are valid. - statements = list(filter(lambda children: isinstance(children, SqlBaseParser.StatementContext), tree.children)) - - return [Statement(statement) for statement in statements] + statements_context: list[SqlBaseParser.StatementContext] = list( + filter(lambda children: isinstance(children, SqlBaseParser.StatementContext), tree.children) + ) + + statements = [] + for statement_context in statements_context: + _stmt = Statement(statement_context) + find_suitable_error(_stmt, error_listener.errors) + statements.append(_stmt) + + else: + # We might still have error(s) that we couldn't match with their origin statement, + # this happens when the query is composed of only one keyword, e.g. 'SELCT 1' + # the error.query will be 'SELCT' instead of 'SELCT 1'. + if len(error_listener.errors) == 1: + # This case has an edge case where we hypothetically assign the + # wrong error to a statement, for example: + # SELECT A FROM tbl1; + # SELEC 1; + # This would match both conditionals, this however is protected by + # by https://github.com/crate/cratedb-sqlparse/issues/28, but might + # change in the future. + error = error_listener.errors[0] + for _stmt in statements: + if _stmt.exception is None and error.query in _stmt.query: + _stmt.exception = error + break + + if len(error_listener.errors) > 1: + logging.error( + "Could not match errors to queries, too much ambiguity, open an issue with this error and the query." + ) + + return statements diff --git a/cratedb_sqlparse_py/pyproject.toml b/cratedb_sqlparse_py/pyproject.toml index 1175853..4ec41f0 100644 --- a/cratedb_sqlparse_py/pyproject.toml +++ b/cratedb_sqlparse_py/pyproject.toml @@ -205,7 +205,7 @@ check = [ format = [ { cmd = "ruff format ." }, # Configure Ruff not to auto-fix (remove!) unused variables (F841) and `print` statements (T201). - { cmd = "ruff check --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 ." }, + { cmd = "ruff check --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 --ignore=E501 ." }, { cmd = "pyproject-fmt --keep-full-version pyproject.toml" }, ] diff --git a/cratedb_sqlparse_py/tests/test_exceptions.py b/cratedb_sqlparse_py/tests/test_exceptions.py new file mode 100644 index 0000000..44d4da3 --- /dev/null +++ b/cratedb_sqlparse_py/tests/test_exceptions.py @@ -0,0 +1,79 @@ +import pytest + + +def test_exception_message(): + from cratedb_sqlparse import sqlparse + + r = sqlparse(""" + SELEC 1; + SELECT A, B, C, D FROM tbl1; + SELECT D, A FROM tbl1 WHERE; + """) + expected_message = "InputMismatchException[line 2:9 mismatched input 'SELEC' expecting {'SELECT', 'DEALLOCATE', 'FETCH', 'END', 'WITH', 'CREATE', 'ALTER', 'KILL', 'CLOSE', 'BEGIN', 'START', 'COMMIT', 'ANALYZE', 'DISCARD', 'EXPLAIN', 'SHOW', 'OPTIMIZE', 'REFRESH', 'RESTORE', 'DROP', 'INSERT', 'VALUES', 'DELETE', 'UPDATE', 'SET', 'RESET', 'COPY', 'GRANT', 'DENY', 'REVOKE', 'DECLARE'}]" # noqa + expected_message_2 = "\n SELEC 1;\n ^^^^^\n SELECT A, B, C, D FROM tbl1;\n SELECT D, A FROM tbl1 WHERE;\n " # noqa + assert r[0].exception.error_message == expected_message + assert r[0].exception.original_query_with_error_marked == expected_message_2 + + +def test_sqlparse_raises_exception(): + from cratedb_sqlparse import ParsingException, sqlparse + + query = "SELCT 2" + + with pytest.raises(ParsingException): + sqlparse(query, raise_exception=True) + + +def test_sqlparse_collects_exception(): + from cratedb_sqlparse import sqlparse + + query = "SELCT 2" + + statements = sqlparse(query) + assert statements[0] + + +def test_sqlparse_collects_exceptions(): + from cratedb_sqlparse import sqlparse + + r = sqlparse(""" + SELECT A FROM tbl1 where ; + SELECT 1; + SELECT D, A FROM tbl1 WHERE; + """) + + assert len(r) == 3 + + assert r[0].exception is not None + assert r[1].exception is None + assert r[2].exception is not None + + +def test_sqlparse_collects_exceptions_2(): + from cratedb_sqlparse import sqlparse + + # Different combination of the query to validate + r = sqlparse(""" + SELEC 1; + SELECT A, B, C, D FROM tbl1; + SELECT D, A FROM tbl1 WHERE; + """) + + assert r[0].exception is not None + assert r[1].exception is None + assert r[2].exception is not None + + +def test_sqlparse_collects_exceptions_3(): + from cratedb_sqlparse import sqlparse + + # Different combination of the query to validate + r = sqlparse(""" + SELECT 1; + SELECT A, B, C, D FROM tbl1; + INSERT INTO doc.tbl VALUES (1,2, 'three', ['four']); + """) + + assert r[0].exception is None + assert r[1].exception is None + assert r[2].exception is None diff --git a/cratedb_sqlparse_py/tests/test_lexer.py b/cratedb_sqlparse_py/tests/test_lexer.py index a69588e..9016e7d 100644 --- a/cratedb_sqlparse_py/tests/test_lexer.py +++ b/cratedb_sqlparse_py/tests/test_lexer.py @@ -1,6 +1,3 @@ -import pytest - - def test_sqlparser_one_statement(query=None): from cratedb_sqlparse import sqlparse @@ -44,13 +41,20 @@ def test_sqlparse_dollar_string(): assert r[0].query == query -def test_sqlparse_raises_exception(): - from cratedb_sqlparse import ParsingException, sqlparse +def test_sqlparse_multiquery_edge_case(): + # Test for https://github.com/crate/cratedb-sqlparse/issues/28, + # if this ends up parsing 3 statements, we can change this test, + # it's here so we can programmatically track if the behavior changes. + from cratedb_sqlparse import sqlparse - query = "SALUT MON AMIE" + query = """ + SELECT A FROM tbl1 where ; + SELEC 1; + SELECT D, A FROM tbl1 WHERE; +""" - with pytest.raises(ParsingException): - sqlparse(query) + statements = sqlparse(query) + assert len(statements) == 1 def test_sqlparse_is_case_insensitive():