diff --git a/docs/installation.rst b/docs/installation.rst index 0f372bdbed1d3..72fd5b0b40d4e 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -444,8 +444,8 @@ The connection string for Teradata looks like this :: Required environment variables: :: - export ODBCINI=/.../teradata/client/ODBC_64/odbc.ini - export ODBCINST=/.../teradata/client/ODBC_64/odbcinst.ini + export ODBCINI=/.../teradata/client/ODBC_64/odbc.ini + export ODBCINST=/.../teradata/client/ODBC_64/odbcinst.ini See `Teradata SQLAlchemy `_. @@ -816,6 +816,19 @@ in this dictionary are made available for users to use in their SQL. 'my_crazy_macro': lambda x: x*2, } +SQL Lab also includes a live query validation feature with pluggable backends. +You can configure which validation implementation is used with which database +engine by adding a block like the following to your config.py: + +.. code-block:: python + FEATURE_FLAGS = { + 'SQL_VALIDATORS_BY_ENGINE': { + 'presto': 'PrestoDBSQLValidator', + } + } + +The available validators and names can be found in `sql_validators/`. + **Scheduling queries** You can optionally allow your users to schedule queries directly in SQL Lab. @@ -972,7 +985,7 @@ Note that the above command will install Superset into ``default`` namespace of Custom OAuth2 configuration --------------------------- -Beyond FAB supported providers (github, twitter, linkedin, google, azure), its easy to connect Superset with other OAuth2 Authorization Server implementations that support "code" authorization. +Beyond FAB supported providers (github, twitter, linkedin, google, azure), its easy to connect Superset with other OAuth2 Authorization Server implementations that support "code" authorization. The first step: Configure authorization in Superset ``superset_config.py``. @@ -991,10 +1004,10 @@ The first step: Configure authorization in Superset ``superset_config.py``. }, 'access_token_method':'POST', # HTTP Method to call access_token_url 'access_token_params':{ # Additional parameters for calls to access_token_url - 'client_id':'myClientId' + 'client_id':'myClientId' }, - 'access_token_headers':{ # Additional headers for calls to access_token_url - 'Authorization': 'Basic Base64EncodedClientIdAndSecret' + 'access_token_headers':{ # Additional headers for calls to access_token_url + 'Authorization': 'Basic Base64EncodedClientIdAndSecret' }, 'base_url':'https://myAuthorizationServer/oauth2AuthorizationServer/', 'access_token_url':'https://myAuthorizationServer/oauth2AuthorizationServer/token', @@ -1002,25 +1015,25 @@ The first step: Configure authorization in Superset ``superset_config.py``. } } ] - + # Will allow user self registration, allowing to create Flask users from Authorized User AUTH_USER_REGISTRATION = True - + # The default user self registration role AUTH_USER_REGISTRATION_ROLE = "Public" - + Second step: Create a `CustomSsoSecurityManager` that extends `SupersetSecurityManager` and overrides `oauth_user_info`: .. code-block:: python - + from superset.security import SupersetSecurityManager - + class CustomSsoSecurityManager(SupersetSecurityManager): def oauth_user_info(self, provider, response=None): logging.debug("Oauth2 provider: {0}.".format(provider)) if provider == 'egaSSO': - # As example, this line request a GET to base_url + '/' + userDetails with Bearer Authentication, + # As example, this line request a GET to base_url + '/' + userDetails with Bearer Authentication, # and expects that authorization server checks the token, and response with user details me = self.appbuilder.sm.oauth_remotes[provider].get('userDetails').data logging.debug("user_data: {0}".format(me)) @@ -1032,7 +1045,6 @@ This file must be located at the same directory than ``superset_config.py`` with Then we can add this two lines to ``superset_config.py``: .. code-block:: python - + from custom_sso_security_manager import CustomSsoSecurityManager CUSTOM_SECURITY_MANAGER = CustomSsoSecurityManager - diff --git a/superset/config.py b/superset/config.py index df14b0f2c1e55..5a35f0b64128d 100644 --- a/superset/config.py +++ b/superset/config.py @@ -420,6 +420,9 @@ class CeleryConfig(object): # Timeout duration for SQL Lab synchronous queries SQLLAB_TIMEOUT = 30 +# Timeout duration for SQL Lab query validation +SQLLAB_VALIDATION_TIMEOUT = 10 + # SQLLAB_DEFAULT_DBID SQLLAB_DEFAULT_DBID = None @@ -608,6 +611,10 @@ class CeleryConfig(object): # localtime (in the tz where the superset webserver is running) IS_EPOCH_S_TRULY_UTC = False +# Configure which SQL validator to use for each engine +SQL_VALIDATORS_BY_ENGINE = { + 'presto': 'PrestoDBSQLValidator', +} try: if CONFIG_PATH_ENV_VAR in os.environ: diff --git a/superset/sql_validators/__init__.py b/superset/sql_validators/__init__.py new file mode 100644 index 0000000000000..367aab6186beb --- /dev/null +++ b/superset/sql_validators/__init__.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from . import base # noqa +from . import presto_db # noqa +from .base import SQLValidationAnnotation # noqa + + +def get_validator_by_name(name: str) -> Optional[base.BaseSQLValidator]: + return { + 'PrestoDBSQLValidator': presto_db.PrestoDBSQLValidator, + }.get(name) diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py new file mode 100644 index 0000000000000..437001bed0888 --- /dev/null +++ b/superset/sql_validators/base.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=too-few-public-methods + +from typing import ( + Any, + Dict, + List, + Optional, +) + + +class SQLValidationAnnotation: + """Represents a single annotation (error/warning) in an SQL querytext""" + def __init__( + self, + message: str, + line_number: Optional[int], + start_column: Optional[int], + end_column: Optional[int], + ): + self.message = message + self.line_number = line_number + self.start_column = start_column + self.end_column = end_column + + def to_dict(self) -> Dict: + """Return a dictionary representation of this annotation""" + return { + 'line_number': self.line_number, + 'start_column': self.start_column, + 'end_column': self.end_column, + 'message': self.message, + } + + +class BaseSQLValidator: + """BaseSQLValidator defines the interface for checking that a given sql + query is valid for a given database engine.""" + + name = 'BaseSQLValidator' + + @classmethod + def validate( + cls, + sql: str, + schema: str, + database: Any, + ) -> List[SQLValidationAnnotation]: + """Check that the given SQL querystring is valid for the given engine""" + raise NotImplementedError diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py new file mode 100644 index 0000000000000..87c2d8efeb805 --- /dev/null +++ b/superset/sql_validators/presto_db.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from contextlib import closing +import logging +import time +from typing import ( + Any, + Dict, + List, + Optional, +) + +from flask import g +from pyhive.exc import DatabaseError + +from superset import app, security_manager +from superset.sql_parse import ParsedQuery +from superset.sql_validators.base import ( + BaseSQLValidator, + SQLValidationAnnotation, +) +from superset.utils.core import sources + +MAX_ERROR_ROWS = 10 + +config = app.config + + +class PrestoSQLValidationError(Exception): + """Error in the process of asking Presto to validate SQL querytext""" + + +class PrestoDBSQLValidator(BaseSQLValidator): + """Validate SQL queries using Presto's built-in EXPLAIN subtype""" + + name = 'PrestoDBSQLValidator' + + @classmethod + def validate_statement( + cls, + statement, + database, + cursor, + user_name, + ) -> Optional[SQLValidationAnnotation]: + # pylint: disable=too-many-locals + db_engine_spec = database.db_engine_spec + parsed_query = ParsedQuery(statement) + sql = parsed_query.stripped() + + # Hook to allow environment-specific mutation (usually comments) to the SQL + # pylint: disable=invalid-name + SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR') + if SQL_QUERY_MUTATOR: + sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) + + # Transform the final statement to an explain call before sending it on + # to presto to validate + sql = f'EXPLAIN (TYPE VALIDATE) {sql}' + + # Invoke the query against presto. NB this deliberately doesn't use the + # engine spec's handle_cursor implementation since we don't record + # these EXPLAIN queries done in validation as proper Query objects + # in the superset ORM. + try: + db_engine_spec.execute(cursor, sql) + polled = cursor.poll() + while polled: + logging.info('polling presto for validation progress') + stats = polled.get('stats', {}) + if stats: + state = stats.get('state') + if state == 'FINISHED': + break + time.sleep(0.2) + polled = cursor.poll() + db_engine_spec.fetch_data(cursor, MAX_ERROR_ROWS) + return None + except DatabaseError as db_error: + # The pyhive presto client yields EXPLAIN (TYPE VALIDATE) responses + # as though they were normal queries. In other words, it doesn't + # know that errors here are not exceptional. To map this back to + # ordinary control flow, we have to trap the category of exception + # raised by the underlying client, match the exception arguments + # pyhive provides against the shape of dictionary for a presto query + # invalid error, and restructure that error as an annotation we can + # return up. + + # Confirm the first element in the DatabaseError constructor is a + # dictionary with error information. This is currently provided by + # the pyhive client, but may break if their interface changes when + # we update at some point in the future. + if not db_error.args or not isinstance(db_error.args[0], dict): + raise PrestoSQLValidationError( + 'The pyhive presto client returned an unhandled ' + 'database error.', + ) from db_error + error_args: Dict[str, Any] = db_error.args[0] + + # Confirm the two fields we need to be able to present an annotation + # are present in the error response -- a message, and a location. + if 'message' not in error_args: + raise PrestoSQLValidationError( + 'The pyhive presto client did not report an error message', + ) from db_error + if 'errorLocation' not in error_args: + raise PrestoSQLValidationError( + 'The pyhive presto client did not report an error location', + ) from db_error + + # Pylint is confused about the type of error_args, despite the hints + # and checks above. + # pylint: disable=invalid-sequence-index + message = error_args['message'] + err_loc = error_args['errorLocation'] + line_number = err_loc.get('lineNumber', None) + start_column = err_loc.get('columnNumber', None) + end_column = err_loc.get('columnNumber', None) + + return SQLValidationAnnotation( + message=message, + line_number=line_number, + start_column=start_column, + end_column=end_column, + ) + except Exception as e: + logging.exception(f'Unexpected error running validation query: {e}') + raise e + + @classmethod + def validate( + cls, + sql: str, + schema: str, + database: Any, + ) -> List[SQLValidationAnnotation]: + """ + Presto supports query-validation queries by running them with a + prepended explain. + + For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE + VALIDATE) SELECT 1 FROM default.mytable. + """ + user_name = g.user.username if g.user else None + parsed_query = ParsedQuery(sql) + statements = parsed_query.get_statements() + + logging.info(f'Validating {len(statements)} statement(s)') + engine = database.get_sqla_engine( + schema=schema, + nullpool=True, + user_name=user_name, + source=sources.get('sql_lab', None), + ) + # Sharing a single connection and cursor across the + # execution of all statements (if many) + annotations: List[SQLValidationAnnotation] = [] + with closing(engine.raw_connection()) as conn: + with closing(conn.cursor()) as cursor: + for statement in parsed_query.get_statements(): + annotation = cls.validate_statement( + statement, + database, + cursor, + user_name, + ) + if annotation: + annotations.append(annotation) + logging.debug(f'Validation found {len(annotations)} error(s)') + + return annotations diff --git a/superset/views/core.py b/superset/views/core.py index e22acb7410b63..eb25cd0d12dda 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -44,7 +44,7 @@ from werkzeug.utils import secure_filename from superset import ( - app, appbuilder, cache, conf, db, results_backend, + app, appbuilder, cache, conf, db, get_feature_flags, results_backend, security_manager, sql_lab, viz) from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.sqla.models import AnnotationDatasource, SqlaTable @@ -56,6 +56,7 @@ from superset.models.sql_lab import Query from superset.models.user_attributes import UserAttribute from superset.sql_parse import ParsedQuery +from superset.sql_validators import get_validator_by_name from superset.utils import core as utils from superset.utils import dashboard_import_export from superset.utils.dates import now_as_float @@ -2516,6 +2517,72 @@ def stop_query(self): pass return self.json_response('OK') + @has_access_api + @expose('/validate_sql_json/', methods=['POST', 'GET']) + @log_this + def validate_sql_json(self): + """Validates that arbitrary sql is acceptable for the given database. + Returns a list of error/warning annotations as json. + """ + sql = request.form.get('sql') + database_id = request.form.get('database_id') + schema = request.form.get('schema') or None + template_params = json.loads( + request.form.get('templateParams') or '{}') + + if len(template_params) > 0: + # TODO: factor the Database object out of template rendering + # or provide it as mydb so we can render template params + # without having to also persist a Query ORM object. + return json_error_response( + 'SQL validation does not support template parameters', + status=400) + + session = db.session() + mydb = session.query(models.Database).filter_by(id=database_id).first() + if not mydb: + json_error_response( + 'Database with id {} is missing.'.format(database_id), + status=400, + ) + + spec = mydb.db_engine_spec + validators_by_engine = get_feature_flags().get( + 'SQL_VALIDATORS_BY_ENGINE') + if not validators_by_engine or spec.engine not in validators_by_engine: + return json_error_response( + 'no SQL validator is configured for {}'.format(spec.engine), + status=400) + validator_name = validators_by_engine[spec.engine] + validator = get_validator_by_name(validator_name) + if not validator: + return json_error_response( + 'No validator named {} found (configured for the {} engine)' + .format(validator_name, spec.engine)) + + try: + timeout = config.get('SQLLAB_VALIDATION_TIMEOUT') + timeout_msg = ( + f'The query exceeded the {timeout} seconds timeout.') + with utils.timeout(seconds=timeout, + error_message=timeout_msg): + errors = validator.validate(sql, schema, mydb) + payload = json.dumps( + [err.to_dict() for err in errors], + default=utils.pessimistic_json_iso_dttm_ser, + ignore_nan=True, + encoding=None, + ) + return json_success(payload) + except Exception as e: + logging.exception(e) + msg = _( + 'Failed to validate your SQL query text. Please check that ' + f'you have configured the {validator.name} validator ' + 'correctly and that any services it depends on are up. ' + f'Exception: {e}') + return json_error_response(f'{msg}') + @has_access_api @expose('/sql_json/', methods=['POST', 'GET']) @log_this diff --git a/tests/base_tests.py b/tests/base_tests.py index 8555915fd0fe5..6de082afc0066 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -190,6 +190,21 @@ def run_sql(self, sql, client_id=None, user_name=None, raise_on_error=False, raise Exception('run_sql failed') return resp + def validate_sql(self, sql, client_id=None, user_name=None, + raise_on_error=False): + if user_name: + self.logout() + self.login(username=(user_name if user_name else 'admin')) + dbid = get_main_database(db.session).id + resp = self.get_json_resp( + '/superset/validate_sql_json/', + raise_on_error=False, + data=dict(database_id=dbid, sql=sql, client_id=client_id), + ) + if raise_on_error and 'error' in resp: + raise Exception('validate_sql failed') + return resp + @patch.dict('superset._feature_flags', {'FOO': True}, clear=True) def test_existing_feature_flags(self): self.assertTrue(is_feature_enabled('FOO')) diff --git a/tests/sql_validator_tests.py b/tests/sql_validator_tests.py new file mode 100644 index 0000000000000..0e1310cb71af2 --- /dev/null +++ b/tests/sql_validator_tests.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for Sql Lab""" +import unittest +from unittest.mock import ( + MagicMock, + patch, +) + +from pyhive.exc import DatabaseError + +from superset import app +from superset.sql_validators import SQLValidationAnnotation +from superset.sql_validators.base import BaseSQLValidator +from superset.sql_validators.presto_db import ( + PrestoDBSQLValidator, + PrestoSQLValidationError, +) +from .base_tests import SupersetTestCase + +PRESTO_TEST_FEATURE_FLAGS = { + 'SQL_VALIDATORS_BY_ENGINE': { + 'presto': 'PrestoDBSQLValidator', + 'sqlite': 'PrestoDBSQLValidator', + 'postgresql': 'PrestoDBSQLValidator', + 'mysql': 'PrestoDBSQLValidator', + }, +} + + +class SqlValidatorEndpointTests(SupersetTestCase): + """Testing for Sql Lab querytext validation endpoint""" + + def tearDown(self): + self.logout() + + def test_validate_sql_endpoint_noconfig(self): + """Assert that validate_sql_json errors out when no validators are + configured for any db""" + self.login('admin') + + app.config['SQL_VALIDATORS_BY_ENGINE'] = {} + + resp = self.validate_sql( + 'SELECT * FROM ab_user', + client_id='1', + raise_on_error=False, + ) + self.assertIn('error', resp) + self.assertIn('no SQL validator is configured', resp['error']) + + @patch('superset.views.core.get_validator_by_name') + @patch.dict('superset._feature_flags', + PRESTO_TEST_FEATURE_FLAGS, + clear=True) + def test_validate_sql_endpoint_mocked(self, get_validator_by_name): + """Assert that, with a mocked validator, annotations make it back out + from the validate_sql_json endpoint as a list of json dictionaries""" + self.login('admin') + + validator = MagicMock() + get_validator_by_name.return_value = validator + validator.validate.return_value = [ + SQLValidationAnnotation( + message="I don't know what I expected, but it wasn't this", + line_number=4, + start_column=12, + end_column=42, + ), + ] + + resp = self.validate_sql( + 'SELECT * FROM somewhere_over_the_rainbow', + client_id='1', + raise_on_error=False, + ) + + self.assertEqual(1, len(resp)) + self.assertIn('expected,', resp[0]['message']) + + @patch('superset.views.core.get_validator_by_name') + @patch.dict('superset._feature_flags', + PRESTO_TEST_FEATURE_FLAGS, + clear=True) + def test_validate_sql_endpoint_failure(self, get_validator_by_name): + """Assert that validate_sql_json errors out when the selected validator + raises an unexpected exception""" + self.login('admin') + + validator = MagicMock() + get_validator_by_name.return_value = validator + validator.validate.side_effect = Exception('Kaboom!') + + resp = self.validate_sql( + 'SELECT * FROM ab_user', + client_id='1', + raise_on_error=False, + ) + self.assertIn('error', resp) + self.assertIn('Kaboom!', resp['error']) + + +class BaseValidatorTests(SupersetTestCase): + """Testing for the base sql validator""" + def setUp(self): + self.validator = BaseSQLValidator + + def test_validator_excepts(self): + with self.assertRaises(NotImplementedError): + self.validator.validate(None, None, None) + + +class PrestoValidatorTests(SupersetTestCase): + """Testing for the prestodb sql validator""" + def setUp(self): + self.validator = PrestoDBSQLValidator + self.database = MagicMock() # noqa + self.database_engine = self.database.get_sqla_engine.return_value + self.database_conn = self.database_engine.raw_connection.return_value + self.database_cursor = self.database_conn.cursor.return_value + self.database_cursor.poll.return_value = None + + def tearDown(self): + self.logout() + + PRESTO_ERROR_TEMPLATE = { + 'errorLocation': { + 'lineNumber': 10, + 'columnNumber': 20, + }, + 'message': "your query isn't how I like it", + } + + @patch('superset.sql_validators.presto_db.g') + def test_validator_success(self, flask_g): + flask_g.user.username = 'nobody' + sql = 'SELECT 1 FROM default.notarealtable' + schema = 'default' + + errors = self.validator.validate(sql, schema, self.database) + + self.assertEqual([], errors) + + @patch('superset.sql_validators.presto_db.g') + def test_validator_db_error(self, flask_g): + flask_g.user.username = 'nobody' + sql = 'SELECT 1 FROM default.notarealtable' + schema = 'default' + + fetch_fn = self.database.db_engine_spec.fetch_data + fetch_fn.side_effect = DatabaseError('dummy db error') + + with self.assertRaises(PrestoSQLValidationError): + self.validator.validate(sql, schema, self.database) + + @patch('superset.sql_validators.presto_db.g') + def test_validator_unexpected_error(self, flask_g): + flask_g.user.username = 'nobody' + sql = 'SELECT 1 FROM default.notarealtable' + schema = 'default' + + fetch_fn = self.database.db_engine_spec.fetch_data + fetch_fn.side_effect = Exception('a mysterious failure') + + with self.assertRaises(Exception): + self.validator.validate(sql, schema, self.database) + + @patch('superset.sql_validators.presto_db.g') + def test_validator_query_error(self, flask_g): + flask_g.user.username = 'nobody' + sql = 'SELECT 1 FROM default.notarealtable' + schema = 'default' + + fetch_fn = self.database.db_engine_spec.fetch_data + fetch_fn.side_effect = DatabaseError(self.PRESTO_ERROR_TEMPLATE) + + errors = self.validator.validate(sql, schema, self.database) + + self.assertEqual(1, len(errors)) + + def test_validate_sql_endpoint(self): + self.login('admin') + # NB this is effectively an integration test -- when there's a default + # validator for sqlite, this test will fail because the validator + # will no longer error out. + resp = self.validate_sql( + 'SELECT * FROM ab_user', + client_id='1', + raise_on_error=False, + ) + self.assertIn('error', resp) + self.assertIn('no SQL validator is configured', resp['error']) + + +if __name__ == '__main__': + unittest.main()