Skip to content

Commit

Permalink
Switch to new SQLAlchemy dialect for CrateDB
Browse files Browse the repository at this point in the history
This includes the fix to the `get_table_names()` reflection method.
  • Loading branch information
amotl committed Jun 13, 2024
1 parent 0741181 commit f095f28
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 40 deletions.
1 change: 0 additions & 1 deletion cratedb_toolkit/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .patch import patch_inspector
from .polyfill import check_uniqueness_factory, polyfill_autoincrement, polyfill_refresh_after_dml, refresh_table
37 changes: 0 additions & 37 deletions cratedb_toolkit/sqlalchemy/patch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import calendar
import datetime as dt
import json
import typing as t
from decimal import Decimal
from uuid import UUID

import sqlalchemy as sa

try:
import numpy as np

Expand All @@ -15,40 +12,6 @@
has_numpy = False


def patch_inspector():
"""
When using `get_table_names()`, make sure the correct schema name gets used.
Apparently, SQLAlchemy does not honor the `search_path` of the engine, when
using the inspector?
FIXME: Bug in CrateDB SQLAlchemy dialect?
"""

def get_effective_schema(engine: sa.Engine):
schema_name_raw = engine.url.query.get("schema")
schema_name = None
if isinstance(schema_name_raw, str):
schema_name = schema_name_raw
elif isinstance(schema_name_raw, tuple):
schema_name = schema_name_raw[0]
return schema_name

try:
from sqlalchemy_cratedb import dialect
except ImportError: # pragma: nocover
from crate.client.sqlalchemy.dialect import CrateDialect as dialect

get_table_names_dist = dialect.get_table_names

def get_table_names(self, connection: sa.Connection, schema: t.Optional[str] = None, **kw: t.Any) -> t.List[str]:
if schema is None:
schema = get_effective_schema(connection.engine)
return get_table_names_dist(self, connection=connection, schema=schema, **kw)

dialect.get_table_names = get_table_names # type: ignore


def patch_encoder():
import crate.client.http

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ dependencies = [
"colorama<1",
"colorlog",
"crash",
"croud==1.8",
"fastapi<0.105",
'importlib-metadata; python_version < "3.8"',
'importlib-resources; python_version < "3.9"',
Expand Down
2 changes: 0 additions & 2 deletions tests/sqlalchemy/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
import sqlalchemy as sa

from cratedb_toolkit.sqlalchemy import patch_inspector
from cratedb_toolkit.sqlalchemy.patch import CrateJsonEncoderWithNumPy
from tests.conftest import TESTDRIVE_DATA_SCHEMA

Expand Down Expand Up @@ -37,7 +36,6 @@ def test_inspector_patched(database):
This verifies that it still works, when it properly has been assigned to
the `?schema=` connection string URL parameter.
"""
patch_inspector()
tablename = f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"'
inspector: sa.Inspector = sa.inspect(database.engine)
database.run_sql(f"CREATE TABLE {tablename} AS SELECT 1")
Expand Down

0 comments on commit f095f28

Please sign in to comment.