Skip to content

Commit

Permalink
fix(clickhouse): add clickhouse connect driver (apache#23185)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro authored Feb 24, 2023
1 parent f0f27a4 commit d0c54cd
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 27 deletions.
1 change: 0 additions & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
from superset.utils.network import is_hostname_valid, is_port_open

if TYPE_CHECKING:
# prevent circular imports
from superset.connectors.sqla.models import TableColumn
from superset.models.core import Database
from superset.models.sql_lab import Query
Expand Down
314 changes: 292 additions & 22 deletions superset/db_engine_specs/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,43 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
from typing import Any, cast, Dict, List, Optional, Type, TYPE_CHECKING

from flask import current_app
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
from sqlalchemy import types
from sqlalchemy.engine.url import URL
from urllib3.exceptions import NewConnectionError

from superset.db_engine_specs.base import BaseEngineSpec
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
BasicParametersType,
BasicPropertiesType,
)
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import cache_manager
from superset.utils.core import GenericDataType
from superset.utils.hashing import md5_sha_from_str
from superset.utils.network import is_hostname_valid, is_port_open

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database

logger = logging.getLogger(__name__)


class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"""Dialect for ClickHouse analytical DB."""

engine = "clickhouse"
engine_name = "ClickHouse"
class ClickHouseBaseEngineSpec(BaseEngineSpec):
"""Shared engine spec for ClickHouse."""

time_secondary_columns = True
time_groupby_inline = True
Expand All @@ -56,8 +70,78 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"P1Y": "toStartOfYear(toDateTime({col}))",
}

_show_functions_column = "name"
column_type_mappings = (
(
re.compile(r".*Enum.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*Array.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*UUID.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*Bool.*", re.IGNORECASE),
types.Boolean(),
GenericDataType.BOOLEAN,
),
(
re.compile(r".*String.*", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r".*Int\d+.*", re.IGNORECASE),
types.INTEGER(),
GenericDataType.NUMERIC,
),
(
re.compile(r".*Decimal.*", re.IGNORECASE),
types.DECIMAL(),
GenericDataType.NUMERIC,
),
(
re.compile(r".*DateTime.*", re.IGNORECASE),
types.DateTime(),
GenericDataType.TEMPORAL,
),
(
re.compile(r".*Date.*", re.IGNORECASE),
types.Date(),
GenericDataType.TEMPORAL,
),
)

@classmethod
def epoch_to_dttm(cls) -> str:
return "{col}"

@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"toDate('{dttm.date().isoformat()}')"
if isinstance(sqla_type, types.DateTime):
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
return None


class ClickHouseEngineSpec(ClickHouseBaseEngineSpec):
"""Engine spec for clickhouse_sqlalchemy connector"""

engine = "clickhouse"
engine_name = "ClickHouse"

_show_functions_column = "name"
supports_file_upload = False

@classmethod
Expand All @@ -73,21 +157,9 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
return exception
return new_exception(str(exception))

@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"toDate('{dttm.date().isoformat()}')"
if isinstance(sqla_type, types.DateTime):
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
return None

