Skip to content

Commit

Permalink
feat(sqlparse): improve table parsing (#26476)
Browse files Browse the repository at this point in the history
(cherry picked from commit c0b57bd)
  • Loading branch information
betodealmeida authored and michael-s-molina committed Feb 1, 2024
1 parent 6cdaf47 commit 1d9cfda
Show file tree
Hide file tree
Showing 17 changed files with 265 additions and 120 deletions.
15 changes: 12 additions & 3 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ geographiclib==1.52
geopy==2.2.0
# via apache-superset
greenlet==2.0.2
# via shillelagh
# via
# shillelagh
# sqlalchemy
gunicorn==21.2.0
# via apache-superset
hashids==1.3.1
Expand All @@ -155,7 +157,10 @@ idna==3.2
# email-validator
# requests
importlib-metadata==6.6.0
# via apache-superset
# via
# apache-superset
# flask
# shillelagh
importlib-resources==5.12.0
# via limits
isodate==0.6.0
Expand Down Expand Up @@ -327,6 +332,8 @@ sqlalchemy-utils==0.38.3
# via
# apache-superset
# flask-appbuilder
sqlglot==20.8.0
# via apache-superset
sqlparse==0.4.4
# via apache-superset
sshtunnel==0.4.0
Expand Down Expand Up @@ -376,7 +383,9 @@ wtforms-json==0.3.5
xlsxwriter==3.0.7
# via apache-superset
zipp==3.15.0
# via importlib-metadata
# via
# importlib-metadata
# importlib-resources

# The following packages are considered to be unsafe in a requirements file:
# setuptools
6 changes: 2 additions & 4 deletions requirements/testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ db-dtypes==1.1.1
# via pandas-gbq
docker==6.1.1
# via -r requirements/testing.in
exceptiongroup==1.1.1
# via pytest
ephem==4.1.4
# via lunarcalendar
flask-testing==0.8.1
# via -r requirements/testing.in
fonttools==4.39.4
Expand Down Expand Up @@ -121,6 +117,8 @@ pyee==9.0.4
# via playwright
pyfakefs==5.2.2
# via -r requirements/testing.in
pyhive[presto]==0.7.0
# via apache-superset
pytest==7.3.1
# via
# -r requirements/testing.in
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def get_git_sha() -> str:
"slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.38.3, <0.39",
"sqlglot>=20,<21",
"sqlparse>=0.4.4, <0.5",
"tabulate>=0.8.9, <0.9",
"typing-extensions>=4, <5",
Expand Down
5 changes: 4 additions & 1 deletion superset/commands/dataset/duplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def run(self) -> Model:
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
table.sql = ParsedQuery(self._base_model.sql).stripped()
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
db.session.add(table)
cols = []
for config_ in self._base_model.columns:
Expand Down
5 changes: 4 additions & 1 deletion superset/commands/sql_lab/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def run(
limit = None
else:
sql = self._query.executed_sql
limit = ParsedQuery(sql).limit
limit = ParsedQuery(
sql,
engine=self._query.database.db_engine_spec.engine,
).limit
if limit is not None and self._query.limiting_factor in {
LimitingFactor.QUERY,
LimitingFactor.DROPDOWN,
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,7 @@ def get_from_clause(
return self.get_sqla_table(), None

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
parsed_query = ParsedQuery(sql)
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query):
raise SupersetSecurityException(
SupersetError(
Expand Down
12 changes: 6 additions & 6 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def apply_limit_to_sql(
return database.compile_sqla_query(qry)

if cls.limit_method == LimitMethod.FORCE_LIMIT:
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
sql = parsed_query.set_or_update_query_limit(limit, force=force)

return sql
Expand Down Expand Up @@ -981,7 +981,7 @@ def get_limit_from_sql(cls, sql: str) -> int | None:
:param sql: SQL query
:return: Value of limit clause in query
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.limit

@classmethod
Expand All @@ -993,7 +993,7 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str:
:param limit: New limit to insert/replace into query
:return: Query with new limit
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.set_or_update_query_limit(limit)

@classmethod
Expand Down Expand Up @@ -1490,7 +1490,7 @@ def process_statement(cls, statement: str, database: Database) -> str:
:param database: Database instance
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
parsed_query = ParsedQuery(statement, engine=cls.engine)
sql = parsed_query.stripped()
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
Expand Down Expand Up @@ -1525,7 +1525,7 @@ def estimate_query_cost(
"Database does not support cost estimation"
)

parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()

costs = []
Expand Down Expand Up @@ -1586,7 +1586,7 @@ def execute( # pylint: disable=unused-argument
:return:
"""
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query)
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)

if cls.arraysize:
cursor.arraysize = cls.arraysize
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def estimate_query_cost(
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")

parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
costs = []
for statement in statements:
Expand Down
2 changes: 1 addition & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def get_from_clause(
"""

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
Expand Down
6 changes: 4 additions & 2 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def username(self) -> str:

@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)

@property
def columns(self) -> list["TableColumn"]:
Expand Down Expand Up @@ -427,7 +427,9 @@ def url(self) -> str:

@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
return list(
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
)

@property
def last_run_humanized(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,10 @@ def raise_for_access(
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
for table_ in sql_parse.ParsedQuery(query.sql).tables
for table_ in sql_parse.ParsedQuery(
query.sql,
engine=database.db_engine_spec.engine,
).tables
}
elif table:
tables = {table}
Expand Down
11 changes: 8 additions & 3 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
database: Database = query.database
db_engine_spec = database.db_engine_spec

parsed_query = ParsedQuery(sql_statement)
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
Expand All @@ -228,7 +228,8 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
database.id,
query.schema,
)
)
),
engine=db_engine_spec.engine,
)

sql = parsed_query.stripped()
Expand Down Expand Up @@ -419,7 +420,11 @@ def execute_sql_statements(
)

# Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query, strip_comments=True)
parsed_query = ParsedQuery(
rendered_query,
strip_comments=True,
engine=db_engine_spec.engine,
)
if not db_engine_spec.run_multiple_statements_as_one:
statements = parsed_query.get_statements()
logger.info(
Expand Down
Loading

0 comments on commit 1d9cfda

Please sign in to comment.