From eca77d78b528066b2c120d06bd0c247fd463e565 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 7 Mar 2022 15:38:12 -0800 Subject: [PATCH 1/9] feat: helper functions for RLS --- superset/sql_parse.py | 118 ++++++++++++++++++++++++++++ tests/unit_tests/sql_parse_tests.py | 94 +++++++++++++++++++++- 2 files changed, 210 insertions(+), 2 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index b5b614cf25acd..e9ba133f3c9c9 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -27,8 +27,10 @@ IdentifierList, Parenthesis, remove_quotes, + Statement, Token, TokenList, + Where, ) from sqlparse.tokens import ( CTE, @@ -458,3 +460,119 @@ def validate_filter_clause(clause: str) -> None: ) if open_parens > 0: raise QueryClauseValidationException("Unclosed parenthesis in filter clause") + + +def has_table_query(statement: Statement) -> bool: + """ + Return if a stament as a query reading from a table. + + >>> has_table_query(sqlparse.parse("COUNT(*)")[0]) + False + >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0]) + True + + Note that queries reading from constant values return false: + + >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0]) + False + + """ + seen_source = False + tokens = statement.tokens[:] + while tokens: + token = tokens.pop(0) + if isinstance(token, TokenList): + tokens.extend(token.tokens) + + if token.ttype == Keyword and token.value.lower() in ("from", "join"): + seen_source = True + elif seen_source and ( + isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword + ): + return True + elif seen_source and token.ttype not in (Whitespace, Punctuation): + seen_source = False + + return False + + +class InsertRLSState(str, Enum): + """ + State machine that scans for WHERE clauses. + """ + + SCANNING = "SCANNING" + SEEN_SOURCE = "SEEN_SOURCE" + FOUND_TABLE = "FOUND_TABLE" + + +def insert_rls(original_statement: Statement, table: str, rls: Statement) -> Statement: + """ + Update a statement applying a RLS associated with a given table. + """ + statement = sqlparse.parse(str(original_statement))[0] + + state = InsertRLSState.SCANNING + tokens = statement.tokens[:] + while tokens: + token = tokens.pop(0) + + # Found a source keyword (FROM/JOIN) + if token.ttype == Keyword and token.value.lower() in ("from", "join"): + state = InsertRLSState.SEEN_SOURCE + + # Found identifier/keyword after FROM/JOIN, test for table + elif state == InsertRLSState.SEEN_SOURCE and ( + isinstance(token, Identifier) or token.ttype == Keyword + ): + if token.value == table: + state = InsertRLSState.FOUND_TABLE + + # found table at the end of the statement; append a WHERE clause + if not tokens: + statement.tokens.extend( + [ + Token(Whitespace, " "), + Where( + [Token(Keyword, "WHERE"), Token(Whitespace, " "), rls] + ), + ] + ) + return statement + + # Found WHERE clause, insert RLS if not present + elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where): + if str(rls) not in {str(t) for t in token.tokens}: + token.tokens.extend( + [ + Token(Whitespace, " "), + Token(Keyword, "AND"), + Token(Whitespace, " "), + ] + + rls.tokens + ) + state = InsertRLSState.SCANNING + + # No WHERE clause found, insert one + elif state == InsertRLSState.FOUND_TABLE and token.ttype not in ( + Whitespace, + Punctuation, + ): + token.parent.insert_before( + token, Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls,]), + ) + token.parent.insert_before(token, Token(Whitespace, " ")) + state = InsertRLSState.SCANNING + + # Found nothing, leaving source + elif state == InsertRLSState.SEEN_SOURCE and token.ttype not in ( + Whitespace, + Punctuation, + ): + state = InsertRLSState.SCANNING + + # Add children nodes + elif isinstance(token, TokenList): + tokens.extend(token.tokens) + + return statement diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 9026eab212ac7..d00ea1b0919fd 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, too-many-lines import unittest from typing import Set @@ -25,6 +25,8 @@ from superset.exceptions import QueryClauseValidationException from superset.sql_parse import ( + has_table_query, + insert_rls, ParsedQuery, strip_comments_from_sql, Table, @@ -1111,7 +1113,8 @@ def test_sqlparse_formatting(): """ assert sqlparse.format( - "SELECT extract(HOUR from from_unixtime(hour_ts) AT TIME ZONE 'America/Los_Angeles') from table", + "SELECT extract(HOUR from from_unixtime(hour_ts) " + "AT TIME ZONE 'America/Los_Angeles') from table", reindent=True, ) == ( "SELECT extract(HOUR\n from from_unixtime(hour_ts) " @@ -1189,3 +1192,90 @@ def test_sqlparse_issue_652(): stmt = sqlparse.parse(r"foo = '\' AND bar = 'baz'")[0] assert len(stmt.tokens) == 5 assert str(stmt.tokens[0]) == "foo = '\\'" + + +@pytest.mark.parametrize( + "sql,expected", + [ + ("SELECT * FROM table", True), + ("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True), + ("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True), + ("COUNT(*)", False), + ("SELECT a FROM (SELECT 1 AS a)", False), + ("SELECT a FROM (SELECT 1 AS a) JOIN table", True), + ("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False), + ("SELECT * FROM other_table", True), + ], +) +def test_has_table_query(sql: str, expected: bool) -> None: + """ + Test if a given statement queries a table. + + This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing + row-level security. + """ + statement = sqlparse.parse(sql)[0] + assert has_table_query(statement) == expected + + +@pytest.mark.parametrize( + "sql,table,rls,expected", + [ + # append RLS to an existing WHERE clause + ( + "SELECT * FROM other_table WHERE 1=1", + "other_table", + "id=42", + "SELECT * FROM other_table WHERE 1=1 AND id=42", + ), + # "table" is a reserved word; since sqlparse is too aggressive when characterizing + # reserved words we need to support them even when not quoted + ( + "SELECT * FROM table WHERE 1=1", + "table", + "id=42", + "SELECT * FROM table WHERE 1=1 AND id=42", + ), + # RLS applies to a different table + ( + "SELECT * FROM table WHERE 1=1", + "other_table", + "id=42", + "SELECT * FROM table WHERE 1=1", + ), + ( + "SELECT * FROM other_table WHERE 1=1", + "table", + "id=42", + "SELECT * FROM other_table WHERE 1=1", + ), + # insert the WHERE clause if there isn't one + ("SELECT * FROM table", "table", "id=42", "SELECT * FROM table WHERE id=42",), + ( + "SELECT * FROM other_table", + "other_table", + "id=42", + "SELECT * FROM other_table WHERE id=42", + ), + ( + "SELECT * FROM table ORDER BY id", + "table", + "id=42", + "SELECT * FROM table WHERE id=42 ORDER BY id", + ), + # do not add RLS if already present + ( + "SELECT * FROM table WHERE 1=1 AND id=42", + "table", + "id=42", + "SELECT * FROM table WHERE 1=1 AND id=42", + ), + ], +) +def test_insert_rls(sql, table, rls, expected) -> None: + """ + Insert into a statement a given RLS condition associated with a table. + """ + statement = sqlparse.parse(sql)[0] + condition = sqlparse.parse(rls)[0] + assert str(insert_rls(statement, table, condition)).strip() == expected.strip() From 11c2bb3d59409e025dc83db744b3db9e0d1f86d3 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 7 Mar 2022 17:19:57 -0800 Subject: [PATCH 2/9] Add function to inject RLS --- superset/sql_parse.py | 96 +++++++++++++++++++++-------- tests/unit_tests/sql_parse_tests.py | 43 ++++++++++--- 2 files changed, 105 insertions(+), 34 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index e9ba133f3c9c9..0504966285719 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -464,7 +464,7 @@ def validate_filter_clause(clause: str) -> None: def has_table_query(statement: Statement) -> bool: """ - Return if a stament as a query reading from a table. + Return if a stament has a query reading from a table. >>> has_table_query(sqlparse.parse("COUNT(*)")[0]) False @@ -496,9 +496,27 @@ def has_table_query(statement: Statement) -> bool: return False +def add_table_name(rls: TokenList, table: str) -> None: + """ + Modify a RLS expression ensuring columns are fully qualified. + """ + tokens = rls.tokens[:] + while tokens: + token = tokens.pop(0) + + if isinstance(token, Identifier) and token.get_parent_name() is None: + token.tokens = [ + Token(Name, table), + Token(Punctuation, "."), + Token(Name, token.get_name()), + ] + elif isinstance(token, TokenList): + tokens.extend(token.tokens) + + class InsertRLSState(str, Enum): """ - State machine that scans for WHERE clauses. + State machine that scans for WHERE and ON clauses referencing tables. """ SCANNING = "SCANNING" @@ -506,16 +524,20 @@ class InsertRLSState(str, Enum): FOUND_TABLE = "FOUND_TABLE" -def insert_rls(original_statement: Statement, table: str, rls: Statement) -> Statement: +def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: """ - Update a statement applying a RLS associated with a given table. + Update a statement inpalce applying an RLS associated with a given table. """ - statement = sqlparse.parse(str(original_statement))[0] + # make sure the identifier has the table name + add_table_name(rls, table) state = InsertRLSState.SCANNING - tokens = statement.tokens[:] - while tokens: - token = tokens.pop(0) + for token in token_list.tokens: + + # Recurse into child token list + if isinstance(token, TokenList): + i = token_list.tokens.index(token) + token_list.tokens[i] = insert_rls(token, table, rls) # Found a source keyword (FROM/JOIN) if token.ttype == Keyword and token.value.lower() in ("from", "join"): @@ -529,8 +551,8 @@ def insert_rls(original_statement: Statement, table: str, rls: Statement) -> Sta state = InsertRLSState.FOUND_TABLE # found table at the end of the statement; append a WHERE clause - if not tokens: - statement.tokens.extend( + if token == token_list[-1]: + token_list.tokens.extend( [ Token(Whitespace, " "), Where( @@ -538,7 +560,7 @@ def insert_rls(original_statement: Statement, table: str, rls: Statement) -> Sta ), ] ) - return statement + return token_list # Found WHERE clause, insert RLS if not present elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where): @@ -553,26 +575,46 @@ def insert_rls(original_statement: Statement, table: str, rls: Statement) -> Sta ) state = InsertRLSState.SCANNING - # No WHERE clause found, insert one - elif state == InsertRLSState.FOUND_TABLE and token.ttype not in ( - Whitespace, - Punctuation, + # Found ON clause, insert RLS if not present + elif ( + state == InsertRLSState.FOUND_TABLE + and token.ttype == Keyword + and token.value.upper() == "ON" ): - token.parent.insert_before( - token, Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls,]), + i = token_list.tokens.index(token) + token.parent.tokens[i + 1 : i + 1] = [ + Token(Whitespace, " "), + rls, + Token(Whitespace, " "), + Token(Keyword, "AND"), + ] + state = InsertRLSState.SCANNING + + # Found table but no WHERE clause found, insert one + elif state == InsertRLSState.FOUND_TABLE and token.ttype != Whitespace: + i = token_list.tokens.index(token) + + # Left pad with space, if needed + if i > 0 and token_list.tokens[i - 1].ttype != Whitespace: + token_list.tokens.insert(i, Token(Whitespace, " ")) + i += 1 + + # Insert predicate + token_list.tokens.insert( + i, Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]), ) - token.parent.insert_before(token, Token(Whitespace, " ")) + + # Right pad with space, if needed + if ( + i < len(token_list.tokens) - 2 + and token_list.tokens[i + 2] != Whitespace + ): + token_list.tokens.insert(i + 1, Token(Whitespace, " ")) + state = InsertRLSState.SCANNING # Found nothing, leaving source - elif state == InsertRLSState.SEEN_SOURCE and token.ttype not in ( - Whitespace, - Punctuation, - ): + elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace: state = InsertRLSState.SCANNING - # Add children nodes - elif isinstance(token, TokenList): - tokens.extend(token.tokens) - - return statement + return token_list diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index d00ea1b0919fd..0894d350431af 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1226,7 +1226,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM other_table WHERE 1=1", "other_table", "id=42", - "SELECT * FROM other_table WHERE 1=1 AND id=42", + "SELECT * FROM other_table WHERE 1=1 AND other_table.id=42", ), # "table" is a reserved word; since sqlparse is too aggressive when characterizing # reserved words we need to support them even when not quoted @@ -1234,7 +1234,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM table WHERE 1=1", "table", "id=42", - "SELECT * FROM table WHERE 1=1 AND id=42", + "SELECT * FROM table WHERE 1=1 AND table.id=42", ), # RLS applies to a different table ( @@ -1250,25 +1250,54 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM other_table WHERE 1=1", ), # insert the WHERE clause if there isn't one - ("SELECT * FROM table", "table", "id=42", "SELECT * FROM table WHERE id=42",), + ( + "SELECT * FROM table", + "table", + "id=42", + "SELECT * FROM table WHERE table.id=42", + ), ( "SELECT * FROM other_table", "other_table", "id=42", - "SELECT * FROM other_table WHERE id=42", + "SELECT * FROM other_table WHERE other_table.id=42", ), ( "SELECT * FROM table ORDER BY id", "table", "id=42", - "SELECT * FROM table WHERE id=42 ORDER BY id", + "SELECT * FROM table WHERE table.id=42 ORDER BY id", ), - # do not add RLS if already present + # do not add RLS if already present... ( - "SELECT * FROM table WHERE 1=1 AND id=42", + "SELECT * FROM table WHERE 1=1 AND table.id=42", "table", "id=42", + "SELECT * FROM table WHERE 1=1 AND table.id=42", + ), + # ...but when in doubt add it + ( "SELECT * FROM table WHERE 1=1 AND id=42", + "table", + "id=42", + "SELECT * FROM table WHERE 1=1 AND id=42 AND table.id=42", + ), + # test with joins + ( + "SELECT * FROM table JOIN other_table ON table.id = other_table.id", + "other_table", + "id=42", + ( + "SELECT * FROM table JOIN other_table ON other_table.id=42 " + "AND table.id = other_table.id" + ), + ), + # test with inner selects + ( + "SELECT * FROM (SELECT * FROM other_table)", + "other_table", + "id=42", + "SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42)", ), ], ) From 6942b9ec536b8282cfe8d2e0e7d7103f20ab7633 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 7 Mar 2022 19:54:51 -0800 Subject: [PATCH 3/9] Add UNION tests --- tests/unit_tests/sql_parse_tests.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 0894d350431af..d8aa48222e84c 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1299,6 +1299,22 @@ def test_has_table_query(sql: str, expected: bool) -> None: "id=42", "SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42)", ), + # union + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + "table", + "id=42", + "SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table", + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + "other_table", + "id=42", + ( + "SELECT * FROM table UNION ALL " + "SELECT * FROM other_table WHERE other_table.id=42" + ), + ), ], ) def test_insert_rls(sql, table, rls, expected) -> None: From 746a919fd3a376c048c68bc643b9685009626e4b Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 8 Mar 2022 11:10:06 -0800 Subject: [PATCH 4/9] Add tests for schema --- superset/sql_parse.py | 19 ++++++++++++++++++- tests/unit_tests/sql_parse_tests.py | 19 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 0504966285719..82cdfbebe2c09 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -524,6 +524,23 @@ class InsertRLSState(str, Enum): FOUND_TABLE = "FOUND_TABLE" +def matches_table_name(token: Token, table: str) -> bool: + """ + Return the name of a table. + + A table should be represented as an identifier, but due to sqlparse's aggressive list + of keywords (spanning multiple dialects) often it gets classified as a keyword. + """ + candidate = token.value + + # match from right to left, splitting on the period, eg, schema.table == table + for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]): + if left != right: + return False + + return True + + def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: """ Update a statement inpalce applying an RLS associated with a given table. @@ -547,7 +564,7 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: elif state == InsertRLSState.SEEN_SOURCE and ( isinstance(token, Identifier) or token.ttype == Keyword ): - if token.value == table: + if matches_table_name(token, table): state = InsertRLSState.FOUND_TABLE # found table at the end of the statement; append a WHERE clause diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index d8aa48222e84c..174c809d7f5dc 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1315,6 +1315,25 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM other_table WHERE other_table.id=42" ), ), + # fully qualified table names + ( + "SELECT * FROM schema.table_name", + "table_name", + "id=42", + "SELECT * FROM schema.table_name WHERE table_name.id=42", + ), + ( + "SELECT * FROM schema.table_name", + "schema.table_name", + "id=42", + "SELECT * FROM schema.table_name WHERE schema.table_name.id=42", + ), + ( + "SELECT * FROM table_name", + "schema.table_name", + "id=42", + "SELECT * FROM table_name WHERE schema.table_name.id=42", + ), ], ) def test_insert_rls(sql, table, rls, expected) -> None: From 5b32f9b73adf73c65c6a764f5738f4eab2905345 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 8 Mar 2022 14:21:47 -0800 Subject: [PATCH 5/9] Add more tests; cleanup --- superset/sql_parse.py | 2 +- tests/unit_tests/sql_parse_tests.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 82cdfbebe2c09..d06976739ddee 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -484,7 +484,7 @@ def has_table_query(statement: Statement) -> bool: if isinstance(token, TokenList): tokens.extend(token.tokens) - if token.ttype == Keyword and token.value.lower() in ("from", "join"): + if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]): seen_source = True elif seen_source and ( isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 174c809d7f5dc..707cba690e25c 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -25,6 +25,7 @@ from superset.exceptions import QueryClauseValidationException from superset.sql_parse import ( + add_table_name, has_table_query, insert_rls, ParsedQuery, @@ -1343,3 +1344,18 @@ def test_insert_rls(sql, table, rls, expected) -> None: statement = sqlparse.parse(sql)[0] condition = sqlparse.parse(rls)[0] assert str(insert_rls(statement, table, condition)).strip() == expected.strip() + + +@pytest.mark.parametrize( + "rls,table,expected", + [ + ("id=42", "users", "users.id=42"), + ("users.id=42", "users", "users.id=42"), + ("schema.users.id=42", "users", "schema.users.id=42"), + ("false", "users", "false"), + ], +) +def test_add_table_name(rls, table, expected) -> None: + condition = sqlparse.parse(rls)[0] + add_table_name(condition, table) + assert str(condition) == expected From b5bfb9420d1ca49920ff12a8344acf520471d16e Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 8 Mar 2022 14:28:27 -0800 Subject: [PATCH 6/9] has_table_query via tree traversal --- superset/sql_parse.py | 50 +++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index d06976739ddee..e5a8428611fe5 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -27,7 +27,6 @@ IdentifierList, Parenthesis, remove_quotes, - Statement, Token, TokenList, Where, @@ -462,7 +461,17 @@ def validate_filter_clause(clause: str) -> None: raise QueryClauseValidationException("Unclosed parenthesis in filter clause") -def has_table_query(statement: Statement) -> bool: +class InsertRLSState(str, Enum): + """ + State machine that scans for WHERE and ON clauses referencing tables. + """ + + SCANNING = "SCANNING" + SEEN_SOURCE = "SEEN_SOURCE" + FOUND_TABLE = "FOUND_TABLE" + + +def has_table_query(token_list: TokenList) -> bool: """ Return if a stament has a query reading from a table. @@ -477,21 +486,26 @@ def has_table_query(statement: Statement) -> bool: False """ - seen_source = False - tokens = statement.tokens[:] - while tokens: - token = tokens.pop(0) - if isinstance(token, TokenList): - tokens.extend(token.tokens) + state = InsertRLSState.SCANNING + for token in token_list.tokens: + + # # Recurse into child token list + if isinstance(token, TokenList) and has_table_query(token): + return True + # Found a source keyword (FROM/JOIN) if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]): - seen_source = True - elif seen_source and ( + state = InsertRLSState.SEEN_SOURCE + + # Found identifier/keyword after FROM/JOIN + elif state == InsertRLSState.SEEN_SOURCE and ( isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword ): return True - elif seen_source and token.ttype not in (Whitespace, Punctuation): - seen_source = False + + # Found nothing, leaving source + elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace: + state = InsertRLSState.SCANNING return False @@ -514,16 +528,6 @@ def add_table_name(rls: TokenList, table: str) -> None: tokens.extend(token.tokens) -class InsertRLSState(str, Enum): - """ - State machine that scans for WHERE and ON clauses referencing tables. - """ - - SCANNING = "SCANNING" - SEEN_SOURCE = "SEEN_SOURCE" - FOUND_TABLE = "FOUND_TABLE" - - def matches_table_name(token: Token, table: str) -> bool: """ Return the name of a table. @@ -557,7 +561,7 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: token_list.tokens[i] = insert_rls(token, table, rls) # Found a source keyword (FROM/JOIN) - if token.ttype == Keyword and token.value.lower() in ("from", "join"): + if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]): state = InsertRLSState.SEEN_SOURCE # Found identifier/keyword after FROM/JOIN, test for table From b21a19fe6464cbb47c7f5f7afda1cf33d7d8acff Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 10 Mar 2022 09:26:11 -0800 Subject: [PATCH 7/9] Wrap existing predicate in parenthesis --- superset/sql_parse.py | 84 ++++++++++++++++--------- tests/unit_tests/sql_parse_tests.py | 98 ++++++++++++++++++++++------- 2 files changed, 133 insertions(+), 49 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index e5a8428611fe5..48ca5938a4e90 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -23,6 +23,7 @@ import sqlparse from sqlparse.sql import ( + Comparison, Identifier, IdentifierList, Parenthesis, @@ -530,10 +531,13 @@ def add_table_name(rls: TokenList, table: str) -> None: def matches_table_name(token: Token, table: str) -> bool: """ - Return the name of a table. + Returns if the token represents a reference to the table. - A table should be represented as an identifier, but due to sqlparse's aggressive list - of keywords (spanning multiple dialects) often it gets classified as a keyword. + Tables can be fully qualified with periods. + + Note that in theory a table should be represented as an identifier, but due to + sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets + classified as a keyword. """ candidate = token.value @@ -571,44 +575,59 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: if matches_table_name(token, table): state = InsertRLSState.FOUND_TABLE - # found table at the end of the statement; append a WHERE clause - if token == token_list[-1]: - token_list.tokens.extend( - [ - Token(Whitespace, " "), - Where( - [Token(Keyword, "WHERE"), Token(Whitespace, " "), rls] - ), - ] - ) - return token_list - - # Found WHERE clause, insert RLS if not present + # Found WHERE clause, insert RLS. Note that we insert it even it already exists, + # to be on the safe side: it could be present in a clause like `1=1 OR RLS`. elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where): - if str(rls) not in {str(t) for t in token.tokens}: - token.tokens.extend( - [ - Token(Whitespace, " "), - Token(Keyword, "AND"), - Token(Whitespace, " "), - ] - + rls.tokens - ) + if token.tokens[1].ttype != Whitespace: + token.tokens.insert(1, Token(Whitespace, " ")) + token.tokens.insert(2, Token(Punctuation, "(")) + token.tokens.extend( + [ + Token(Punctuation, ")"), + Token(Whitespace, " "), + Token(Keyword, "AND"), + Token(Whitespace, " "), + ] + + rls.tokens + ) state = InsertRLSState.SCANNING - # Found ON clause, insert RLS if not present + # Found ON clause, insert RLS elif ( state == InsertRLSState.FOUND_TABLE and token.ttype == Keyword and token.value.upper() == "ON" ): - i = token_list.tokens.index(token) - token.parent.tokens[i + 1 : i + 1] = [ + tokens = [ Token(Whitespace, " "), rls, Token(Whitespace, " "), Token(Keyword, "AND"), + Token(Whitespace, " "), + Token(Punctuation, "("), + ] + i = token_list.tokens.index(token) + token.parent.tokens[i + 1 : i + 1] = tokens + i += len(tokens) + 2 + + # close parenthesis after last existing comparison + j = 0 + for j, sibling in enumerate(token_list.tokens[i:]): + if ( + sibling.ttype == Keyword + and not imt( + sibling, m=[(Keyword, "AND"), (Keyword, "OR"), (Keyword, "NOT")] + ) + or isinstance(sibling, Where) + ): + j -= 1 + break + token.parent.tokens[i + j + 1 : i + j + 1] = [ + Token(Whitespace, " "), + Token(Punctuation, ")"), + Token(Whitespace, " "), ] + state = InsertRLSState.SCANNING # Found table but no WHERE clause found, insert one @@ -638,4 +657,13 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace: state = InsertRLSState.SCANNING + # found table at the end of the statement; append a WHERE clause + if state == InsertRLSState.FOUND_TABLE: + if token_list.tokens[-1].ttype != Whitespace: + token_list.tokens.append(Token(Whitespace, " ")) + + token_list.tokens.append( + Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]) + ) + return token_list diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 707cba690e25c..208d532e7e475 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1222,22 +1222,32 @@ def test_has_table_query(sql: str, expected: bool) -> None: @pytest.mark.parametrize( "sql,table,rls,expected", [ - # append RLS to an existing WHERE clause + # Basic test: append RLS (some_table.id=42) to an existing WHERE clause. ( - "SELECT * FROM other_table WHERE 1=1", - "other_table", + "SELECT * FROM some_table WHERE 1=1", + "some_table", "id=42", - "SELECT * FROM other_table WHERE 1=1 AND other_table.id=42", + "SELECT * FROM some_table WHERE (1=1) AND some_table.id=42", ), - # "table" is a reserved word; since sqlparse is too aggressive when characterizing - # reserved words we need to support them even when not quoted + # Any existing predicates MUST to be wrapped in parenthesis because AND has higher + # precedence than OR. If the RLS it `1=0` and we didn't add parenthesis a user + # could bypass it by crafting a query with `WHERE TRUE OR FALSE`, since + # `WHERE TRUE OR FALSE AND 1=0` evaluates to `WHERE TRUE OR (FALSE AND 1=0)`. + ( + "SELECT * FROM some_table WHERE TRUE OR FALSE", + "some_table", + "1=0", + "SELECT * FROM some_table WHERE (TRUE OR FALSE) AND 1=0", + ), + # Here "table" is a reserved word; since sqlparse is too aggressive when + # characterizing reserved words we need to support them even when not quoted. ( "SELECT * FROM table WHERE 1=1", "table", "id=42", - "SELECT * FROM table WHERE 1=1 AND table.id=42", + "SELECT * FROM table WHERE (1=1) AND table.id=42", ), - # RLS applies to a different table + # RLS is only applied to queries reading from the associated table. ( "SELECT * FROM table WHERE 1=1", "other_table", @@ -1250,7 +1260,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: "id=42", "SELECT * FROM other_table WHERE 1=1", ), - # insert the WHERE clause if there isn't one + # If there's no pre-existing WHERE clause we create one. ( "SELECT * FROM table", "table", @@ -1258,10 +1268,10 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM table WHERE table.id=42", ), ( - "SELECT * FROM other_table", - "other_table", + "SELECT * FROM some_table", + "some_table", "id=42", - "SELECT * FROM other_table WHERE other_table.id=42", + "SELECT * FROM some_table WHERE some_table.id=42", ), ( "SELECT * FROM table ORDER BY id", @@ -1269,38 +1279,82 @@ def test_has_table_query(sql: str, expected: bool) -> None: "id=42", "SELECT * FROM table WHERE table.id=42 ORDER BY id", ), - # do not add RLS if already present... + ( + "SELECT * FROM some_table;", + "some_table", + "id=42", + "SELECT * FROM some_table WHERE some_table.id=42;", + ), + ( + "SELECT * FROM some_table ;", + "some_table", + "id=42", + "SELECT * FROM some_table WHERE some_table.id=42;", + ), + ( + "SELECT * FROM some_table ", + "some_table", + "id=42", + "SELECT * FROM some_table WHERE some_table.id=42", + ), + # We add the RLS even if it's already present, to be conservative. It should have + # no impact on the query, and it's easier than testing if the RLS is already + # present (it could be present in an OR clause, eg). ( "SELECT * FROM table WHERE 1=1 AND table.id=42", "table", "id=42", - "SELECT * FROM table WHERE 1=1 AND table.id=42", + "SELECT * FROM table WHERE (1=1 AND table.id=42) AND table.id=42", + ), + ( + ( + "SELECT * FROM table JOIN other_table ON " + "table.id = other_table.id AND other_table.id=42" + ), + "other_table", + "id=42", + ( + "SELECT * FROM table JOIN other_table ON other_table.id=42 " + "AND ( table.id = other_table.id AND other_table.id=42 )" + ), ), - # ...but when in doubt add it ( "SELECT * FROM table WHERE 1=1 AND id=42", "table", "id=42", - "SELECT * FROM table WHERE 1=1 AND id=42 AND table.id=42", + "SELECT * FROM table WHERE (1=1 AND id=42) AND table.id=42", ), - # test with joins + # For joins we apply the RLS to the ON clause, since it's easier and prevents + # leaking information about number of rows on OUTER JOINs. ( "SELECT * FROM table JOIN other_table ON table.id = other_table.id", "other_table", "id=42", ( "SELECT * FROM table JOIN other_table ON other_table.id=42 " - "AND table.id = other_table.id" + "AND ( table.id = other_table.id )" + ), + ), + ( + ( + "SELECT * FROM table JOIN other_table ON table.id = other_table.id " + "WHERE 1=1" + ), + "other_table", + "id=42", + ( + "SELECT * FROM table JOIN other_table ON other_table.id=42 " + "AND ( table.id = other_table.id ) WHERE 1=1" ), ), - # test with inner selects + # Subqueries also work, as expected. ( "SELECT * FROM (SELECT * FROM other_table)", "other_table", "id=42", "SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42)", ), - # union + # As well as UNION. ( "SELECT * FROM table UNION ALL SELECT * FROM other_table", "table", @@ -1316,7 +1370,9 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM other_table WHERE other_table.id=42" ), ), - # fully qualified table names + # When comparing fully qualified table names (eg, schema.table) to simple names + # (eg, table) we are also conservative, assuming the schema is the same, since + # we don't have information on the default schema. ( "SELECT * FROM schema.table_name", "table_name", From e816857c53c9b9e192ec3bd71df8b15a7fe9b005 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 10 Mar 2022 09:37:23 -0800 Subject: [PATCH 8/9] Clean up logic --- superset/sql_parse.py | 42 +++++++++++------------------ tests/unit_tests/sql_parse_tests.py | 22 +++++++-------- 2 files changed, 26 insertions(+), 38 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 48ca5938a4e90..06d6ed5266128 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -23,7 +23,6 @@ import sqlparse from sqlparse.sql import ( - Comparison, Identifier, IdentifierList, Parenthesis, @@ -578,9 +577,7 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: # Found WHERE clause, insert RLS. Note that we insert it even it already exists, # to be on the safe side: it could be present in a clause like `1=1 OR RLS`. elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where): - if token.tokens[1].ttype != Whitespace: - token.tokens.insert(1, Token(Whitespace, " ")) - token.tokens.insert(2, Token(Punctuation, "(")) + token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")] token.tokens.extend( [ Token(Punctuation, ")"), @@ -592,7 +589,9 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: ) state = InsertRLSState.SCANNING - # Found ON clause, insert RLS + # Found ON clause, insert RLS. The logic for ON is more complicated than the logic + # for WHERE because in the former the comparisons are siblings, while on the + # latter they are children. elif ( state == InsertRLSState.FOUND_TABLE and token.ttype == Keyword @@ -613,6 +612,7 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: # close parenthesis after last existing comparison j = 0 for j, sibling in enumerate(token_list.tokens[i:]): + # scan until we hit a non-comparison keyword (like ORDER BY) or a WHERE if ( sibling.ttype == Keyword and not imt( @@ -633,23 +633,11 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: # Found table but no WHERE clause found, insert one elif state == InsertRLSState.FOUND_TABLE and token.ttype != Whitespace: i = token_list.tokens.index(token) - - # Left pad with space, if needed - if i > 0 and token_list.tokens[i - 1].ttype != Whitespace: - token_list.tokens.insert(i, Token(Whitespace, " ")) - i += 1 - - # Insert predicate - token_list.tokens.insert( - i, Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]), - ) - - # Right pad with space, if needed - if ( - i < len(token_list.tokens) - 2 - and token_list.tokens[i + 2] != Whitespace - ): - token_list.tokens.insert(i + 1, Token(Whitespace, " ")) + token_list.tokens[i:i] = [ + Token(Whitespace, " "), + Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]), + Token(Whitespace, " "), + ] state = InsertRLSState.SCANNING @@ -659,11 +647,11 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: # found table at the end of the statement; append a WHERE clause if state == InsertRLSState.FOUND_TABLE: - if token_list.tokens[-1].ttype != Whitespace: - token_list.tokens.append(Token(Whitespace, " ")) - - token_list.tokens.append( - Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]) + token_list.tokens.extend( + [ + Token(Whitespace, " "), + Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]), + ] ) return token_list diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 208d532e7e475..9a67e58ea2afc 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1227,7 +1227,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM some_table WHERE 1=1", "some_table", "id=42", - "SELECT * FROM some_table WHERE (1=1) AND some_table.id=42", + "SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42", ), # Any existing predicates MUST to be wrapped in parenthesis because AND has higher # precedence than OR. If the RLS it `1=0` and we didn't add parenthesis a user @@ -1237,7 +1237,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM some_table WHERE TRUE OR FALSE", "some_table", "1=0", - "SELECT * FROM some_table WHERE (TRUE OR FALSE) AND 1=0", + "SELECT * FROM some_table WHERE ( TRUE OR FALSE) AND 1=0", ), # Here "table" is a reserved word; since sqlparse is too aggressive when # characterizing reserved words we need to support them even when not quoted. @@ -1245,7 +1245,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM table WHERE 1=1", "table", "id=42", - "SELECT * FROM table WHERE (1=1) AND table.id=42", + "SELECT * FROM table WHERE ( 1=1) AND table.id=42", ), # RLS is only applied to queries reading from the associated table. ( @@ -1277,25 +1277,25 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM table ORDER BY id", "table", "id=42", - "SELECT * FROM table WHERE table.id=42 ORDER BY id", + "SELECT * FROM table WHERE table.id=42 ORDER BY id", ), ( "SELECT * FROM some_table;", "some_table", "id=42", - "SELECT * FROM some_table WHERE some_table.id=42;", + "SELECT * FROM some_table WHERE some_table.id=42 ;", ), ( "SELECT * FROM some_table ;", "some_table", "id=42", - "SELECT * FROM some_table WHERE some_table.id=42;", + "SELECT * FROM some_table WHERE some_table.id=42 ;", ), ( "SELECT * FROM some_table ", "some_table", "id=42", - "SELECT * FROM some_table WHERE some_table.id=42", + "SELECT * FROM some_table WHERE some_table.id=42", ), # We add the RLS even if it's already present, to be conservative. It should have # no impact on the query, and it's easier than testing if the RLS is already @@ -1304,7 +1304,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM table WHERE 1=1 AND table.id=42", "table", "id=42", - "SELECT * FROM table WHERE (1=1 AND table.id=42) AND table.id=42", + "SELECT * FROM table WHERE ( 1=1 AND table.id=42) AND table.id=42", ), ( ( @@ -1322,7 +1322,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM table WHERE 1=1 AND id=42", "table", "id=42", - "SELECT * FROM table WHERE (1=1 AND id=42) AND table.id=42", + "SELECT * FROM table WHERE ( 1=1 AND id=42) AND table.id=42", ), # For joins we apply the RLS to the ON clause, since it's easier and prevents # leaking information about number of rows on OUTER JOINs. @@ -1352,14 +1352,14 @@ def test_has_table_query(sql: str, expected: bool) -> None: "SELECT * FROM (SELECT * FROM other_table)", "other_table", "id=42", - "SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42)", + "SELECT * FROM (SELECT * FROM other_table WHERE other_table.id=42 )", ), # As well as UNION. ( "SELECT * FROM table UNION ALL SELECT * FROM other_table", "table", "id=42", - "SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table", + "SELECT * FROM table WHERE table.id=42 UNION ALL SELECT * FROM other_table", ), ( "SELECT * FROM table UNION ALL SELECT * FROM other_table", From 00598fa167f45fe4f336368ee7247d275f60557b Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 11 Mar 2022 11:39:04 -0800 Subject: [PATCH 9/9] Improve table matching --- superset/sql_parse.py | 15 ++++++++++----- tests/unit_tests/sql_parse_tests.py | 21 +++++++++++++++++++-- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 06d6ed5266128..f5523bab71e8d 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -528,7 +528,7 @@ def add_table_name(rls: TokenList, table: str) -> None: tokens.extend(token.tokens) -def matches_table_name(token: Token, table: str) -> bool: +def matches_table_name(candidate: Token, table: str) -> bool: """ Returns if the token represents a reference to the table. @@ -538,11 +538,16 @@ def matches_table_name(token: Token, table: str) -> bool: sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets classified as a keyword. """ - candidate = token.value + if not isinstance(candidate, Identifier): + candidate = Identifier([Token(Name, candidate.value)]) + + target = sqlparse.parse(table)[0].tokens[0] + if not isinstance(target, Identifier): + target = Identifier([Token(Name, target.value)]) # match from right to left, splitting on the period, eg, schema.table == table - for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]): - if left != right: + for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]): + if left.value != right.value: return False return True @@ -550,7 +555,7 @@ def matches_table_name(token: Token, table: str) -> bool: def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: """ - Update a statement inpalce applying an RLS associated with a given table. + Update a statement inplace applying an RLS associated with a given table. """ # make sure the identifier has the table name add_table_name(rls, table) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 9a67e58ea2afc..aa811bdef757e 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -28,6 +28,7 @@ add_table_name, has_table_query, insert_rls, + matches_table_name, ParsedQuery, strip_comments_from_sql, Table, @@ -1206,6 +1207,7 @@ def test_sqlparse_issue_652(): ("SELECT a FROM (SELECT 1 AS a) JOIN table", True), ("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False), ("SELECT * FROM other_table", True), + ("extract(HOUR from from_unixtime(hour_ts)", False), ], ) def test_has_table_query(sql: str, expected: bool) -> None: @@ -1393,7 +1395,7 @@ def test_has_table_query(sql: str, expected: bool) -> None: ), ], ) -def test_insert_rls(sql, table, rls, expected) -> None: +def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None: """ Insert into a statement a given RLS condition associated with a table. """ @@ -1411,7 +1413,22 @@ def test_insert_rls(sql, table, rls, expected) -> None: ("false", "users", "false"), ], ) -def test_add_table_name(rls, table, expected) -> None: +def test_add_table_name(rls: str, table: str, expected: str) -> None: condition = sqlparse.parse(rls)[0] add_table_name(condition, table) assert str(condition) == expected + + +@pytest.mark.parametrize( + "candidate,table,expected", + [ + ("table", "table", True), + ("schema.table", "table", True), + ("table", "schema.table", True), + ('schema."my table"', '"my table"', True), + ('schema."my.table"', '"my.table"', True), + ], +) +def test_matches_table_name(candidate: str, table: str, expected: bool) -> None: + token = sqlparse.parse(candidate)[0].tokens[0] + assert matches_table_name(token, table) == expected