diff --git a/superset/config.py b/superset/config.py index 3b63a48f929b5..1e7af7c4cab11 100644 --- a/superset/config.py +++ b/superset/config.py @@ -211,7 +211,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # # e.g.: # -# class AesGcmEncryptedAdapter( # pylint: disable=too-few-public-methods +# class AesGcmEncryptedAdapter( # AbstractEncryptedFieldAdapter # ): # def create( diff --git a/superset/databases/api.py b/superset/databases/api.py index a6160b0d2f1fe..4c617eb720519 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -1083,8 +1083,8 @@ def available(self) -> Response: "preferred": engine_spec.engine_name in preferred_databases, } - if hasattr(engine_spec, "default_driver"): - payload["default_driver"] = engine_spec.default_driver # type: ignore + if engine_spec.default_driver: + payload["default_driver"] = engine_spec.default_driver # show configuration parameters for DBs that support it if ( diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index a9f1633a18144..e9fe5eaf0c972 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -29,8 +29,7 @@ ) from superset.databases.dao import DatabaseDAO from superset.databases.utils import make_url_safe -from superset.db_engine_specs import get_engine_specs -from superset.db_engine_specs.base import BasicParametersMixin +from superset.db_engine_specs import get_engine_spec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.extensions import event_logger from superset.models.core import Database @@ -45,25 +44,13 @@ def __init__(self, parameters: Dict[str, Any]): def run(self) -> None: engine = self._properties["engine"] - engine_specs = get_engine_specs() + driver = self._properties.get("driver") if engine in BYPASS_VALIDATION_ENGINES: # Skip engines that are only validated onCreate return - if engine not in engine_specs: - raise InvalidEngineError( - SupersetError( - message=__( - 'Engine "%(engine)s" is not a valid engine.', - engine=engine, - ), - error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - level=ErrorLevel.ERROR, - extra={"allowed": list(engine_specs), "provided": engine}, - ), - ) - engine_spec = engine_specs[engine] + engine_spec = get_engine_spec(engine, driver) if not hasattr(engine_spec, "parameters_schema"): raise InvalidEngineError( SupersetError( @@ -73,14 +60,6 @@ def run(self) -> None: ), error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, level=ErrorLevel.ERROR, - extra={ - "allowed": [ - name - for name, engine_spec in engine_specs.items() - if issubclass(engine_spec, BasicParametersMixin) - ], - "provided": engine, - }, ), ) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index aa88822a854df..b6a0ab6983064 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -16,7 +16,7 @@ # under the License. import inspect import json -from typing import Any, Dict, Optional, Type +from typing import Any, Dict from flask import current_app from flask_babel import lazy_gettext as _ @@ -28,7 +28,7 @@ from superset import db from superset.databases.commands.exceptions import DatabaseInvalidError from superset.databases.utils import make_url_safe -from superset.db_engine_specs import BaseEngineSpec, get_engine_specs +from superset.db_engine_specs import get_engine_spec from superset.exceptions import CertificateException, SupersetSecurityException from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK from superset.security.analytics_db_safety import check_sqlalchemy_uri @@ -150,7 +150,7 @@ def sqlalchemy_uri_validator(value: str) -> str: [ _( "Invalid connection string, a valid string usually follows: " - "driver://user:password@database-host/database-name" + "backend+driver://user:password@database-host/database-name" ) ] ) from ex @@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods """ engine = fields.String(allow_none=True, description="SQLAlchemy engine to use") + driver = fields.String(allow_none=True, description="SQLAlchemy driver to use") parameters = fields.Dict( keys=fields.String(), values=fields.Raw(), @@ -262,10 +263,20 @@ def build_sqlalchemy_uri( or parameters.pop("engine", None) or data.pop("backend", None) ) + driver = data.pop("driver", None) configuration_method = data.get("configuration_method") if configuration_method == ConfigurationMethod.DYNAMIC_FORM: - engine_spec = get_engine_spec(engine) + if not engine: + raise ValidationError( + [ + _( + "An engine must be specified when passing " + "individual parameters to a database." + ) + ] + ) + engine_spec = get_engine_spec(engine, driver) if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr( engine_spec, "parameters_schema" @@ -295,34 +306,12 @@ def build_sqlalchemy_uri( return data -def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]: - if not engine: - raise ValidationError( - [ - _( - "An engine must be specified when passing " - "individual parameters to a database." - ) - ] - ) - engine_specs = get_engine_specs() - if engine not in engine_specs: - raise ValidationError( - [ - _( - 'Engine "%(engine)s" is not a valid engine.', - engine=engine, - ) - ] - ) - return engine_specs[engine] - - class DatabaseValidateParametersSchema(Schema): class Meta: # pylint: disable=too-few-public-methods unknown = EXCLUDE engine = fields.String(required=True, description="SQLAlchemy engine to use") + driver = fields.String(allow_none=True, description="SQLAlchemy driver to use") parameters = fields.Dict( keys=fields.String(), values=fields.Raw(allow_none=True), diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py index dac700199557c..29e4877337b61 100644 --- a/superset/db_engine_specs/__init__.py +++ b/superset/db_engine_specs/__init__.py @@ -33,27 +33,34 @@ from collections import defaultdict from importlib import import_module from pathlib import Path -from typing import Any, Dict, List, Set, Type +from typing import Any, Dict, List, Optional, Set, Type import sqlalchemy.databases import sqlalchemy.dialects from pkg_resources import iter_entry_points from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.url import URL from superset.db_engine_specs.base import BaseEngineSpec logger = logging.getLogger(__name__) -def is_engine_spec(attr: Any) -> bool: +def is_engine_spec(obj: Any) -> bool: + """ + Return true if a given object is a DB engine spec. + """ return ( - inspect.isclass(attr) - and issubclass(attr, BaseEngineSpec) - and attr != BaseEngineSpec + inspect.isclass(obj) + and issubclass(obj, BaseEngineSpec) + and obj != BaseEngineSpec ) def load_engine_specs() -> List[Type[BaseEngineSpec]]: + """ + Load all engine specs, native and 3rd party. + """ engine_specs: List[Type[BaseEngineSpec]] = [] # load standard engines @@ -78,20 +85,31 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]: return engine_specs -def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]: +def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]: + """ + Return the DB engine spec associated with a given SQLAlchemy URL. + + Note that if a driver is not specified the function returns the first DB engine spec + that supports the backend. Also, if a driver is specified but no DB engine explicitly + supporting that driver exists then a backend-only match is done, in order to allow new + drivers to work with Superset even if they are not listed in the DB engine spec + drivers. + """ engine_specs = load_engine_specs() - # build map from name/alias -> spec - engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {} - for engine_spec in engine_specs: - names = [engine_spec.engine] - if engine_spec.engine_aliases: - names.extend(engine_spec.engine_aliases) + if driver is not None: + for engine_spec in engine_specs: + if engine_spec.supports_backend(backend, driver): + return engine_spec - for name in names: - engine_specs_map[name] = engine_spec + # check ignoring the driver, in order to support new drivers; this will return a + # random DB engine spec that supports the engine + for engine_spec in engine_specs: + if engine_spec.supports_backend(backend): + return engine_spec - return engine_specs_map + # default to the generic DB engine spec + return BaseEngineSpec # there's a mismatch between the dialect name reported by the driver in these diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 368770e2612f5..1bf2f4a3f7c86 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -183,9 +183,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods having to add the same aggregation in SELECT. """ + engine_name: Optional[str] = None # for user messages, overridden in child classes + + # These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers; + # see the ``supports_url`` and ``supports_backend`` methods below. engine = "base" # str as defined in sqlalchemy.engine.engine engine_aliases: Set[str] = set() - engine_name: Optional[str] = None # for user messages, overridden in child classes + drivers: Dict[str, str] = {} + default_driver: Optional[str] = None + _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} column_type_mappings: Tuple[ColumnTypeMapping, ...] = ( @@ -355,6 +361,58 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]] ] = {} + @classmethod + def supports_url(cls, url: URL) -> bool: + """ + Returns true if the DB engine spec supports a given SQLAlchemy URL. + + As an example, if a given DB engine spec has: + + class PostgresDBEngineSpec: + engine = "postgresql" + engine_aliases = "postgres" + drivers = { + "psycopg2": "The default Postgres driver", + "asyncpg": "An asynchronous Postgres driver", + } + + It would be used for all the following SQLAlchemy URIs: + + - postgres://user:password@host/db + - postgresql://user:password@host/db + - postgres+asyncpg://user:password@host/db + - postgres+psycopg2://user:password@host/db + - postgresql+asyncpg://user:password@host/db + - postgresql+psycopg2://user:password@host/db + + Note that SQLAlchemy has a default driver even if one is not specified: + + >>> from sqlalchemy.engine.url import make_url + >>> make_url('postgres://').get_driver_name() + 'psycopg2' + + """ + backend = url.get_backend_name() + driver = url.get_driver_name() + return cls.supports_backend(backend, driver) + + @classmethod + def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool: + """ + Returns true if the DB engine spec supports a given SQLAlchemy backend/driver. + """ + # check the backend first + if backend != cls.engine and backend not in cls.engine_aliases: + return False + + # originally DB engine specs didn't declare any drivers and the check was made + # only on the engine; if that's the case, ignore the driver for backwards + # compatibility + if not cls.drivers or driver is None: + return True + + return driver in cls.drivers + @classmethod def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: """ @@ -394,7 +452,7 @@ def get_allow_cost_estimate( # pylint: disable=unused-argument @classmethod def get_text_clause(cls, clause: str) -> TextClause: """ - SQLALchemy wrapper to ensure text clauses are escaped properly + SQLAlchemy wrapper to ensure text clauses are escaped properly :param clause: string clause with potentially unescaped characters :return: text clause with escaped characters diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 79718c93f664c..90d90b9448fa7 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -47,18 +47,23 @@ class DatabricksHiveEngineSpec(HiveEngineSpec): - engine = "databricks" engine_name = "Databricks Interactive Cluster" - driver = "pyhive" + + engine = "databricks" + drivers = {"pyhive": "Hive driver for Interactive Cluster"} + default_driver = "pyhive" + _show_functions_column = "function" _time_grain_expressions = time_grain_expressions class DatabricksODBCEngineSpec(BaseEngineSpec): - engine = "databricks" engine_name = "Databricks SQL Endpoint" - driver = "pyodbc" + + engine = "databricks" + drivers = {"pyodbc": "ODBC driver for SQL endpoint"} + default_driver = "pyodbc" _time_grain_expressions = time_grain_expressions @@ -74,9 +79,11 @@ def epoch_to_dttm(cls) -> str: class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec): - engine = "databricks" engine_name = "Databricks Native Connector" - driver = "connector" + + engine = "databricks" + drivers = {"connector": "Native all-purpose driver"} + default_driver = "connector" @staticmethod def get_extra_params(database: "Database") -> Dict[str, Any]: diff --git a/superset/db_engine_specs/shillelagh.py b/superset/db_engine_specs/shillelagh.py index c6e6f618c7251..37301224484b7 100644 --- a/superset/db_engine_specs/shillelagh.py +++ b/superset/db_engine_specs/shillelagh.py @@ -20,7 +20,11 @@ class ShillelaghEngineSpec(SqliteEngineSpec): """Engine for shillelagh""" - engine = "shillelagh" engine_name = "Shillelagh" + engine = "shillelagh" + drivers = {"apsw": "SQLite driver"} + default_driver = "apsw" + sqlalchemy_uri_placeholder = "shillelagh://" + allows_joins = True allows_subqueries = True diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 2a23d1c969593..9aa89ce34a06f 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -35,7 +35,7 @@ from superset.models.core import Database try: - from trino.dbapi import Cursor # pylint: disable=unused-import + from trino.dbapi import Cursor except ImportError: pass diff --git a/superset/models/core.py b/superset/models/core.py index b5a4aa6537da2..ec7ec793212c1 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -46,7 +46,7 @@ from sqlalchemy.engine import Connection, Dialect, Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL -from sqlalchemy.exc import ArgumentError +from sqlalchemy.exc import ArgumentError, NoSuchModuleError from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import relationship from sqlalchemy.pool import NullPool @@ -635,15 +635,20 @@ def get_all_schema_names( # pylint: disable=unused-argument @property def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]: - return self.get_db_engine_spec_for_backend(self.backend) + url = make_url_safe(self.sqlalchemy_uri_decrypted) + return self.get_db_engine_spec(url) @classmethod @memoized - def get_db_engine_spec_for_backend( - cls, backend: str - ) -> Type[db_engine_specs.BaseEngineSpec]: - engines = db_engine_specs.get_engine_specs() - return engines.get(backend, db_engine_specs.BaseEngineSpec) + def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]: + backend = url.get_backend_name() + try: + driver = url.get_driver_name() + except NoSuchModuleError: + # can't load the driver, fallback for backwards compatibility + driver = None + + return db_engine_specs.get_engine_spec(backend, driver) def grains(self) -> Tuple[TimeGrain, ...]: """Defines time granularity database-specific expressions. diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 8ff12b2406b54..b53418fb16496 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -1425,7 +1425,7 @@ def test_test_connection_failed(self): expected_response = { "errors": [ { - "message": "Could not load database driver: AzureSynapseSpec", + "message": "Could not load database driver: MssqlEngineSpec", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index 07f9bfcf318dc..f998444f31895 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -20,7 +20,7 @@ import pytest from superset.connectors.sqla.models import TableColumn -from superset.db_engine_specs import get_engine_specs +from superset.db_engine_specs import load_engine_specs from superset.db_engine_specs.base import ( BaseEngineSpec, BasicParametersMixin, @@ -195,7 +195,7 @@ class DummyEngineSpec(BaseEngineSpec): def test_engine_time_grain_validity(self): time_grains = set(builtin_time_grains.keys()) # loop over all subclasses of BaseEngineSpec - for engine in get_engine_specs().values(): + for engine in load_engine_specs(): if engine is not BaseEngineSpec: # make sure time grain functions have been defined self.assertGreater(len(engine.get_time_grain_expressions()), 0) diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index e6eb4fc1d13ea..79a307a488515 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -20,7 +20,7 @@ from sqlalchemy import column, literal_column from sqlalchemy.dialects import postgresql -from superset.db_engine_specs import get_engine_specs +from superset.db_engine_specs import load_engine_specs from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query @@ -137,7 +137,11 @@ def test_engine_alias_name(self): """ DB Eng Specs (postgres): Test "postgres" in engine spec """ - self.assertIn("postgres", get_engine_specs()) + backends = set() + for engine in load_engine_specs(): + backends.add(engine.engine) + backends.update(engine.engine_aliases) + assert "postgres" in backends def test_extras_without_ssl(self): db = mock.Mock() diff --git a/tests/integration_tests/databases/schema_tests.py b/tests/unit_tests/databases/schema_tests.py similarity index 57% rename from tests/integration_tests/databases/schema_tests.py rename to tests/unit_tests/databases/schema_tests.py index 1f8ca067f6b0d..58a1f6389d4c1 100644 --- a/tests/integration_tests/databases/schema_tests.py +++ b/tests/unit_tests/databases/schema_tests.py @@ -15,31 +15,59 @@ # specific language governing permissions and limitations # under the License. -from unittest import mock +# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, redefined-outer-name +from typing import TYPE_CHECKING + +import pytest from marshmallow import fields, Schema, ValidationError +from pytest_mock import MockFixture + +if TYPE_CHECKING: + from superset.databases.schemas import DatabaseParametersSchemaMixin + from superset.db_engine_specs.base import BasicParametersMixin -from superset.databases.schemas import DatabaseParametersSchemaMixin -from superset.db_engine_specs.base import BasicParametersMixin -from superset.models.core import ConfigurationMethod +# pylint: disable=too-few-public-methods +class InvalidEngine: + """ + An invalid DB engine spec. + """ -class DummySchema(Schema, DatabaseParametersSchemaMixin): - sqlalchemy_uri = fields.String() +@pytest.fixture +def dummy_schema() -> "DatabaseParametersSchemaMixin": + """ + Fixture providing a dummy schema. + """ + from superset.databases.schemas import DatabaseParametersSchemaMixin -class DummyEngine(BasicParametersMixin): - engine = "dummy" - default_driver = "dummy" + class DummySchema(Schema, DatabaseParametersSchemaMixin): + sqlalchemy_uri = fields.String() + return DummySchema() + + +@pytest.fixture +def dummy_engine(mocker: MockFixture) -> None: + """ + Fixture proving a dummy DB engine spec. + """ + from superset.db_engine_specs.base import BasicParametersMixin + + class DummyEngine(BasicParametersMixin): + engine = "dummy" + default_driver = "dummy" + + mocker.patch("superset.databases.schemas.get_engine_spec", return_value=DummyEngine) -class InvalidEngine: - pass +def test_database_parameters_schema_mixin( + dummy_engine: None, + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod -@mock.patch("superset.databases.schemas.get_engine_specs") -def test_database_parameters_schema_mixin(get_engine_specs): - get_engine_specs.return_value = {"dummy_engine": DummyEngine} payload = { "engine": "dummy_engine", "configuration_method": ConfigurationMethod.DYNAMIC_FORM, @@ -51,15 +79,18 @@ def test_database_parameters_schema_mixin(get_engine_specs): "database": "dbname", }, } - schema = DummySchema() - result = schema.load(payload) + result = dummy_schema.load(payload) assert result == { "configuration_method": ConfigurationMethod.DYNAMIC_FORM, "sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname", } -def test_database_parameters_schema_mixin_no_engine(): +def test_database_parameters_schema_mixin_no_engine( + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod + payload = { "configuration_method": ConfigurationMethod.DYNAMIC_FORM, "parameters": { @@ -67,23 +98,28 @@ def test_database_parameters_schema_mixin_no_engine(): "password": "password", "host": "localhost", "port": 12345, - "dbname": "dbname", + "database": "dbname", }, } - schema = DummySchema() try: - schema.load(payload) + dummy_schema.load(payload) except ValidationError as err: assert err.messages == { "_schema": [ - "An engine must be specified when passing individual parameters to a database." + ( + "An engine must be specified when passing individual parameters to " + "a database." + ), ] } -@mock.patch("superset.databases.schemas.get_engine_specs") -def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs): - get_engine_specs.return_value = {} +def test_database_parameters_schema_mixin_invalid_engine( + dummy_engine: None, + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod + payload = { "engine": "dummy_engine", "configuration_method": ConfigurationMethod.DYNAMIC_FORM, @@ -92,21 +128,24 @@ def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs): "password": "password", "host": "localhost", "port": 12345, - "dbname": "dbname", + "database": "dbname", }, } - schema = DummySchema() try: - schema.load(payload) + dummy_schema.load(payload) except ValidationError as err: + print(err.messages) assert err.messages == { "_schema": ['Engine "dummy_engine" is not a valid engine.'] } -@mock.patch("superset.databases.schemas.get_engine_specs") -def test_database_parameters_schema_no_mixin(get_engine_specs): - get_engine_specs.return_value = {"invalid_engine": InvalidEngine} +def test_database_parameters_schema_no_mixin( + dummy_engine: None, + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod + payload = { "engine": "invalid_engine", "configuration_method": ConfigurationMethod.DYNAMIC_FORM, @@ -118,9 +157,8 @@ def test_database_parameters_schema_no_mixin(get_engine_specs): "database": "dbname", }, } - schema = DummySchema() try: - schema.load(payload) + dummy_schema.load(payload) except ValidationError as err: assert err.messages == { "_schema": [ @@ -132,9 +170,12 @@ def test_database_parameters_schema_no_mixin(get_engine_specs): } -@mock.patch("superset.databases.schemas.get_engine_specs") -def test_database_parameters_schema_mixin_invalid_type(get_engine_specs): - get_engine_specs.return_value = {"dummy_engine": DummyEngine} +def test_database_parameters_schema_mixin_invalid_type( + dummy_engine: None, + dummy_schema: "Schema", +) -> None: + from superset.models.core import ConfigurationMethod + payload = { "engine": "dummy_engine", "configuration_method": ConfigurationMethod.DYNAMIC_FORM, @@ -146,8 +187,7 @@ def test_database_parameters_schema_mixin_invalid_type(get_engine_specs): "database": "dbname", }, } - schema = DummySchema() try: - schema.load(payload) + dummy_schema.load(payload) except ValidationError as err: assert err.messages == {"port": ["Not a valid integer."]} diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 3338ddcb61441..5eb60dc6f93ef 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -59,7 +59,7 @@ def get_metrics( }, ] - database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore + database.get_db_engine_spec = mocker.MagicMock( # type: ignore return_value=CustomSqliteEngineSpec ) assert database.get_metrics("table") == [ @@ -70,3 +70,78 @@ def get_metrics( "verbose_name": "COUNT(DISTINCT user_id)", }, ] + + +def test_get_db_engine_spec(mocker: MockFixture) -> None: + """ + Tests for ``get_db_engine_spec``. + """ + from superset.db_engine_specs import BaseEngineSpec + from superset.models.core import Database + + # pylint: disable=abstract-method + class PostgresDBEngineSpec(BaseEngineSpec): + """ + A DB engine spec with drivers and a default driver. + """ + + engine = "postgresql" + engine_aliases = {"postgres"} + drivers = { + "psycopg2": "The default Postgres driver", + "asyncpg": "An async Postgres driver", + } + default_driver = "psycopg2" + + # pylint: disable=abstract-method + class OldDBEngineSpec(BaseEngineSpec): + """ + And old DB engine spec without drivers nor a default driver. + """ + + engine = "mysql" + + load_engine_specs = mocker.patch("superset.db_engine_specs.load_engine_specs") + load_engine_specs.return_value = [ + PostgresDBEngineSpec, + OldDBEngineSpec, + ] + + assert ( + Database(database_name="db", sqlalchemy_uri="postgresql://").db_engine_spec + == PostgresDBEngineSpec + ) + assert ( + Database( + database_name="db", sqlalchemy_uri="postgresql+psycopg2://" + ).db_engine_spec + == PostgresDBEngineSpec + ) + assert ( + Database( + database_name="db", sqlalchemy_uri="postgresql+asyncpg://" + ).db_engine_spec + == PostgresDBEngineSpec + ) + assert ( + Database( + database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://" + ).db_engine_spec + == PostgresDBEngineSpec + ) + assert ( + Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec + == OldDBEngineSpec + ) + assert ( + Database( + database_name="db", sqlalchemy_uri="mysql+mysqlconnector://" + ).db_engine_spec + == OldDBEngineSpec + ) + assert ( + Database( + database_name="db", sqlalchemy_uri="mysql+fancynewdriver://" + ).db_engine_spec + == OldDBEngineSpec + )