@classmethod
@cache_manager.cache.memoize()
def get_function_names(cls, database: "Database") -> List[str]:
def get_function_names(cls, database: Database) -> List[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
Expand Down Expand Up @@ -123,3 +195,201 @@ def get_function_names(cls, database: "Database") -> List[str]:

# otherwise, return no function names to prevent errors
return []


class ClickHouseParametersSchema(Schema):
username = fields.String(allow_none=True, description=__("Username"))
password = fields.String(allow_none=True, description=__("Password"))
host = fields.String(required=True, description=__("Hostname or IP address"))
port = fields.Integer(
allow_none=True,
description=__("Database port"),
validate=Range(min=0, max=65535),
)
database = fields.String(allow_none=True, description=__("Database name"))
encryption = fields.Boolean(
default=True, description=__("Use an encrypted connection to the database")
)
query = fields.Dict(
keys=fields.Str(), values=fields.Raw(), description=__("Additional parameters")
)


try:
from clickhouse_connect.common import set_setting
from clickhouse_connect.datatypes.format import set_default_formats

# override default formats for compatibility
set_default_formats(
"FixedString",
"string",
"IPv*",
"string",
"signed",
"UUID",
"string",
"*Int256",
"string",
"*Int128",
"string",
)
set_setting(
"product_name",
f"superset/{current_app.config.get('VERSION_STRING', 'dev')}",
)
except ImportError: # ClickHouse Connect not installed, do nothing
pass


class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
"""Engine spec for clickhouse-connect connector"""

engine = "clickhousedb"
engine_name = "ClickHouse Connect"

default_driver = "connect"
_function_names: List[str] = []

sqlalchemy_uri_placeholder = (
"clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]"
)
parameters_schema = ClickHouseParametersSchema()
encryption_parameters = {"secure": "true"}

@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
return {}

@classmethod
def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
new_exception = cls.get_dbapi_exception_mapping().get(type(exception))
if new_exception == SupersetDBAPIDatabaseError:
return SupersetDBAPIDatabaseError("Connection failed")
if not new_exception:
return exception
return new_exception(str(exception))

@classmethod
def get_function_names(cls, database: Database) -> List[str]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver.exceptions import ClickHouseError

if cls._function_names:
return cls._function_names
try:
names = database.get_df(
"SELECT name FROM system.functions UNION ALL "
+ "SELECT name FROM system.table_functions LIMIT 10000"
)["name"].tolist()
cls._function_names = names
return names
except ClickHouseError:
logger.exception("Error retrieving system.functions")
return []

@classmethod
def get_datatype(cls, type_code: str) -> str:
# keep it lowercase, as ClickHouse types aren't typical SHOUTCASE ANSI SQL
return type_code

@classmethod
def build_sqlalchemy_uri(
cls,
parameters: BasicParametersType,
encrypted_extra: Optional[Dict[str, str]] = None,
) -> str:
url_params = parameters.copy()
if url_params.get("encryption"):
query = parameters.get("query", {}).copy()
query.update(cls.encryption_parameters)
url_params["query"] = query
if not url_params.get("database"):
url_params["database"] = "__default__"
url_params.pop("encryption", None)
return str(URL(f"{cls.engine}+{cls.default_driver}", **url_params))

@classmethod
def get_parameters_from_uri(
cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None
) -> BasicParametersType:
url = make_url_safe(uri)
query = url.query
if "secure" in query:
encryption = url.query.get("secure") == "true"
query.pop("secure")
else:
encryption = False
return BasicParametersType(
username=url.username,
password=url.password,
host=url.host,
port=url.port,
database="" if url.database == "__default__" else cast(str, url.database),
query=dict(query),
encryption=encryption,
)

@classmethod
def validate_parameters(
cls, properties: BasicPropertiesType
) -> List[SupersetError]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver import default_port

parameters = properties.get("parameters", {})
host = parameters.get("host", None)
if not host:
return [
SupersetError(
"Hostname is required",
SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
ErrorLevel.WARNING,
{"missing": ["host"]},
)
]
if not is_hostname_valid(host):
return [
SupersetError(
"The hostname provided can't be resolved.",
SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
ErrorLevel.ERROR,
{"invalid": ["host"]},
)
]
port = parameters.get("port")
if port is None:
port = default_port("http", parameters.get("encryption", False))
try:
port = int(port)
except (ValueError, TypeError):
port = -1
if port <= 0 or port >= 65535:
return [
SupersetError(
"Port must be a valid integer between 0 and 65535 (inclusive).",
SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
ErrorLevel.ERROR,
{"invalid": ["port"]},
)
]
if not is_port_open(host, port):
return [
SupersetError(
"The port is closed.",
SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
ErrorLevel.ERROR,
{"invalid": ["port"]},
)
]
return []

@staticmethod
def _mutate_label(label: str) -> str:
"""
Suffix with the first six characters from the md5 of the label to avoid
collisions with original column names
:param label: Expected expression label
:return: Conditionally mutated label
"""
return f"{label}_{md5_sha_from_str(label)[:6]}"
Loading

0 comments on commit d0c54cd

Please sign in to comment.