diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 84dcacf..955503b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,4 +1,5 @@ name: CI +concurrency: hana-tests on: push: @@ -20,3 +21,32 @@ jobs: run: pip install -e .[dev,test] - name: run pre-commit run: "pre-commit run --all" + ci-test: + if: ${{ github.event_name == 'pull_request' }} + strategy: + fail-fast: false + max-parallel: 1 + matrix: + python-version: ["3.11"] + sqlalchemy-version: ["1.4.*", "2.0.*"] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install project + run: pip install -e .[test] + - name: Install sqlalchemy + run: pip install sqlalchemy==${{ matrix.sqlalchemy-version }} + - name: run tests (with coverage) + run: | + PYTEST_DBURI=$(python test/ci_setup.py setup ${{ secrets.TEST_DBURI }}) + echo "::add-mask::$PYTEST_DBURI" + export PYTEST_ADDOPTS="--dburi $PYTEST_DBURI --dropfirst" + pytest -v --cov sqlalchemy_hana --cov-report html --cov-report xml test/ + python test/ci_setup.py teardown ${{ secrets.TEST_DBURI }} $PYTEST_DBURI + - name: run diff-cover + run: "diff-cover --config-file pyproject.toml coverage.xml" diff --git a/pyproject.toml b/pyproject.toml index 99a8d25..4041504 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,12 @@ dependencies = ["sqlalchemy>=1.4.0,<3", "hdbcli"] [project.optional-dependencies] dev = ["isort==5.12.0", "black==23.9.1", "pre-commit==3.5.0", "flake8==6.1.0"] -test = ["pytest==7.4.2"] +test = [ + "pytest==7.4.2", + "pytest-cov==4.1.0", + "coverage[toml]==7.3.2", + "diff-cover[toml]==8.0.0", +] [project.entry-points."sqlalchemy.dialects"] hana = "sqlalchemy_hana.dialect:HANAHDBCLIDialect" @@ -55,3 +60,12 @@ swagger_plugin_for_sphinx = ["py.typed"] [tool.isort] profile = "black" add_imports = ["from __future__ import annotations"] + +[tool.pytest.ini_options] +log_level = "DEBUG" +xfail_strict = true +filterwarnings = ["ignore"] + +[tool.diff_cover] +include_untracked = true +fail_under = 80 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 7012093..0000000 --- a/setup.cfg +++ /dev/null @@ -1,3 +0,0 @@ -[sqla_testing] -requirement_cls=sqlalchemy_hana.requirements:Requirements -profile_file=.profiles.txt diff --git a/sqlalchemy_hana/dialect.py b/sqlalchemy_hana/dialect.py index 4e15acc..ac1e8ab 100644 --- a/sqlalchemy_hana/dialect.py +++ b/sqlalchemy_hana/dialect.py @@ -82,6 +82,7 @@ "sql", "start", "sysuuid", + "table", "tablesample", "top", "trailing", @@ -121,9 +122,9 @@ def visit_bindparam(self, bindparam, **kwargs): return super(HANAStatementCompiler, self).visit_bindparam(bindparam, **kwargs) def visit_sequence(self, seq, **kwargs): - return self.dialect.identifier_preparer.format_sequence(seq) + ".NEXTVAL" + return self.preparer.format_sequence(seq) + ".NEXTVAL" - def visit_empty_set_expr(self, element_types): + def visit_empty_set_expr(self, element_types, **kwargs): return "SELECT %s FROM DUMMY WHERE 1 != 1" % ( ", ".join(["1" for _ in element_types]) ) @@ -197,6 +198,12 @@ def visit_isnot_distinct_from_binary(self, binary, operator, **kw): f"({left} IS NULL AND {right} IS NULL))" ) + def visit_is_true_unary_operator(self, element, operator, **kw): + return "%s = TRUE" % self.process(element.element, **kw) + + def visit_is_false_unary_operator(self, element, operator, **kw): + return "%s = FALSE" % self.process(element.element, **kw) + class HANATypeCompiler(compiler.GenericTypeCompiler): def visit_NUMERIC(self, type_): @@ -225,7 +232,7 @@ def visit_unicode_text(self, type_, **kwargs): class HANADDLCompiler(compiler.DDLCompiler): - def visit_unique_constraint(self, constraint): + def visit_unique_constraint(self, constraint, **kwargs): if len(constraint) == 0: return "" @@ -240,7 +247,7 @@ def visit_unique_constraint(self, constraint): text += self.define_constraint_deferrability(constraint) return text - def visit_create_table(self, create): + def visit_create_table(self, create, **kwargs): table = create.element # The table._prefixes list outlives the current compilation, meaning changing the list @@ -267,13 +274,13 @@ def visit_create_table(self, create): class HANAExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): - seq = self.dialect.identifier_preparer.format_sequence(seq) + seq = self.identifier_preparer.format_sequence(seq) return self._execute_scalar("SELECT %s.NEXTVAL FROM DUMMY" % seq, type_) class HANAInspector(reflection.Inspector): def get_table_oid(self, table_name, schema=None): - return self.dialect.get_table_oid( + return self.get_table_oid( self.bind, table_name, schema, info_cache=self.info_cache ) @@ -320,6 +327,10 @@ class HANABaseDialect(default.DefaultDialect): supports_default_values = False supports_sane_multi_rowcount = False isolation_level = None + div_is_floordiv = False + supports_schemas = True + supports_sane_rowcount = False + supports_is_distinct_from = False max_identifier_length = 127 @@ -407,21 +418,51 @@ def denormalize_name(self, name): name = name.upper() return name - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kwargs): schema = schema or self.default_schema_name result = connection.execute( sql.text( "SELECT 1 FROM SYS.TABLES " - "WHERE SCHEMA_NAME=:schema AND TABLE_NAME=:table", + "WHERE SCHEMA_NAME=:schema AND TABLE_NAME=:table " + "UNION ALL " + "SELECT 1 FROM SYS.VIEWS " + "WHERE SCHEMA_NAME=:schema AND VIEW_NAME=:table ", + ).bindparams( + schema=self.denormalize_name(schema), + table=self.denormalize_name(table_name), + ) + ) + return bool(result.first()) + + @reflection.cache + def has_schema(self, connection, schema_name, **kwargs): + result = connection.execute( + sql.text( + "SELECT 1 FROM SYS.SCHEMAS WHERE SCHEMA_NAME=:schema", + ).bindparams(schema=self.denormalize_name(schema_name)) + ) + return bool(result.first()) + + @reflection.cache + def has_index(self, connection, table_name, index_name, schema=None, **kwargs): + schema = schema or self.default_schema_name + + result = connection.execute( + sql.text( + "SELECT 1 FROM SYS.INDEXES " + "WHERE SCHEMA_NAME=:schema AND TABLE_NAME=:table AND INDEX_NAME=:index" ).bindparams( schema=self.denormalize_name(schema), table=self.denormalize_name(table_name), + index=self.denormalize_name(index_name), ) ) return bool(result.first()) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kwargs): schema = schema or self.default_schema_name result = connection.execute( sql.text( @@ -434,11 +475,13 @@ def has_sequence(self, connection, sequence_name, schema=None): ) return bool(result.first()) + @reflection.cache def get_schema_names(self, connection, **kwargs): result = connection.execute(sql.text("SELECT SCHEMA_NAME FROM SYS.SCHEMAS")) return list([self.normalize_name(name) for name, in result.fetchall()]) + @reflection.cache def get_table_names(self, connection, schema=None, **kwargs): schema = schema or self.default_schema_name @@ -487,8 +530,7 @@ def get_view_names(self, connection, schema=None, **kwargs): def get_view_definition(self, connection, view_name, schema=None, **kwargs): schema = schema or self.default_schema_name - - return connection.execute( + result = connection.execute( sql.text( "SELECT DEFINITION FROM SYS.VIEWS " "WHERE VIEW_NAME=:view_name AND SCHEMA_NAME=:schema LIMIT 1", @@ -498,8 +540,14 @@ def get_view_definition(self, connection, view_name, schema=None, **kwargs): ) ).scalar() + if result is None: + raise exc.NoSuchTableError() + return result + def get_columns(self, connection, table_name, schema=None, **kwargs): schema = schema or self.default_schema_name + if not self.has_table(connection, table_name, schema, **kwargs): + raise exc.NoSuchTableError() result = connection.execute( sql.text( @@ -550,8 +598,22 @@ def get_columns(self, connection, table_name, schema=None, **kwargs): return columns + @reflection.cache + def get_sequence_names(self, connection, schema=None, **kwargs): + schema = schema or self.default_schema_name + + result = connection.execute( + sql.text( + "SELECT SEQUENCE_NAME FROM SYS.SEQUENCES " + "WHERE SCHEMA_NAME=:schema ORDER BY SEQUENCE_NAME" + ).bindparams(schema=self.denormalize_name(schema)) + ) + return [self.normalize_name(row[0]) for row in result] + def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): lookup_schema = schema or self.default_schema_name + if not self.has_table(connection, table_name, lookup_schema, **kwargs): + raise exc.NoSuchTableError() result = connection.execute( sql.text( @@ -579,7 +641,7 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): foreign_key = { "name": foreign_key_name, "constrained_columns": [self.normalize_name(row[1])], - "referred_schema": schema, + "referred_schema": None, "referred_table": self.normalize_name(row[3]), "referred_columns": [self.normalize_name(row[4])], "options": {"onupdate": row[5], "ondelete": row[6]}, @@ -591,10 +653,12 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): foreign_keys[foreign_key_name] = foreign_key foreign_keys_list.append(foreign_key) - return foreign_keys_list + return sorted(foreign_keys_list, key=lambda foreign_key: foreign_key["name"]) def get_indexes(self, connection, table_name, schema=None, **kwargs): schema = schema or self.default_schema_name + if not self.has_table(connection, table_name, schema, **kwargs): + raise exc.NoSuchTableError() result = connection.execute( sql.text( @@ -610,10 +674,11 @@ def get_indexes(self, connection, table_name, schema=None, **kwargs): indexes = {} for name, column, constraint in result.fetchall(): - if name.startswith("_SYS"): + if constraint == "PRIMARY KEY": continue - name = self.normalize_name(name) + if not name.startswith("_SYS"): + name = self.normalize_name(name) column = self.normalize_name(column) if name not in indexes: @@ -629,10 +694,12 @@ def get_indexes(self, connection, table_name, schema=None, **kwargs): else: indexes[name]["column_names"].append(column) - return list(indexes.values()) + return sorted(list(indexes.values()), key=lambda index: index["name"]) def get_pk_constraint(self, connection, table_name, schema=None, **kwargs): schema = schema or self.default_schema_name + if not self.has_table(connection, table_name, schema, **kwargs): + raise exc.NoSuchTableError() result = connection.execute( sql.text( @@ -659,12 +726,14 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kwargs): def get_unique_constraints(self, connection, table_name, schema=None, **kwargs): schema = schema or self.default_schema_name + if not self.has_table(connection, table_name, schema, **kwargs): + raise exc.NoSuchTableError() result = connection.execute( sql.text( "SELECT CONSTRAINT_NAME, COLUMN_NAME FROM SYS.CONSTRAINTS " "WHERE SCHEMA_NAME=:schema AND TABLE_NAME=:table AND " - "IS_UNIQUE_KEY='TRUE' AND IS_PRIMARY_KEY='FALSE'" + "IS_UNIQUE_KEY='TRUE' AND IS_PRIMARY_KEY='FALSE' " "ORDER BY CONSTRAINT_NAME, POSITION" ).bindparams( schema=self.denormalize_name(schema), @@ -693,10 +762,15 @@ def get_unique_constraints(self, connection, table_name, schema=None, **kwargs): constraints.append(constraint) constraint["column_names"].append(self.normalize_name(column_name)) - return constraints + return sorted( + constraints, + key=lambda constraint: (constraint["name"] is not None, constraint["name"]), + ) def get_check_constraints(self, connection, table_name, schema=None, **kwargs): schema = schema or self.default_schema_name + if not self.has_table(connection, table_name, schema, **kwargs): + raise exc.NoSuchTableError() result = connection.execute( sql.text( @@ -736,6 +810,8 @@ def get_table_oid(self, connection, table_name, schema=None, **kwargs): def get_table_comment(self, connection, table_name, schema=None, **kwargs): schema = schema or self.default_schema_name + if not self.has_table(connection, table_name, schema, **kwargs): + raise exc.NoSuchTableError() result = connection.execute( sql.text( diff --git a/sqlalchemy_hana/requirements.py b/sqlalchemy_hana/requirements.py index eec6f82..e83d2b7 100644 --- a/sqlalchemy_hana/requirements.py +++ b/sqlalchemy_hana/requirements.py @@ -1,54 +1,17 @@ from __future__ import annotations -import sys - -import sqlalchemy from sqlalchemy.testing import exclusions, requirements class Requirements(requirements.SuiteRequirements): - @property - def temporary_tables(self): - return exclusions.open() - - @property - def temp_table_reflection(self): - return exclusions.open() - @property def views(self): return exclusions.open() - @property - def deferrable_or_no_constraints(self): - """Target database must support derferable constraints.""" - return exclusions.closed() - - @property - def named_constraints(self): - return exclusions.open() - - @property - def unique_constraint_reflection(self): - return exclusions.open() - @property def reflects_pk_names(self): return exclusions.open() - @property - def self_referential_foreign_keys(self): - return exclusions.open() - - @property - def empty_inserts(self): - """Empty value tuple in INSERT statement is not allowed""" - return exclusions.closed() - - @property - def precision_numerics_enotation_large(self): - return exclusions.open() - @property def precision_numerics_many_significant_digits(self): return exclusions.open() @@ -64,13 +27,9 @@ def datetime_literals(self): @property def time_microseconds(self): - """No support for microseconds in datetime""" + # SAP HANA does not support microseconds in TIME return exclusions.closed() - @property - def datetime_microseconds(self): - return exclusions.open() - @property def datetime_historic(self): return exclusions.open() @@ -79,234 +38,179 @@ def datetime_historic(self): def date_historic(self): return exclusions.open() - @property - def text_type(self): - """Currently not supported by PYHDB""" - return exclusions.open() - - @property - def schemas(self): - return exclusions.open() - @property def percent_schema_names(self): return exclusions.open() - @property - def savepoints(self): - """No support for savepoints in transactions""" - return exclusions.closed() - @property def selectone(self): - """HANA doesn't support 'SELECT 1' without 'FROM DUMMY'""" - return exclusions.closed() - - @property - def order_by_col_from_union(self): - return exclusions.open() - - @property - def broken_cx_oracle6_numerics(self): - return exclusions.closed() - - @property - def mysql_zero_date(self): - return exclusions.closed() - - @property - def mysql_non_strict(self): + # SAP HANA doesn't support 'SELECT 1' without 'FROM DUMMY' return exclusions.closed() @property def two_phase_transactions(self): - """Not supported by PYHDB""" - return exclusions.closed() - - @property - def predictable_gc(self): return exclusions.open() @property - def cpython(self): + def autoincrement_without_sequence(self): + # Not supported yet return exclusions.closed() @property - def python3(self): - if sys.version_info < (3,): - return exclusions.closed() + def isolation_level(self): return exclusions.open() - @property - def identity(self): - return exclusions.closed() - - @property - def sane_rowcount(self): - return exclusions.closed() + def get_isolation_levels(self, config): + return { + "default": "READ COMMITTED", + "supported": [ + "READ COMMITTED", + "SERIALIZABLE", + "REPEATABLE READ", + "AUTOCOMMIT", + ], + } @property - def sane_multi_rowcount(self): - return exclusions.closed() + def autocommit(self): + return exclusions.open() @property - def check_constraints(self): + def comment_reflection(self): return exclusions.open() @property - def update_nowait(self): - return exclusions.closed() + def sequences_optional(self): + return exclusions.open() @property - def independent_connections(self): + def timestamp_microseconds(self): return exclusions.open() @property - def non_broken_pickle(self): - return exclusions.closed() + def temp_table_names(self): + return exclusions.open() @property - def independent_cursors(self): + def tuple_in(self): return exclusions.open() @property - def cross_schema_fk_reflection(self): + def foreign_key_constraint_option_reflection_ondelete(self): + # TODO fix return exclusions.closed() @property - def updateable_autoincrement_pks(self): + def foreign_key_constraint_option_reflection_onupdate(self): + # TODO fix return exclusions.closed() @property - def bound_limit_offset(self): - return exclusions.open() - - @property - def isolation_level(self): - return exclusions.open() - - def get_isolation_levels(self, config): - return { - "default": "READ COMMITTED", - "supported": [ - "READ COMMITTED", - "SERIALIZABLE", - "REPEATABLE READ", - "AUTOCOMMIT", - ], - } - - # Disable mysql tests - @property - def mssql_freetds(self): + def check_constraint_reflection(self): + # TODO fix return exclusions.closed() - # Disable postgresql tests @property - def postgresql_utf8_server_encoding(self): + def expressions_against_unbounded_text(self): + # not supported by SAP HANA return exclusions.closed() @property - def range_types(self): + def independent_readonly_connections(self): + # TODO check if supported return exclusions.closed() @property - def hstore(self): + def sql_expression_limit_offset(self): + # SAP HANA does not support expressions in LIMIT or OFFSET return exclusions.closed() @property def array_type(self): + # Not yet supported, #119 return exclusions.closed() @property - def psycopg2_compatibility(self): + def unbounded_varchar(self): + # SAP HANA requires a length vor (N)VARCHAR return exclusions.closed() @property - def postgresql_jsonb(self): - return exclusions.closed() + def unique_index_reflect_as_unique_constraints(self): + # SAP HANA reflects unique indexes as unique constraints + return exclusions.open() @property - def savepoints_w_release(self): - return exclusions.closed() + def unique_constraints_reflect_as_index(self): + # SAP HANA reflects unique constraints as indexes + return exclusions.open() @property - def non_broken_binary(self): - return exclusions.closed() + def intersect(self): + return exclusions.open() @property - def oracle5x(self): - return exclusions.closed() + def except_(self): + return exclusions.open() @property - def psycopg2_or_pg8000_compatibility(self): - return exclusions.closed() + def window_functions(self): + return exclusions.open() @property - def psycopg2_native_hstore(self): - return exclusions.closed() + def comment_reflection_full_unicode(self): + return exclusions.open() @property - def psycopg2_native_json(self): - return exclusions.closed() + def foreign_key_constraint_name_reflection(self): + return exclusions.open() @property - def two_phase_recovery(self): - return exclusions.closed() + def cross_schema_fk_reflection(self): + return exclusions.open() @property - def enforces_check_constraints(self): + def fk_constraint_option_reflection_onupdate_restrict(self): + # TODO fix return exclusions.closed() @property - def implicitly_named_constraints(self): - return exclusions.open() + def fk_constraint_option_reflection_ondelete_restrict(self): + # TODO fix + return exclusions.closed() @property - def autocommit(self): + def schema_create_delete(self): return exclusions.open() @property - def comment_reflection(self): + def savepoints(self): return exclusions.open() @property - def sequences_optional(self): + def has_temp_table(self): return exclusions.open() @property - def timestamp_microseconds(self): + def unicode_ddl(self): return exclusions.open() @property - def temp_table_names(self): + def update_from(self): return exclusions.open() @property - def tuple_in(self): + def delete_from(self): return exclusions.open() @property - def foreign_key_constraint_option_reflection(self): + def mod_operator_as_percent_sign(self): return exclusions.open() @property - def check_constraint_reflection(self): - if sqlalchemy.__version__.startswith("1.1."): - # Skip reflection tests in SQLAlchemy~=1.1.0 due missing normalization - return exclusions.closed() + def order_by_label_with_expression(self): return exclusions.open() @property - def implicit_decimal_binds(self): - # See SQLAlchemy ticket 4036 - return exclusions.closed() - - @property - def expressions_against_unbounded_text(self): - return exclusions.closed() - - @property - def temporary_views(self): - # SAP HANA doesn't support temporary views only temporary tables. - return exclusions.closed() + def graceful_disconnects(self): + return exclusions.open() diff --git a/test.cfg b/test.cfg new file mode 100644 index 0000000..8d1ce9e --- /dev/null +++ b/test.cfg @@ -0,0 +1,3 @@ +[sqla_testing] +requirement_cls = sqlalchemy_hana.requirements:Requirements +profile_file = .profiles.txt diff --git a/test/ci_setup.py b/test/ci_setup.py new file mode 100644 index 0000000..e4d32e6 --- /dev/null +++ b/test/ci_setup.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import random +import string +import sys +from contextlib import closing +from urllib.parse import urlsplit + +from hdbcli import dbapi + + +def random_string(length: int) -> str: + return "".join( + random.choices( + string.ascii_uppercase + string.ascii_lowercase + string.digits, k=length + ) + ) + + +def setup(dburi: str) -> str: + url = urlsplit(dburi) + user = f"PYTEST_{random_string(10)}" + password = random_string(15) + + with closing( + dbapi.connect(url.hostname, url.port, url.username, url.password) + ) as connection, closing(connection.cursor()) as cursor: + cursor.execute( + f'CREATE USER {user} PASSWORD "{password}" NO FORCE_FIRST_PASSWORD_CHANGE' + ) + for schema in ["TEST_SCHEMA", "TEST_SCHEMA_2"]: + cursor.execute(f"SELECT 1 FROM SCHEMAS WHERE SCHEMA_NAME='{schema}'") + if cursor.fetchall(): + cursor.execute(f"DROP SCHEMA {schema} CASCADE") + cursor.execute(f"CREATE SCHEMA {schema}") + cursor.execute(f"GRANT ALL PRIVILEGES ON SCHEMA {schema} TO {user}") + cursor.execute(f"GRANT CREATE SCHEMA TO {user}") + + return f"hana://{user}:{password}@{url.hostname}:{url.port}" + + +def teardown(dburi: str, test_dburi: str) -> None: + url = urlsplit(dburi) + test_user = urlsplit(test_dburi).username + + with closing( + dbapi.connect(url.hostname, url.port, url.username, url.password) + ) as connection, closing(connection.cursor()) as cursor: + cursor.execute(f"DROP USER {test_user} CASCADE") + + +if __name__ == "__main__": + if sys.argv[1] == "setup": + print(setup(sys.argv[2])) + elif sys.argv[1] == "teardown": + teardown(sys.argv[2], sys.argv[3]) + else: + raise ValueError(f"Unknown mode {sys.argv[1]}") diff --git a/test/conftest.py b/test/conftest.py index d23c006..3358fbb 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,6 +2,7 @@ import logging +import pytest from sqlalchemy import Column, Sequence, event from sqlalchemy.dialects import registry @@ -9,6 +10,7 @@ registry.register("hana", "sqlalchemy_hana.dialect", "HANAHDBCLIDialect") registry.register("hana.hdbcli", "sqlalchemy_hana.dialect", "HANAHDBCLIDialect") +pytest.register_assert_rewrite("sqlalchemy.testing.assertions") @event.listens_for(Column, "after_parent_attach") @@ -18,4 +20,4 @@ def add_test_seq(column, table): # enable the SQLAlchemy plugin after our setup is done -import sqlalchemy.testing.plugin.pytestplugin # noqa: F401,E402 +from sqlalchemy.testing.plugin.pytestplugin import * # noqa: F403,F401,E402 diff --git a/test/test_hana_connect_url.py b/test/test_hana_connect_url.py index 3d74a23..a8642d8 100644 --- a/test/test_hana_connect_url.py +++ b/test/test_hana_connect_url.py @@ -5,7 +5,6 @@ class HANAConnectUrlWithTenantTest(sqlalchemy.testing.fixtures.TestBase): - @sqlalchemy.testing.only_on("hana+hdbcli") def test_hdbcli_tenant_url_default_port(self): """If the URL includes a tenant database, the dialect pass the adjusted values to hdbcli. @@ -23,7 +22,6 @@ def test_hdbcli_tenant_url_default_port(self): assert result_kwargs["password"] == "secret-password" assert result_kwargs["databaseName"] == "TENANT_NAME" - @sqlalchemy.testing.only_on("hana+hdbcli") def test_hdbcli_tenant_url_changed_port(self): """If the URL includes a tenant database, the dialect pass the adjusted values to hdbcli. @@ -41,7 +39,6 @@ def test_hdbcli_tenant_url_changed_port(self): class HANAConnectUrlWithHDBUserStoreTest(sqlalchemy.testing.fixtures.TestBase): - @sqlalchemy.testing.only_on("hana+hdbcli") def test_parsing_userkey_hdbcli(self): """With HDBCLI, the user may reference to a local HDBUserStore key which holds the connection details. SQLAlchemy-HANA should only pass the userkey name to @@ -55,7 +52,6 @@ def test_parsing_userkey_hdbcli(self): class HANAConnectUrlParsing(sqlalchemy.testing.fixtures.TestBase): - @sqlalchemy.testing.only_on("hana+hdbcli") def test_pass_uri_query_as_kwargs(self): """SQLAlchemy-HANA should passes all URL parameters to hdbcli.""" diff --git a/test/test_hana_connection.py b/test/test_hana_connection.py index 3ee76c4..a486946 100644 --- a/test/test_hana_connection.py +++ b/test/test_hana_connection.py @@ -1,15 +1,13 @@ from __future__ import annotations +from unittest.mock import Mock + import sqlalchemy.testing -from sqlalchemy.testing.mock import Mock +from hdbcli.dbapi import Error class HANAHDBCLIConnectionIsDisconnectedTest(sqlalchemy.testing.fixtures.TestBase): - __only_on__ = "hana+hdbcli" - def test_detection_by_error_code(self): - from hdbcli.dbapi import Error - dialect = sqlalchemy.testing.db.dialect assert dialect.is_disconnect(Error(-10709, "Connect failed"), None, None) diff --git a/test/test_hana_sql.py b/test/test_hana_sql.py index 075872f..d257678 100644 --- a/test/test_hana_sql.py +++ b/test/test_hana_sql.py @@ -7,8 +7,6 @@ class HANACompileTest( sqlalchemy.testing.fixtures.TestBase, sqlalchemy.testing.AssertsCompiledSQL ): - __only_on__ = "hana" - def test_sql_with_for_update(self): table1 = table("mytable", column("myid"), column("name"), column("description")) diff --git a/test/test_isolation_level.py b/test/test_isolation_level.py index ad302f3..1d95827 100644 --- a/test/test_isolation_level.py +++ b/test/test_isolation_level.py @@ -6,8 +6,6 @@ class IsolationLevelTest(sqlalchemy.testing.fixtures.TestBase): - __only_on__ = "hana" - def _default_isolation_level(self): return "READ COMMITTED" diff --git a/test/test_sqlalchemy_dialect_suite.py b/test/test_sqlalchemy_dialect_suite.py index 8e431f0..d292d8b 100644 --- a/test/test_sqlalchemy_dialect_suite.py +++ b/test/test_sqlalchemy_dialect_suite.py @@ -1,7 +1,6 @@ from __future__ import annotations -import sqlalchemy -from sqlalchemy.testing.suite import ComponentReflectionTest as _ComponentReflectionTest +from sqlalchemy.testing.provision import temp_table_keyword_args from sqlalchemy.testing.suite import * # noqa: F401, F403 # Import dialect test suite provided by SQLAlchemy into SQLAlchemy-HANA test collection. @@ -9,31 +8,8 @@ # for compatibility with SAP HANA. -class ComponentReflectionTest(_ComponentReflectionTest): - # Overwrite function so that temporary tables are correctly created with HANA's specific - # GLOBAL prefix. - @classmethod - def define_temp_tables(cls, metadata): - kw = { - "prefixes": ["GLOBAL TEMPORARY"], - } - - sqlalchemy.Table( - "user_tmp", - metadata, - sqlalchemy.Column("id", sqlalchemy.INT, primary_key=True), - sqlalchemy.Column("name", sqlalchemy.VARCHAR(50)), - sqlalchemy.Column("foo", sqlalchemy.INT), - sqlalchemy.UniqueConstraint("name", name="user_tmp_uq"), - sqlalchemy.Index("user_tmp_ix", "foo"), - **kw, - ) - - # Overwrite function as SQLAlchemy assumes only the PostgreSQL dialect allows to retrieve the - # table object id but SQLAlchemy-HANA also provides this functionality. - @sqlalchemy.testing.provide_metadata - def _test_get_table_oid(self, table_name, schema=None): - meta = self.metadata - insp = sqlalchemy.inspect(meta.bind) - oid = insp.get_table_oid(table_name, schema) - self.assert_(isinstance(oid, int)) +@temp_table_keyword_args.for_db("*") +def _temp_table_keyword_args(*args, **kwargs): + return { + "prefixes": ["GLOBAL TEMPORARY"], + }