Skip to content

Commit

Permalink
refactor(druid): port to sqlglot
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Feb 12, 2024
1 parent c65d9f8 commit 85e2b16
Show file tree
Hide file tree
Showing 33 changed files with 627 additions and 649 deletions.
28 changes: 14 additions & 14 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ jobs:
- trino
services:
- trino
# - name: druid
# title: Druid
# extras:
# - druid
# services:
# - druid
- name: druid
title: Druid
extras:
- druid
services:
- druid
# - name: oracle
# title: Oracle
# serial: true
Expand Down Expand Up @@ -264,14 +264,14 @@ jobs:
- trino
extras:
- trino
# - os: windows-latest
# backend:
# name: druid
# title: Druid
# extras:
# - druid
# services:
# - druid
- os: windows-latest
backend:
name: druid
title: Druid
extras:
- druid
services:
- druid
# - os: windows-latest
# backend:
# name: oracle
Expand Down
2 changes: 1 addition & 1 deletion docker/druid/environment
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ druid_extensions_loadList=["postgresql-metadata-storage", "druid-multi-stage-que
druid_zk_service_host=zookeeper

druid_worker_capacity=6
druid_generic_useDefaultValueForNull=true
druid_generic_useDefaultValueForNull=false

druid_metadata_storage_host=
druid_metadata_storage_type=postgresql
Expand Down
227 changes: 134 additions & 93 deletions ibis/backends/druid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,101 +4,104 @@

import contextlib
import json
import warnings
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlparse

import sqlalchemy as sa
import pydruid
import sqlglot as sg

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
import ibis.expr.schema as sch
from ibis.backends.base.sqlglot import SQLGlotBackend
from ibis.backends.base.sqlglot.compiler import STAR
from ibis.backends.base.sqlglot.datatypes import DruidType
from ibis.backends.druid.compiler import DruidCompiler
from ibis.backends.druid.datatypes import (
DruidBinary,
DruidDateTime,
DruidString,
DruidType,
)

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Mapping

import pandas as pd
import pyarrow as pa

class Backend(BaseAlchemyBackend):
import ibis.expr.types as ir


class Backend(SQLGlotBackend):
name = "druid"
compiler = DruidCompiler
compiler = DruidCompiler()
supports_create_or_replace = False
supports_in_memory_tables = True

@property
def current_database(self) -> str:
# https://druid.apache.org/docs/latest/querying/sql-metadata-tables.html#schemata-table
return "druid"
def version(self) -> str:
with self._safe_raw_sql("SELECT version()") as result:
[(version,)] = result.fetchall()
return version

def do_connect(
self,
host: str = "localhost",
port: int = 8082,
database: str | None = "druid/v2/sql",
**_: Any,
) -> None:
"""Create an Ibis client using the passed connection parameters.
def _from_url(self, url: str, **kwargs):
"""Connect to a backend using a URL `url`.
Parameters
----------
host
Hostname
port
Port
database
Database to connect to
url
URL with which to connect to a backend.
kwargs
Additional keyword arguments
Returns
-------
BaseBackend
A backend instance
"""
url = sa.engine.url.make_url(f"druid://{host}:{port}/{database}?header=true")

self.database_name = "default" # not sure what should go here

engine = sa.create_engine(url, poolclass=sa.pool.StaticPool)

super().do_connect(engine)

# workaround a broken pydruid `has_table` implementation
engine.dialect.has_table = self._has_table
url = urlparse(url)
query_params = parse_qs(url.query)
kwargs = {
"user": url.username,
"password": url.password,
"host": url.hostname,
"path": url.path,
"port": url.port,
} | kwargs

for name, value in query_params.items():
if len(value) > 1:
kwargs[name] = value
elif len(value) == 1:
kwargs[name] = value[0]
else:
raise com.IbisError(f"Invalid URL parameter: {name}")

# don't double percent signs
engine.dialect.identifier_preparer._double_percents = False
self._convert_kwargs(kwargs)

@staticmethod
def _new_sa_metadata():
meta = sa.MetaData()
return self.connect(**kwargs)

@sa.event.listens_for(meta, "column_reflect")
def column_reflect(inspector, table, column_info):
if isinstance(typ := column_info["type"], sa.DateTime):
column_info["type"] = DruidDateTime()
elif isinstance(typ, (sa.LargeBinary, sa.BINARY, sa.VARBINARY)):
column_info["type"] = DruidBinary()
elif isinstance(typ, sa.String):
column_info["type"] = DruidString()
@property
def current_database(self) -> str:
# https://druid.apache.org/docs/latest/querying/sql-metadata-tables.html#schemata-table
return "druid"

return meta
def do_connect(self, **kwargs: Any) -> None:
"""Create an Ibis client using the passed connection parameters."""
header = kwargs.pop("header", True)
self.con = pydruid.db.connect(**kwargs, header=header)
self._temp_views = set()

@contextlib.contextmanager
def _safe_raw_sql(self, query, *args, **kwargs):
query = query.compile(
dialect=self.con.dialect, compile_kwargs=dict(literal_binds=True)
)
with contextlib.suppress(AttributeError):
query = query.sql(dialect=self.compiler.dialect)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Dialect druid:rest will not make use of SQL compilation caching",
category=sa.exc.SAWarning,
)
with self.begin() as con:
yield con.execute(query, *args, **kwargs)
with contextlib.closing(self.con.cursor()) as cur:
cur.execute(query, *args, **kwargs)
yield cur

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
result = self._scalar_query(f"EXPLAIN PLAN FOR {query}")
with self._safe_raw_sql(f"EXPLAIN PLAN FOR {query}") as result:
[(row, *_)] = result.fetchall()

(plan,) = json.loads(result)
(plan,) = json.loads(row)
for column in plan["signature"]:
name, typ = column["name"], column["type"]
if name == "__time":
Expand All @@ -107,33 +110,71 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
dtype = DruidType.from_string(typ)
yield name, dtype

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
raise NotImplementedError()

def _has_table(self, connection, table_name: str, schema) -> bool:
t = sa.table(
"TABLES", sa.column("TABLE_NAME", sa.TEXT), schema="INFORMATION_SCHEMA"
def get_schema(
self, table_name: str, schema: str | None = None, database: str | None = None
) -> sch.Schema:
name_type_pairs = self._metadata(
sg.select(STAR)
.from_(sg.table(table_name, db=schema, catalog=database))
.sql(self.compiler.dialect)
)
query = sa.select(
sa.func.sum(sa.cast(t.c.TABLE_NAME == table_name, sa.INTEGER))
).compile(dialect=self.con.dialect)

return bool(connection.execute(query).scalar())

def _get_sqla_table(
self, name: str, autoload: bool = True, **kwargs: Any
) -> sa.Table:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="|".join( # noqa: FLY002
(
"Did not recognize type",
"Dialect druid:rest will not make use of SQL compilation caching",
)
),
category=sa.exc.SAWarning,
return sch.Schema.from_tuples(name_type_pairs)

def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
import pandas as pd

from ibis.formats.pandas import PandasData

try:
df = pd.DataFrame.from_records(
cursor, columns=schema.names, coerce_float=True
)
return super()._get_sqla_table(name, autoload=autoload, **kwargs)
except Exception:
# clean up the cursor if we fail to create the DataFrame
cursor.close()
raise
df = PandasData.convert_table(df, schema)
return df

def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
temp: bool = False,
overwrite: bool = False,
) -> ir.Table:
raise NotImplementedError()

def list_tables(
self, like: str | None = None, database: str | None = None
) -> list[str]:
t = sg.table("TABLES", db="INFORMATION_SCHEMA", quoted=True)
c = self.compiler
query = sg.select(sg.column("TABLE_NAME", quoted=True)).from_(t).sql(c.dialect)

with self._safe_raw_sql(query) as result:
tables = result.fetchall()
return self._filter_with_like([table.TABLE_NAME for table in tables], like=like)

def _register_in_memory_tables(self, expr):
"""No-op. Table are inlined, for better or worse."""

def _cursor_batches(
self,
expr: ir.Expr,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
chunk_size: int = 1 << 20,
) -> Iterable[list]:
self._run_pre_execute_hooks(expr)

dtypes = expr.as_table().schema().values()

with self._safe_raw_sql(
self.compile(expr, limit=limit, params=params)
) as cursor:
while batch := cursor.fetchmany(chunk_size):
yield (tuple(map(dt.normalize, dtypes, row)) for row in batch)
Loading

0 comments on commit 85e2b16

Please sign in to comment.