Skip to content

Commit

Permalink
Add tests to CI (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
kasium authored Oct 30, 2023
1 parent 4503ce9 commit 9d4f0c5
Show file tree
Hide file tree
Showing 13 changed files with 284 additions and 234 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: CI
concurrency: hana-tests

on:
push:
Expand All @@ -20,3 +21,32 @@ jobs:
run: pip install -e .[dev,test]
- name: run pre-commit
run: "pre-commit run --all"
ci-test:
if: ${{ github.event_name == 'pull_request' }}
strategy:
fail-fast: false
max-parallel: 1
matrix:
python-version: ["3.11"]
sqlalchemy-version: ["1.4.*", "2.0.*"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install project
run: pip install -e .[test]
- name: Install sqlalchemy
run: pip install sqlalchemy==${{ matrix.sqlalchemy-version }}
- name: run tests (with coverage)
run: |
PYTEST_DBURI=$(python test/ci_setup.py setup ${{ secrets.TEST_DBURI }})
echo "::add-mask::$PYTEST_DBURI"
export PYTEST_ADDOPTS="--dburi $PYTEST_DBURI --dropfirst"
pytest -v --cov sqlalchemy_hana --cov-report html --cov-report xml test/
python test/ci_setup.py teardown ${{ secrets.TEST_DBURI }} $PYTEST_DBURI
- name: run diff-cover
run: "diff-cover --config-file pyproject.toml coverage.xml"
16 changes: 15 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ dependencies = ["sqlalchemy>=1.4.0,<3", "hdbcli"]

[project.optional-dependencies]
dev = ["isort==5.12.0", "black==23.9.1", "pre-commit==3.5.0", "flake8==6.1.0"]
test = ["pytest==7.4.2"]
test = [
"pytest==7.4.2",
"pytest-cov==4.1.0",
"coverage[toml]==7.3.2",
"diff-cover[toml]==8.0.0",
]

[project.entry-points."sqlalchemy.dialects"]
hana = "sqlalchemy_hana.dialect:HANAHDBCLIDialect"
Expand All @@ -55,3 +60,12 @@ swagger_plugin_for_sphinx = ["py.typed"]
[tool.isort]
profile = "black"
add_imports = ["from __future__ import annotations"]

[tool.pytest.ini_options]
log_level = "DEBUG"
xfail_strict = true
filterwarnings = ["ignore"]

[tool.diff_cover]
include_untracked = true
fail_under = 80
3 changes: 0 additions & 3 deletions setup.cfg

This file was deleted.

112 changes: 94 additions & 18 deletions sqlalchemy_hana/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"sql",
"start",
"sysuuid",
"table",
"tablesample",
"top",
"trailing",
Expand Down Expand Up @@ -121,9 +122,9 @@ def visit_bindparam(self, bindparam, **kwargs):
return super(HANAStatementCompiler, self).visit_bindparam(bindparam, **kwargs)

def visit_sequence(self, seq, **kwargs):
return self.dialect.identifier_preparer.format_sequence(seq) + ".NEXTVAL"
return self.preparer.format_sequence(seq) + ".NEXTVAL"

def visit_empty_set_expr(self, element_types):
def visit_empty_set_expr(self, element_types, **kwargs):
return "SELECT %s FROM DUMMY WHERE 1 != 1" % (
", ".join(["1" for _ in element_types])
)
Expand Down Expand Up @@ -197,6 +198,12 @@ def visit_isnot_distinct_from_binary(self, binary, operator, **kw):
f"({left} IS NULL AND {right} IS NULL))"
)

def visit_is_true_unary_operator(self, element, operator, **kw):
return "%s = TRUE" % self.process(element.element, **kw)

def visit_is_false_unary_operator(self, element, operator, **kw):
return "%s = FALSE" % self.process(element.element, **kw)


class HANATypeCompiler(compiler.GenericTypeCompiler):
def visit_NUMERIC(self, type_):
Expand Down Expand Up @@ -225,7 +232,7 @@ def visit_unicode_text(self, type_, **kwargs):


class HANADDLCompiler(compiler.DDLCompiler):
def visit_unique_constraint(self, constraint):
def visit_unique_constraint(self, constraint, **kwargs):
if len(constraint) == 0:
return ""

Expand All @@ -240,7 +247,7 @@ def visit_unique_constraint(self, constraint):
text += self.define_constraint_deferrability(constraint)
return text

def visit_create_table(self, create):
def visit_create_table(self, create, **kwargs):
table = create.element

# The table._prefixes list outlives the current compilation, meaning changing the list
Expand All @@ -267,13 +274,13 @@ def visit_create_table(self, create):

class HANAExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
seq = self.dialect.identifier_preparer.format_sequence(seq)
seq = self.identifier_preparer.format_sequence(seq)
return self._execute_scalar("SELECT %s.NEXTVAL FROM DUMMY" % seq, type_)


class HANAInspector(reflection.Inspector):
def get_table_oid(self, table_name, schema=None):
return self.dialect.get_table_oid(
return self.get_table_oid(
self.bind, table_name, schema, info_cache=self.info_cache
)

Expand Down Expand Up @@ -320,6 +327,10 @@ class HANABaseDialect(default.DefaultDialect):
supports_default_values = False
supports_sane_multi_rowcount = False
isolation_level = None
div_is_floordiv = False
supports_schemas = True
supports_sane_rowcount = False
supports_is_distinct_from = False

max_identifier_length = 127

Expand Down Expand Up @@ -407,21 +418,51 @@ def denormalize_name(self, name):
name = name.upper()
return name

def has_table(self, connection, table_name, schema=None):
@reflection.cache
def has_table(self, connection, table_name, schema=None, **kwargs):
schema = schema or self.default_schema_name

result = connection.execute(
sql.text(
"SELECT 1 FROM SYS.TABLES "
"WHERE SCHEMA_NAME=:schema AND TABLE_NAME=:table",
"WHERE SCHEMA_NAME=:schema AND TABLE_NAME=:table "
"UNION ALL "
"SELECT 1 FROM SYS.VIEWS "
"WHERE SCHEMA_NAME=:schema AND VIEW_NAME=:table ",
).bindparams(
schema=self.denormalize_name(schema),
table=self.denormalize_name(table_name),
)
)
return bool(result.first())

@reflection.cache
def has_schema(self, connection, schema_name, **kwargs):
result = connection.execute(
sql.text(
"SELECT 1 FROM SYS.SCHEMAS WHERE SCHEMA_NAME=:schema",
).bindparams(schema=self.denormalize_name(schema_name))
)
return bool(result.first())

@reflection.cache
def has_index(self, connection, table_name, index_name, schema=None, **kwargs):
schema = schema or self.default_schema_name

result = connection.execute(
sql.text(
"SELECT 1 FROM SYS.INDEXES "
"WHERE SCHEMA_NAME=:schema AND TABLE_NAME=:table AND INDEX_NAME=:index"
).bindparams(
schema=self.denormalize_name(schema),
table=self.denormalize_name(table_name),
index=self.denormalize_name(index_name),
)
)
return bool(result.first())

def has_sequence(self, connection, sequence_name, schema=None):
@reflection.cache
def has_sequence(self, connection, sequence_name, schema=None, **kwargs):
schema = schema or self.default_schema_name
result = connection.execute(
sql.text(
Expand All @@ -434,11 +475,13 @@ def has_sequence(self, connection, sequence_name, schema=None):
)
return bool(result.first())

@reflection.cache
def get_schema_names(self, connection, **kwargs):
result = connection.execute(sql.text("SELECT SCHEMA_NAME FROM SYS.SCHEMAS"))

return list([self.normalize_name(name) for name, in result.fetchall()])

@reflection.cache
def get_table_names(self, connection, schema=None, **kwargs):
schema = schema or self.default_schema_name

Expand Down Expand Up @@ -487,8 +530,7 @@ def get_view_names(self, connection, schema=None, **kwargs):

def get_view_definition(self, connection, view_name, schema=None, **kwargs):
schema = schema or self.default_schema_name

return connection.execute(
result = connection.execute(
sql.text(
"SELECT DEFINITION FROM SYS.VIEWS "
"WHERE VIEW_NAME=:view_name AND SCHEMA_NAME=:schema LIMIT 1",
Expand All @@ -498,8 +540,14 @@ def get_view_definition(self, connection, view_name, schema=None, **kwargs):
)
).scalar()

if result is None:
raise exc.NoSuchTableError()
return result

def get_columns(self, connection, table_name, schema=None, **kwargs):
schema = schema or self.default_schema_name
if not self.has_table(connection, table_name, schema, **kwargs):
raise exc.NoSuchTableError()

result = connection.execute(
sql.text(
Expand Down Expand Up @@ -550,8 +598,22 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):

return columns

@reflection.cache
def get_sequence_names(self, connection, schema=None, **kwargs):
schema = schema or self.default_schema_name

result = connection.execute(
sql.text(
"SELECT SEQUENCE_NAME FROM SYS.SEQUENCES "
"WHERE SCHEMA_NAME=:schema ORDER BY SEQUENCE_NAME"
).bindparams(schema=self.denormalize_name(schema))
)
return [self.normalize_name(row[0]) for row in result]

def get_foreign_keys(self, connection, table_name, schema=None, **kwargs):
lookup_schema = schema or self.default_schema_name
if not self.has_table(connection, table_name, lookup_schema, **kwargs):
raise exc.NoSuchTableError()

result = connection.execute(
sql.text(
Expand Down Expand Up @@ -579,7 +641,7 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kwargs):
foreign_key = {
"name": foreign_key_name,
"constrained_columns": [self.normalize_name(row[1])],
"referred_schema": schema,
"referred_schema": None,
"referred_table": self.normalize_name(row[3]),
"referred_columns": [self.normalize_name(row[4])],
"options": {"onupdate": row[5], "ondelete": row[6]},
Expand All @@ -591,10 +653,12 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kwargs):
foreign_keys[foreign_key_name] = foreign_key
foreign_keys_list.append(foreign_key)

return foreign_keys_list
return sorted(foreign_keys_list, key=lambda foreign_key: foreign_key["name"])

def get_indexes(self, connection, table_name, schema=None, **kwargs):
schema = schema or self.default_schema_name
if not self.has_table(connection, table_name, schema, **kwargs):
raise exc.NoSuchTableError()

result = connection.execute(
sql.text(
Expand All @@ -610,10 +674,11 @@ def get_indexes(self, connection, table_name, schema=None, **kwargs):

indexes = {}
for name, column, constraint in result.fetchall():
if name.startswith("_SYS"):
if constraint == "PRIMARY KEY":
continue

name = self.normalize_name(name)
if not name.startswith("_SYS"):
name = self.normalize_name(name)
column = self.normalize_name(column)

if name not in indexes:
Expand All @@ -629,10 +694,12 @@ def get_indexes(self, connection, table_name, schema=None, **kwargs):
else:
indexes[name]["column_names"].append(column)

return list(indexes.values())
return sorted(list(indexes.values()), key=lambda index: index["name"])

def get_pk_constraint(self, connection, table_name, schema=None, **kwargs):
schema = schema or self.default_schema_name
if not self.has_table(connection, table_name, schema, **kwargs):
raise exc.NoSuchTableError()

result = connection.execute(
sql.text(
Expand All @@ -659,12 +726,14 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kwargs):

def get_unique_constraints(self, connection, table_name, schema=None, **kwargs):
schema = schema or self.default_schema_name
if not self.has_table(connection, table_name, schema, **kwargs):
raise exc.NoSuchTableError()

result = connection.execute(
sql.text(
"SELECT CONSTRAINT_NAME, COLUMN_NAME FROM SYS.CONSTRAINTS "
"WHERE SCHEMA_NAME=:schema AND TABLE_NAME=:table AND "
"IS_UNIQUE_KEY='TRUE' AND IS_PRIMARY_KEY='FALSE'"
"IS_UNIQUE_KEY='TRUE' AND IS_PRIMARY_KEY='FALSE' "
"ORDER BY CONSTRAINT_NAME, POSITION"
).bindparams(
schema=self.denormalize_name(schema),
Expand Down Expand Up @@ -693,10 +762,15 @@ def get_unique_constraints(self, connection, table_name, schema=None, **kwargs):
constraints.append(constraint)
constraint["column_names"].append(self.normalize_name(column_name))

return constraints
return sorted(
constraints,
key=lambda constraint: (constraint["name"] is not None, constraint["name"]),
)

def get_check_constraints(self, connection, table_name, schema=None, **kwargs):
schema = schema or self.default_schema_name
if not self.has_table(connection, table_name, schema, **kwargs):
raise exc.NoSuchTableError()

result = connection.execute(
sql.text(
Expand Down Expand Up @@ -736,6 +810,8 @@ def get_table_oid(self, connection, table_name, schema=None, **kwargs):

def get_table_comment(self, connection, table_name, schema=None, **kwargs):
schema = schema or self.default_schema_name
if not self.has_table(connection, table_name, schema, **kwargs):
raise exc.NoSuchTableError()

result = connection.execute(
sql.text(
Expand Down
Loading

0 comments on commit 9d4f0c5

Please sign in to comment.