Skip to content

Commit

Permalink
refactor(mysql): port to sqlglot (#7926)
Browse files Browse the repository at this point in the history
Port the MySQL backend to sqlglot.
  • Loading branch information
cpcloud authored and kszucs committed Feb 12, 2024
1 parent 6916c1d commit cba2f98
Show file tree
Hide file tree
Showing 37 changed files with 1,080 additions and 1,010 deletions.
61 changes: 27 additions & 34 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ jobs:
extras:
- polars
- deltalake
# - name: mysql
# title: MySQL
# services:
# - mysql
# extras:
# - mysql
# - geospatial
# sys-deps:
# - libgeos-dev
- name: mysql
title: MySQL
services:
- mysql
extras:
- mysql
- geospatial
sys-deps:
- libgeos-dev
- name: postgres
title: PostgreSQL
extras:
Expand Down Expand Up @@ -188,17 +188,17 @@ jobs:
# extras:
# - risingwave
exclude:
# - os: windows-latest
# backend:
# name: mysql
# title: MySQL
# extras:
# - mysql
# - geospatial
# services:
# - mysql
# sys-deps:
# - libgeos-dev
- os: windows-latest
backend:
name: mysql
title: MySQL
extras:
- mysql
- geospatial
services:
- mysql
sys-deps:
- libgeos-dev
- os: windows-latest
backend:
name: clickhouse
Expand Down Expand Up @@ -317,13 +317,13 @@ jobs:
# extras:
# - risingwave
steps:
# - name: update and install system dependencies
# if: matrix.os == 'ubuntu-latest' && matrix.backend.sys-deps != null
# run: |
# set -euo pipefail
#
# sudo apt-get update -qq -y
# sudo apt-get install -qq -y build-essential ${{ join(matrix.backend.sys-deps, ' ') }}
- name: update and install system dependencies
if: matrix.os == 'ubuntu-latest' && matrix.backend.sys-deps != null
run: |
set -euo pipefail
sudo apt-get update -qq -y
sudo apt-get install -qq -y build-essential ${{ join(matrix.backend.sys-deps, ' ') }}
- name: install sqlite
if: matrix.os == 'windows-latest' && matrix.backend.name == 'sqlite'
Expand Down Expand Up @@ -669,13 +669,6 @@ jobs:
# - freetds-dev
# - unixodbc-dev
# - tdsodbc
# - name: mysql
# title: MySQL
# services:
# - mysql
# extras:
# - geospatial
# - mysql
# - name: sqlite
# title: SQLite
# extras:
Expand Down
2 changes: 2 additions & 0 deletions docker/mysql/startup.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
CREATE USER 'ibis'@'localhost' IDENTIFIED BY 'ibis';
CREATE SCHEMA IF NOT EXISTS test_schema;
GRANT CREATE, DROP ON *.* TO 'ibis'@'%';
GRANT CREATE,SELECT,DROP ON `test_schema`.* TO 'ibis'@'%';
FLUSH PRIVILEGES;
25 changes: 0 additions & 25 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ibis import util
from ibis.backends.base import CanCreateSchema
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported
from ibis.backends.base.sql.alchemy.query_builder import AlchemyCompiler
from ibis.backends.base.sql.alchemy.registry import (
fixed_arity,
Expand Down Expand Up @@ -204,28 +203,6 @@ def _safe_raw_sql(self, *args, **kwargs):
with self.begin() as con:
yield con.execute(*args, **kwargs)

# TODO(kszucs): move to ibis.formats.pandas
@staticmethod
def _to_geodataframe(df, schema):
"""Convert `df` to a `GeoDataFrame`.
Required libraries for geospatial support must be installed and
a geospatial column is present in the dataframe.
"""
import geopandas as gpd
from geoalchemy2 import shape

geom_col = None
for name, dtype in schema.items():
if dtype.is_geospatial():
if not geom_col:
geom_col = name
df[name] = df[name].map(shape.to_shape, na_action="ignore")
if geom_col:
df[geom_col] = gpd.array.GeometryArray(df[geom_col].values)
df = gpd.GeoDataFrame(df, geometry=geom_col)
return df

def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
import pandas as pd

Expand All @@ -241,8 +218,6 @@ def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
cursor.close()
raise
df = PandasData.convert_table(df, schema)
if not df.empty and geospatial_supported:
return self._to_geodataframe(df, schema)
return df

@contextlib.contextmanager
Expand Down
145 changes: 0 additions & 145 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import sqlalchemy as sa
import sqlalchemy.types as sat
import toolz
from sqlalchemy.ext.compiler import compiles

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported
from ibis.backends.base.sqlglot.datatypes import SqlglotType
from ibis.common.collections import FrozenDict
from ibis.formats import TypeMapper

if TYPE_CHECKING:
from collections.abc import Mapping

if geospatial_supported:
import geoalchemy2 as ga


class ArrayType(sat.UserDefinedType):
def __init__(self, value_type: sat.TypeEngine):
self.value_type = sat.to_instance(value_type)

def result_processor(self, dialect, coltype) -> None:
if not coltype.lower().startswith("array"):
return None

inner_processor = (
self.value_type.result_processor(dialect, coltype[len("array(") : -1])
or toolz.identity
)

return lambda v: v if v is None else list(map(inner_processor, v))


@compiles(ArrayType, "default")
def compiles_array(element, compiler, **kw):
return f"ARRAY({compiler.process(element.value_type, **kw)})"


@compiles(sat.FLOAT, "duckdb")
def compiles_float(element, compiler, **kw):
precision = element.precision
if precision is None or 1 <= precision <= 24:
return "FLOAT"
elif 24 < precision <= 53:
return "DOUBLE"
else:
raise ValueError(
"FLOAT precision must be between 1 and 53 inclusive, or `None`"
)


class StructType(sat.UserDefinedType):
cache_ok = True

def __init__(self, fields: Mapping[str, sat.TypeEngine]) -> None:
self.fields = FrozenDict(
{name: sat.to_instance(typ) for name, typ in fields.items()}
)


@compiles(StructType, "default")
def compiles_struct(element, compiler, **kw):
quote = compiler.dialect.identifier_preparer.quote
content = ", ".join(
f"{quote(field)} {compiler.process(typ, **kw)}"
for field, typ in element.fields.items()
)
return f"STRUCT({content})"


class MapType(sat.UserDefinedType):
def __init__(self, key_type: sat.TypeEngine, value_type: sat.TypeEngine):
self.key_type = sat.to_instance(key_type)
self.value_type = sat.to_instance(value_type)


@compiles(MapType, "default")
def compiles_map(element, compiler, **kw):
key_type = compiler.process(element.key_type, **kw)
value_type = compiler.process(element.value_type, **kw)
return f"MAP({key_type}, {value_type})"


class UInt64(sat.Integer):
pass
Expand All @@ -102,30 +25,14 @@ class UInt8(sat.Integer):
pass


@compiles(UInt64, "postgresql")
@compiles(UInt32, "postgresql")
@compiles(UInt16, "postgresql")
@compiles(UInt8, "postgresql")
@compiles(UInt64, "mssql")
@compiles(UInt32, "mssql")
@compiles(UInt16, "mssql")
@compiles(UInt8, "mssql")
@compiles(UInt64, "mysql")
@compiles(UInt32, "mysql")
@compiles(UInt16, "mysql")
@compiles(UInt8, "mysql")
@compiles(UInt64, "snowflake")
@compiles(UInt32, "snowflake")
@compiles(UInt16, "snowflake")
@compiles(UInt8, "snowflake")
@compiles(UInt64, "sqlite")
@compiles(UInt32, "sqlite")
@compiles(UInt16, "sqlite")
@compiles(UInt8, "sqlite")
@compiles(UInt64, "trino")
@compiles(UInt32, "trino")
@compiles(UInt16, "trino")
@compiles(UInt8, "trino")
def compile_uint(element, compiler, **kw):
dialect_name = compiler.dialect.name
raise TypeError(
Expand Down Expand Up @@ -220,17 +127,6 @@ class Unknown(sa.Text):
53: dt.Float64,
}

_GEOSPATIAL_TYPES = {
"POINT": dt.Point,
"LINESTRING": dt.LineString,
"POLYGON": dt.Polygon,
"MULTILINESTRING": dt.MultiLineString,
"MULTIPOINT": dt.MultiPoint,
"MULTIPOLYGON": dt.MultiPolygon,
"GEOMETRY": dt.Geometry,
"GEOGRAPHY": dt.Geography,
}


class AlchemyType(TypeMapper):
@classmethod
Expand Down Expand Up @@ -261,25 +157,6 @@ def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine:
return sat.NUMERIC(dtype.precision, dtype.scale)
elif dtype.is_timestamp():
return sat.TIMESTAMP(timezone=bool(dtype.timezone))
elif dtype.is_array():
return ArrayType(cls.from_ibis(dtype.value_type))
elif dtype.is_struct():
fields = {k: cls.from_ibis(v) for k, v in dtype.fields.items()}
return StructType(fields)
elif dtype.is_map():
return MapType(
cls.from_ibis(dtype.key_type), cls.from_ibis(dtype.value_type)
)
elif dtype.is_geospatial():
if geospatial_supported:
if dtype.geotype == "geometry":
return ga.Geometry
elif dtype.geotype == "geography":
return ga.Geography
else:
return ga.types._GISType
else:
raise TypeError("geospatial types are not supported")
else:
return _to_sqlalchemy_types[type(dtype)]

Expand All @@ -306,32 +183,10 @@ def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType:
return dt.Decimal(typ.precision, typ.scale, nullable=nullable)
elif isinstance(typ, sat.Numeric):
return dt.Decimal(typ.precision, typ.scale, nullable=nullable)
elif isinstance(typ, ArrayType):
return dt.Array(cls.to_ibis(typ.value_type), nullable=nullable)
elif isinstance(typ, sat.ARRAY):
ndim = typ.dimensions
if ndim is not None and ndim != 1:
raise NotImplementedError("Nested array types not yet supported")
return dt.Array(cls.to_ibis(typ.item_type), nullable=nullable)
elif isinstance(typ, StructType):
fields = {k: cls.to_ibis(v) for k, v in typ.fields.items()}
return dt.Struct(fields, nullable=nullable)
elif isinstance(typ, MapType):
return dt.Map(
cls.to_ibis(typ.key_type),
cls.to_ibis(typ.value_type),
nullable=nullable,
)
elif isinstance(typ, sa.DateTime):
timezone = "UTC" if typ.timezone else None
return dt.Timestamp(timezone, nullable=nullable)
elif isinstance(typ, sat.String):
return dt.String(nullable=nullable)
elif geospatial_supported and isinstance(typ, ga.types._GISType):
name = typ.geometry_type.upper()
try:
return _GEOSPATIAL_TYPES[name](geotype=typ.name, nullable=nullable)
except KeyError:
raise ValueError(f"Unrecognized geometry type: {name}")
else:
raise TypeError(f"Unable to convert type: {typ!r}")
10 changes: 0 additions & 10 deletions ibis/backends/base/sql/alchemy/geospatial.py

This file was deleted.

10 changes: 10 additions & 0 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
typecode.TEXT: dt.String,
typecode.TIME: dt.Time,
typecode.TIMETZ: dt.Time,
typecode.TINYBLOB: dt.Binary,
typecode.TINYINT: dt.Int8,
typecode.TINYTEXT: dt.String,
typecode.UBIGINT: dt.UInt64,
typecode.UINT: dt.UInt32,
typecode.USMALLINT: dt.UInt16,
Expand Down Expand Up @@ -400,6 +402,10 @@ class DataFusionType(PostgresType):

class MySQLType(SqlglotType):
dialect = "mysql"
# these are mysql's defaults, see
# https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
default_decimal_precision = 10
default_decimal_scale = 0

unknown_type_strings = FrozenDict(
{
Expand Down Expand Up @@ -428,6 +434,10 @@ def _from_sqlglot_DATETIME(cls) -> dt.Timestamp:
def _from_sqlglot_TIMESTAMP(cls) -> dt.Timestamp:
return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable)

@classmethod
def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType:
return sge.DataType(this=typecode.TEXT)


class DuckDBType(SqlglotType):
dialect = "duckdb"
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sqlglot/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def rewrite_empty_order_by_window(_, y):

@replace(p.WindowFunction(p.RowNumber | p.NTile, y))
def exclude_unsupported_window_frame_from_row_number(_, y):
return ops.Subtract(_.copy(frame=y.copy(start=None, end=None)), 1)
return ops.Subtract(_.copy(frame=y.copy(start=None, end=0)), 1)


@replace(
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def ddl_con(ddl_backend):
keep=(
"exasol",
"mssql",
"mysql",
"oracle",
"risingwave",
"sqlite",
Expand Down
Loading

0 comments on commit cba2f98

Please sign in to comment.