diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 71c6ab20d9b6f..c9ed303473a6e 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -29,6 +29,8 @@ from sqlalchemy.sql import text from flask_babel import lazy_gettext as _ +from sqlalchemy.engine.url import make_url + from superset.utils import SupersetTemplateException from superset.utils import QueryStatus from superset import conf, cache_util, utils @@ -184,6 +186,28 @@ def select_star(cls, my_db, table_name, schema=None, limit=100, sql = sqlparse.format(sql, reindent=True) return sql + @classmethod + def modify_url_for_impersonation(cls, url, impersonate_user, username): + """ + Modify the SQL Alchemy URL object with the user to impersonate if applicable. + :param url: SQLAlchemy URL object + :param impersonate_user: Bool indicating if impersonation is enabled + :param username: Effective username + """ + if impersonate_user is not None and username is not None: + url.username = username + + @classmethod + def get_uri_for_impersonation(cls, uri, impersonate_user, username): + """ + Return a new URI string that allows for user impersonation. + :param uri: URI string + :param impersonate_user: Bool indicating if impersonation is enabled + :param username: Effective username + :return: New URI string + """ + return uri + class PostgresEngineSpec(BaseEngineSpec): engine = 'postgresql' @@ -677,6 +701,7 @@ def patch(cls): hive.constants = patched_constants hive.ttypes = patched_ttypes hive.Cursor.fetch_logs = patched_hive.fetch_logs + hive.Connection = patched_hive.ConnectionProxyUser @classmethod @cache_util.memoized_func( @@ -830,6 +855,35 @@ def _partition_query( cls, table_name, limit=0, order_by=None, filters=None): return "SHOW PARTITIONS {table_name}".format(**locals()) + @classmethod + def modify_url_for_impersonation(cls, url, impersonate_user, username): + """ + Modify the SQL Alchemy URL object with the user to impersonate if applicable. + :param url: SQLAlchemy URL object + :param impersonate_user: Bool indicating if impersonation is enabled + :param username: Effective username + """ + if impersonate_user is not None and "auth" in url.query.keys() and username is not None: + url.query["hive_server2_proxy_user"] = username + + @classmethod + def get_uri_for_impersonation(cls, uri, impersonate_user, username): + """ + Return a new URI string that allows for user impersonation. + :param uri: URI string + :param impersonate_user: Bool indicating if impersonation is enabled + :param username: Effective username + :return: New URI string + """ + new_uri = uri + url = make_url(uri) + backend_name = url.get_backend_name() + + # Must be Hive connection, enable impersonation, and set param auth=LDAP|KERBEROS + if backend_name == "hive" and "auth" in url.query.keys() and\ + impersonate_user is True and username is not None: + new_uri += "&hive_server2_proxy_user={0}".format(username) + return new_uri class MssqlEngineSpec(BaseEngineSpec): engine = 'mssql' diff --git a/superset/db_engines/hive.py b/superset/db_engines/hive.py index f14608410823a..334ae0a4b2d87 100644 --- a/superset/db_engines/hive.py +++ b/superset/db_engines/hive.py @@ -3,6 +3,28 @@ from thrift import Thrift +old_Connection = hive.Connection + +# TODO +# Monkey-patch of PyHive project's pyhive/hive.py which needed to change the constructor. +# Submitted a pull request on October 13, 2017 and waiting for it to be merged. +# https://github.com/dropbox/PyHive/pull/165 +class ConnectionProxyUser(hive.Connection): + + def __init__(self, host=None, port=None, username=None, database='default', auth=None, + configuration=None, kerberos_service_name=None, password=None, + thrift_transport=None, hive_server2_proxy_user=None): + configuration = configuration or {} + if auth is not None and auth in ('LDAP', 'KERBEROS'): + if hive_server2_proxy_user is not None: + configuration["hive.server2.proxy.user"] = hive_server2_proxy_user + # restore the old connection class, otherwise, will recurse on its own __init__ method + hive.Connection = old_Connection + hive.Connection.__init__(self, host=host, port=port, username=username, database=database, auth=auth, + configuration=configuration, kerberos_service_name=kerberos_service_name, password=password, + thrift_transport=thrift_transport) + + # TODO: contribute back to pyhive. def fetch_logs(self, max_rows=1024, orientation=ttypes.TFetchOrientation.FETCH_NEXT): diff --git a/superset/models/core.py b/superset/models/core.py index 7392e8796a571..1a795c2ebde22 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -13,6 +13,7 @@ from future.standard_library import install_aliases from copy import copy from datetime import datetime, date +from copy import deepcopy import pandas as pd import sqlalchemy as sqla @@ -47,6 +48,7 @@ stats_logger = config.get('STATS_LOGGER') metadata = Model.metadata # pylint: disable=no-member +PASSWORD_MASK = "X" * 10 def set_related_perm(mapper, connection, target): # noqa src_class = target.cls_model @@ -581,30 +583,56 @@ def backend(self): url = make_url(self.sqlalchemy_uri_decrypted) return url.get_backend_name() + @classmethod + def get_password_masked_url_from_uri(cls, uri): + url = make_url(uri) + return cls.get_password_masked_url(url) + + @classmethod + def get_password_masked_url(cls, url): + url_copy = deepcopy(url) + if url_copy.password is not None and url_copy.password != PASSWORD_MASK: + url_copy.password = PASSWORD_MASK + return url_copy + def set_sqlalchemy_uri(self, uri): - password_mask = "X" * 10 conn = sqla.engine.url.make_url(uri) - if conn.password != password_mask and not self.custom_password_store: + if conn.password != PASSWORD_MASK and not self.custom_password_store: # do not over-write the password with the password mask self.password = conn.password - conn.password = password_mask if conn.password else None + conn.password = PASSWORD_MASK if conn.password else None self.sqlalchemy_uri = str(conn) # hides the password + def get_effective_user(self, url, user_name=None): + """ + Get the effective user, especially during impersonation. + :param url: SQL Alchemy URL object + :param user_name: Default username + :return: The effective username + """ + effective_username = None + if self.impersonate_user: + effective_username = url.username + if user_name: + effective_username = user_name + elif hasattr(g, 'user') and g.user.username: + effective_username = g.user.username + return effective_username + def get_sqla_engine(self, schema=None, nullpool=False, user_name=None): extra = self.get_extra() - uri = make_url(self.sqlalchemy_uri_decrypted) + url = make_url(self.sqlalchemy_uri_decrypted) params = extra.get('engine_params', {}) if nullpool: params['poolclass'] = NullPool - uri = self.db_engine_spec.adjust_database_uri(uri, schema) - if self.impersonate_user: - eff_username = uri.username - if user_name: - eff_username = user_name - elif hasattr(g, 'user') and g.user.username: - eff_username = g.user.username - uri.username = eff_username - return create_engine(uri, **params) + url = self.db_engine_spec.adjust_database_uri(url, schema) + effective_username = self.get_effective_user(url, user_name) + self.db_engine_spec.modify_url_for_impersonation(url, self.impersonate_user, effective_username) + + masked_url = self.get_password_masked_url(url) + logging.info("Database.get_sqla_engine(). Masked URL: {0}".format(masked_url)) + + return create_engine(url, **params) def get_reserved_words(self): return self.get_sqla_engine().dialect.preparer.reserved_words @@ -688,6 +716,10 @@ def db_engine_spec(self): return db_engine_specs.engines.get( self.backend, db_engine_specs.BaseEngineSpec) + @classmethod + def get_db_engine_spec_for_backend(cls, backend): + return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec) + def grains(self): """Defines time granularity database-specific expressions. diff --git a/superset/sql_lab.py b/superset/sql_lab.py index aeb71f6b79a0e..d9a8e800f7932 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -168,6 +168,7 @@ def handle_error(msg): session.merge(query) session.commit() logging.info("Set query to 'running'") + conn = None try: engine = database.get_sqla_engine( schema=query.schema, nullpool=not ctask.request.called_directly, user_name=user_name) @@ -183,20 +184,23 @@ def handle_error(msg): data = db_engine_spec.fetch_data(cursor, query.limit) except SoftTimeLimitExceeded as e: logging.exception(e) - conn.close() + if conn is not None: + conn.close() return handle_error( "SQL Lab timeout. This environment's policy is to kill queries " "after {} seconds.".format(SQLLAB_TIMEOUT)) except Exception as e: logging.exception(e) - conn.close() + if conn is not None: + conn.close() return handle_error(db_engine_spec.extract_error_message(e)) logging.info("Fetching cursor description") cursor_description = cursor.description - conn.commit() - conn.close() + if conn is not None: + conn.commit() + conn.close() if query.status == utils.QueryStatus.STOPPED: return json.dumps( diff --git a/superset/templates/superset/models/database/macros.html b/superset/templates/superset/models/database/macros.html index e66854c3c4f5c..ec20da1d8fefb 100644 --- a/superset/templates/superset/models/database/macros.html +++ b/superset/templates/superset/models/database/macros.html @@ -20,6 +20,7 @@ data = JSON.stringify({ uri: $("#sqlalchemy_uri").val(), name: $('#database_name').val(), + impersonate_user: $('#impersonate_user').is(':checked'), extras: JSON.parse($("#extra").val()), }) } catch(parse_error){ diff --git a/superset/views/core.py b/superset/views/core.py index 57710bbaa5c92..532faea7d366a 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -28,6 +28,7 @@ from flask_babel import lazy_gettext as _ from sqlalchemy import create_engine +from sqlalchemy.engine.url import make_url from werkzeug.routing import BaseConverter from superset import ( @@ -236,8 +237,10 @@ class DatabaseView(SupersetModelView, DeleteMixin): # noqa "(http://docs.sqlalchemy.org/en/rel_1_0/core/metadata.html" "#sqlalchemy.schema.MetaData) call. ", True), 'impersonate_user': _( - "All the queries in Sql Lab are going to be executed " - "on behalf of currently authorized user."), + "If Presto, all the queries in SQL Lab are going to be executed as the currently logged on user " + "who must have permission to run them.
" + "If Hive and hive.server2.enable.doAs is enabled, will run the queries as service account, " + "but impersonate the currently logged on user via hive.server2.proxy.user property."), } label_columns = { 'expose_in_sqllab': _("Expose in SQL Lab"), @@ -252,7 +255,7 @@ class DatabaseView(SupersetModelView, DeleteMixin): # noqa 'extra': _("Extra"), 'allow_run_sync': _("Allow Run Sync"), 'allow_run_async': _("Allow Run Async"), - 'impersonate_user': _("Impersonate queries to the database"), + 'impersonate_user': _("Impersonate the logged on user") } def pre_add(self, db): @@ -1415,8 +1418,10 @@ def add_slices(self, dashboard_id): def testconn(self): """Tests a sqla connection""" try: + username = g.user.username if g.user is not None else None uri = request.json.get('uri') db_name = request.json.get('name') + impersonate_user = request.json.get('impersonate_user') if db_name: database = ( db.session @@ -1428,6 +1433,15 @@ def testconn(self): # the password-masked uri was passed # use the URI associated with this database uri = database.sqlalchemy_uri_decrypted + + url = make_url(uri) + db_engine = models.Database.get_db_engine_spec_for_backend(url.get_backend_name()) + db_engine.patch() + uri = db_engine.get_uri_for_impersonation(uri, impersonate_user, username) + masked_url = database.get_password_masked_url_from_uri(uri) + + logging.info("Superset.testconn(). Masked URL: {0}".format(masked_url)) + connect_args = ( request.json .get('extras', {}) diff --git a/tests/core_tests.py b/tests/core_tests.py index e68dea96ad231..c5335538dbe74 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -276,13 +276,15 @@ def test_misc(self): assert self.get_resp('/health') == "OK" assert self.get_resp('/ping') == "OK" - def test_testconn(self): + def test_testconn(self, username='admin'): + self.login(username=username) database = self.get_main_database(db.session) # validate that the endpoint works with the password-masked sqlalchemy uri data = json.dumps({ 'uri': database.safe_sqlalchemy_uri(), - 'name': 'main' + 'name': 'main', + 'impersonate_user': False }) response = self.client.post('/superset/testconn', data=data, content_type='application/json') assert response.status_code == 200 @@ -291,7 +293,8 @@ def test_testconn(self): # validate that the endpoint works with the decrypted sqlalchemy uri data = json.dumps({ 'uri': database.sqlalchemy_uri_decrypted, - 'name': 'main' + 'name': 'main', + 'impersonate_user': False }) response = self.client.post('/superset/testconn', data=data, content_type='application/json') assert response.status_code == 200