From cba2f9806680609a1668a0f756262a280773d2b0 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 6 Jan 2024 07:34:22 -0500 Subject: [PATCH] refactor(mysql): port to sqlglot (#7926) Port the MySQL backend to sqlglot. --- .github/workflows/ibis-backends.yml | 61 +-- docker/mysql/startup.sql | 2 + ibis/backends/base/sql/alchemy/__init__.py | 25 - ibis/backends/base/sql/alchemy/datatypes.py | 145 ----- ibis/backends/base/sql/alchemy/geospatial.py | 10 - ibis/backends/base/sqlglot/datatypes.py | 10 + ibis/backends/base/sqlglot/rewrites.py | 2 +- ibis/backends/conftest.py | 1 - ibis/backends/mssql/tests/test_client.py | 9 - ibis/backends/mysql/__init__.py | 500 +++++++++++++++--- ibis/backends/mysql/compiler.py | 430 ++++++++++++++- ibis/backends/mysql/converter.py | 26 + ibis/backends/mysql/datatypes.py | 101 ---- ibis/backends/mysql/registry.py | 265 ---------- ibis/backends/mysql/tests/conftest.py | 60 +-- ibis/backends/mysql/tests/test_client.py | 73 +-- ibis/backends/tests/errors.py | 7 + .../test_default_limit/mysql/out.sql | 5 + .../test_disable_query_limit/mysql/out.sql | 5 + .../mysql/out.sql | 19 + .../test_respect_set_limit/mysql/out.sql | 10 + .../test_group_by_has_index/mysql/out.sql | 8 +- .../test_sql/test_isin_bug/mysql/out.sql | 18 +- ibis/backends/tests/test_aggregation.py | 10 +- ibis/backends/tests/test_array.py | 27 +- ibis/backends/tests/test_asof_join.py | 16 +- ibis/backends/tests/test_client.py | 14 +- ibis/backends/tests/test_export.py | 3 +- ibis/backends/tests/test_generic.py | 16 +- ibis/backends/tests/test_join.py | 15 +- ibis/backends/tests/test_numeric.py | 67 +-- ibis/backends/tests/test_sql.py | 2 +- ibis/backends/tests/test_string.py | 14 +- ibis/backends/tests/test_temporal.py | 93 +--- ibis/backends/tests/test_window.py | 10 +- ibis/formats/pandas.py | 5 +- pyproject.toml | 6 +- 37 files changed, 1080 insertions(+), 1010 deletions(-) delete mode 100644 ibis/backends/base/sql/alchemy/geospatial.py create mode 100644 ibis/backends/mysql/converter.py delete mode 100644 ibis/backends/mysql/datatypes.py delete mode 100644 ibis/backends/mysql/registry.py create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_default_limit/mysql/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/mysql/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/mysql/out.sql create mode 100644 ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/mysql/out.sql diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index 7b8d35caf3ac..d1824231d6e1 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -96,15 +96,15 @@ jobs: extras: - polars - deltalake - # - name: mysql - # title: MySQL - # services: - # - mysql - # extras: - # - mysql - # - geospatial - # sys-deps: - # - libgeos-dev + - name: mysql + title: MySQL + services: + - mysql + extras: + - mysql + - geospatial + sys-deps: + - libgeos-dev - name: postgres title: PostgreSQL extras: @@ -188,17 +188,17 @@ jobs: # extras: # - risingwave exclude: - # - os: windows-latest - # backend: - # name: mysql - # title: MySQL - # extras: - # - mysql - # - geospatial - # services: - # - mysql - # sys-deps: - # - libgeos-dev + - os: windows-latest + backend: + name: mysql + title: MySQL + extras: + - mysql + - geospatial + services: + - mysql + sys-deps: + - libgeos-dev - os: windows-latest backend: name: clickhouse @@ -317,13 +317,13 @@ jobs: # extras: # - risingwave steps: - # - name: update and install system dependencies - # if: matrix.os == 'ubuntu-latest' && matrix.backend.sys-deps != null - # run: | - # set -euo pipefail - # - # sudo apt-get update -qq -y - # sudo apt-get install -qq -y build-essential ${{ join(matrix.backend.sys-deps, ' ') }} + - name: update and install system dependencies + if: matrix.os == 'ubuntu-latest' && matrix.backend.sys-deps != null + run: | + set -euo pipefail + + sudo apt-get update -qq -y + sudo apt-get install -qq -y build-essential ${{ join(matrix.backend.sys-deps, ' ') }} - name: install sqlite if: matrix.os == 'windows-latest' && matrix.backend.name == 'sqlite' @@ -669,13 +669,6 @@ jobs: # - freetds-dev # - unixodbc-dev # - tdsodbc - # - name: mysql - # title: MySQL - # services: - # - mysql - # extras: - # - geospatial - # - mysql # - name: sqlite # title: SQLite # extras: diff --git a/docker/mysql/startup.sql b/docker/mysql/startup.sql index 06d40f979281..29982e2f7d89 100644 --- a/docker/mysql/startup.sql +++ b/docker/mysql/startup.sql @@ -1,3 +1,5 @@ CREATE USER 'ibis'@'localhost' IDENTIFIED BY 'ibis'; +CREATE SCHEMA IF NOT EXISTS test_schema; GRANT CREATE, DROP ON *.* TO 'ibis'@'%'; +GRANT CREATE,SELECT,DROP ON `test_schema`.* TO 'ibis'@'%'; FLUSH PRIVILEGES; diff --git a/ibis/backends/base/sql/alchemy/__init__.py b/ibis/backends/base/sql/alchemy/__init__.py index 6088e75447b8..ec64b484061a 100644 --- a/ibis/backends/base/sql/alchemy/__init__.py +++ b/ibis/backends/base/sql/alchemy/__init__.py @@ -22,7 +22,6 @@ from ibis import util from ibis.backends.base import CanCreateSchema from ibis.backends.base.sql import BaseSQLBackend -from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported from ibis.backends.base.sql.alchemy.query_builder import AlchemyCompiler from ibis.backends.base.sql.alchemy.registry import ( fixed_arity, @@ -204,28 +203,6 @@ def _safe_raw_sql(self, *args, **kwargs): with self.begin() as con: yield con.execute(*args, **kwargs) - # TODO(kszucs): move to ibis.formats.pandas - @staticmethod - def _to_geodataframe(df, schema): - """Convert `df` to a `GeoDataFrame`. - - Required libraries for geospatial support must be installed and - a geospatial column is present in the dataframe. - """ - import geopandas as gpd - from geoalchemy2 import shape - - geom_col = None - for name, dtype in schema.items(): - if dtype.is_geospatial(): - if not geom_col: - geom_col = name - df[name] = df[name].map(shape.to_shape, na_action="ignore") - if geom_col: - df[geom_col] = gpd.array.GeometryArray(df[geom_col].values) - df = gpd.GeoDataFrame(df, geometry=geom_col) - return df - def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: import pandas as pd @@ -241,8 +218,6 @@ def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: cursor.close() raise df = PandasData.convert_table(df, schema) - if not df.empty and geospatial_supported: - return self._to_geodataframe(df, schema) return df @contextlib.contextmanager diff --git a/ibis/backends/base/sql/alchemy/datatypes.py b/ibis/backends/base/sql/alchemy/datatypes.py index 1608faee29aa..d78739264dc3 100644 --- a/ibis/backends/base/sql/alchemy/datatypes.py +++ b/ibis/backends/base/sql/alchemy/datatypes.py @@ -1,90 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import sqlalchemy as sa import sqlalchemy.types as sat -import toolz from sqlalchemy.ext.compiler import compiles import ibis.expr.datatypes as dt -from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported from ibis.backends.base.sqlglot.datatypes import SqlglotType -from ibis.common.collections import FrozenDict from ibis.formats import TypeMapper -if TYPE_CHECKING: - from collections.abc import Mapping - -if geospatial_supported: - import geoalchemy2 as ga - - -class ArrayType(sat.UserDefinedType): - def __init__(self, value_type: sat.TypeEngine): - self.value_type = sat.to_instance(value_type) - - def result_processor(self, dialect, coltype) -> None: - if not coltype.lower().startswith("array"): - return None - - inner_processor = ( - self.value_type.result_processor(dialect, coltype[len("array(") : -1]) - or toolz.identity - ) - - return lambda v: v if v is None else list(map(inner_processor, v)) - - -@compiles(ArrayType, "default") -def compiles_array(element, compiler, **kw): - return f"ARRAY({compiler.process(element.value_type, **kw)})" - - -@compiles(sat.FLOAT, "duckdb") -def compiles_float(element, compiler, **kw): - precision = element.precision - if precision is None or 1 <= precision <= 24: - return "FLOAT" - elif 24 < precision <= 53: - return "DOUBLE" - else: - raise ValueError( - "FLOAT precision must be between 1 and 53 inclusive, or `None`" - ) - - -class StructType(sat.UserDefinedType): - cache_ok = True - - def __init__(self, fields: Mapping[str, sat.TypeEngine]) -> None: - self.fields = FrozenDict( - {name: sat.to_instance(typ) for name, typ in fields.items()} - ) - - -@compiles(StructType, "default") -def compiles_struct(element, compiler, **kw): - quote = compiler.dialect.identifier_preparer.quote - content = ", ".join( - f"{quote(field)} {compiler.process(typ, **kw)}" - for field, typ in element.fields.items() - ) - return f"STRUCT({content})" - - -class MapType(sat.UserDefinedType): - def __init__(self, key_type: sat.TypeEngine, value_type: sat.TypeEngine): - self.key_type = sat.to_instance(key_type) - self.value_type = sat.to_instance(value_type) - - -@compiles(MapType, "default") -def compiles_map(element, compiler, **kw): - key_type = compiler.process(element.key_type, **kw) - value_type = compiler.process(element.value_type, **kw) - return f"MAP({key_type}, {value_type})" - class UInt64(sat.Integer): pass @@ -102,30 +25,14 @@ class UInt8(sat.Integer): pass -@compiles(UInt64, "postgresql") -@compiles(UInt32, "postgresql") -@compiles(UInt16, "postgresql") -@compiles(UInt8, "postgresql") @compiles(UInt64, "mssql") @compiles(UInt32, "mssql") @compiles(UInt16, "mssql") @compiles(UInt8, "mssql") -@compiles(UInt64, "mysql") -@compiles(UInt32, "mysql") -@compiles(UInt16, "mysql") -@compiles(UInt8, "mysql") -@compiles(UInt64, "snowflake") -@compiles(UInt32, "snowflake") -@compiles(UInt16, "snowflake") -@compiles(UInt8, "snowflake") @compiles(UInt64, "sqlite") @compiles(UInt32, "sqlite") @compiles(UInt16, "sqlite") @compiles(UInt8, "sqlite") -@compiles(UInt64, "trino") -@compiles(UInt32, "trino") -@compiles(UInt16, "trino") -@compiles(UInt8, "trino") def compile_uint(element, compiler, **kw): dialect_name = compiler.dialect.name raise TypeError( @@ -220,17 +127,6 @@ class Unknown(sa.Text): 53: dt.Float64, } -_GEOSPATIAL_TYPES = { - "POINT": dt.Point, - "LINESTRING": dt.LineString, - "POLYGON": dt.Polygon, - "MULTILINESTRING": dt.MultiLineString, - "MULTIPOINT": dt.MultiPoint, - "MULTIPOLYGON": dt.MultiPolygon, - "GEOMETRY": dt.Geometry, - "GEOGRAPHY": dt.Geography, -} - class AlchemyType(TypeMapper): @classmethod @@ -261,25 +157,6 @@ def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine: return sat.NUMERIC(dtype.precision, dtype.scale) elif dtype.is_timestamp(): return sat.TIMESTAMP(timezone=bool(dtype.timezone)) - elif dtype.is_array(): - return ArrayType(cls.from_ibis(dtype.value_type)) - elif dtype.is_struct(): - fields = {k: cls.from_ibis(v) for k, v in dtype.fields.items()} - return StructType(fields) - elif dtype.is_map(): - return MapType( - cls.from_ibis(dtype.key_type), cls.from_ibis(dtype.value_type) - ) - elif dtype.is_geospatial(): - if geospatial_supported: - if dtype.geotype == "geometry": - return ga.Geometry - elif dtype.geotype == "geography": - return ga.Geography - else: - return ga.types._GISType - else: - raise TypeError("geospatial types are not supported") else: return _to_sqlalchemy_types[type(dtype)] @@ -306,32 +183,10 @@ def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType: return dt.Decimal(typ.precision, typ.scale, nullable=nullable) elif isinstance(typ, sat.Numeric): return dt.Decimal(typ.precision, typ.scale, nullable=nullable) - elif isinstance(typ, ArrayType): - return dt.Array(cls.to_ibis(typ.value_type), nullable=nullable) - elif isinstance(typ, sat.ARRAY): - ndim = typ.dimensions - if ndim is not None and ndim != 1: - raise NotImplementedError("Nested array types not yet supported") - return dt.Array(cls.to_ibis(typ.item_type), nullable=nullable) - elif isinstance(typ, StructType): - fields = {k: cls.to_ibis(v) for k, v in typ.fields.items()} - return dt.Struct(fields, nullable=nullable) - elif isinstance(typ, MapType): - return dt.Map( - cls.to_ibis(typ.key_type), - cls.to_ibis(typ.value_type), - nullable=nullable, - ) elif isinstance(typ, sa.DateTime): timezone = "UTC" if typ.timezone else None return dt.Timestamp(timezone, nullable=nullable) elif isinstance(typ, sat.String): return dt.String(nullable=nullable) - elif geospatial_supported and isinstance(typ, ga.types._GISType): - name = typ.geometry_type.upper() - try: - return _GEOSPATIAL_TYPES[name](geotype=typ.name, nullable=nullable) - except KeyError: - raise ValueError(f"Unrecognized geometry type: {name}") else: raise TypeError(f"Unable to convert type: {typ!r}") diff --git a/ibis/backends/base/sql/alchemy/geospatial.py b/ibis/backends/base/sql/alchemy/geospatial.py deleted file mode 100644 index 41b86ca00e1b..000000000000 --- a/ibis/backends/base/sql/alchemy/geospatial.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from importlib.util import find_spec as _find_spec - -geospatial_supported = ( - _find_spec("geoalchemy2") is not None - and _find_spec("geopandas") is not None - and _find_spec("shapely") is not None -) -__all__ = ["geospatial_supported"] diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index 21f2242251c0..8b379d1d0db9 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -50,7 +50,9 @@ typecode.TEXT: dt.String, typecode.TIME: dt.Time, typecode.TIMETZ: dt.Time, + typecode.TINYBLOB: dt.Binary, typecode.TINYINT: dt.Int8, + typecode.TINYTEXT: dt.String, typecode.UBIGINT: dt.UInt64, typecode.UINT: dt.UInt32, typecode.USMALLINT: dt.UInt16, @@ -400,6 +402,10 @@ class DataFusionType(PostgresType): class MySQLType(SqlglotType): dialect = "mysql" + # these are mysql's defaults, see + # https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html + default_decimal_precision = 10 + default_decimal_scale = 0 unknown_type_strings = FrozenDict( { @@ -428,6 +434,10 @@ def _from_sqlglot_DATETIME(cls) -> dt.Timestamp: def _from_sqlglot_TIMESTAMP(cls) -> dt.Timestamp: return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable) + @classmethod + def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType: + return sge.DataType(this=typecode.TEXT) + class DuckDBType(SqlglotType): dialect = "duckdb" diff --git a/ibis/backends/base/sqlglot/rewrites.py b/ibis/backends/base/sqlglot/rewrites.py index 522380d9111e..c6b02d23423c 100644 --- a/ibis/backends/base/sqlglot/rewrites.py +++ b/ibis/backends/base/sqlglot/rewrites.py @@ -176,7 +176,7 @@ def rewrite_empty_order_by_window(_, y): @replace(p.WindowFunction(p.RowNumber | p.NTile, y)) def exclude_unsupported_window_frame_from_row_number(_, y): - return ops.Subtract(_.copy(frame=y.copy(start=None, end=None)), 1) + return ops.Subtract(_.copy(frame=y.copy(start=None, end=0)), 1) @replace( diff --git a/ibis/backends/conftest.py b/ibis/backends/conftest.py index d9475b533ea5..afa2479a44a0 100644 --- a/ibis/backends/conftest.py +++ b/ibis/backends/conftest.py @@ -537,7 +537,6 @@ def ddl_con(ddl_backend): keep=( "exasol", "mssql", - "mysql", "oracle", "risingwave", "sqlite", diff --git a/ibis/backends/mssql/tests/test_client.py b/ibis/backends/mssql/tests/test_client.py index 12012ac929d6..b26c78a53c38 100644 --- a/ibis/backends/mssql/tests/test_client.py +++ b/ibis/backends/mssql/tests/test_client.py @@ -7,7 +7,6 @@ import ibis import ibis.expr.datatypes as dt from ibis import udf -from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported DB_TYPES = [ # Exact numbers @@ -53,10 +52,6 @@ ] -skipif_no_geospatial_deps = pytest.mark.skipif( - not geospatial_supported, reason="geospatial dependencies not installed" -) - broken_sqlalchemy_autoload = pytest.mark.xfail( reason="scale not inferred by sqlalchemy autoload" ) @@ -65,10 +60,6 @@ @pytest.mark.parametrize( ("server_type", "expected_type"), DB_TYPES - + [ - param("GEOMETRY", dt.geometry, marks=[skipif_no_geospatial_deps]), - param("GEOGRAPHY", dt.geography, marks=[skipif_no_geospatial_deps]), - ] + [ param( "DATETIME2(4)", dt.timestamp(scale=4), marks=[broken_sqlalchemy_autoload] diff --git a/ibis/backends/mysql/__init__.py b/ibis/backends/mysql/__init__.py index 7ea409b73701..a52edbe5fba8 100644 --- a/ibis/backends/mysql/__init__.py +++ b/ibis/backends/mysql/__init__.py @@ -2,31 +2,101 @@ from __future__ import annotations +import atexit +import contextlib +import re import warnings -from typing import TYPE_CHECKING, Literal +from functools import cached_property, partial +from itertools import repeat +from operator import itemgetter +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qs, urlparse import pymysql -import sqlalchemy as sa -from sqlalchemy.dialects import mysql +import sqlglot as sg +import sqlglot.expressions as sge +import ibis +import ibis.common.exceptions as com +import ibis.expr.operations as ops import ibis.expr.schema as sch +import ibis.expr.types as ir from ibis import util from ibis.backends.base import CanCreateDatabase -from ibis.backends.base.sql.alchemy import BaseAlchemyBackend +from ibis.backends.base.sqlglot import SQLGlotBackend +from ibis.backends.base.sqlglot.compiler import TRUE, C from ibis.backends.mysql.compiler import MySQLCompiler -from ibis.backends.mysql.datatypes import MySQLDateTime, MySQLType if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Mapping + + import pandas as pd + import pyarrow as pa import ibis.expr.datatypes as dt -class Backend(BaseAlchemyBackend, CanCreateDatabase): +class Backend(SQLGlotBackend, CanCreateDatabase): name = "mysql" - compiler = MySQLCompiler + compiler = MySQLCompiler() supports_create_or_replace = False + def _from_url(self, url: str, **kwargs): + """Connect to a backend using a URL `url`. + + Parameters + ---------- + url + URL with which to connect to a backend. + kwargs + Additional keyword arguments + + Returns + ------- + BaseBackend + A backend instance + """ + + url = urlparse(url) + database, *_ = url.path[1:].split("/", 1) + query_params = parse_qs(url.query) + connect_args = { + "user": url.username, + "password": url.password or "", + "host": url.hostname, + "database": database or "", + } + + for name, value in query_params.items(): + if len(value) > 1: + connect_args[name] = value + elif len(value) == 1: + connect_args[name] = value[0] + else: + raise com.IbisError(f"Invalid URL parameter: {name}") + + kwargs.update(connect_args) + self._convert_kwargs(kwargs) + + if "user" in kwargs and not kwargs["user"]: + del kwargs["user"] + + if "host" in kwargs and not kwargs["host"]: + del kwargs["host"] + + if "database" in kwargs and not kwargs["database"]: + del kwargs["database"] + + if "password" in kwargs and kwargs["password"] is None: + del kwargs["password"] + + return self.connect(**kwargs) + + @cached_property + def version(self): + matched = re.search(r"(\d+)\.(\d+)\.(\d+)", self.con.server_version) + return ".".join(matched.groups()) + def do_connect( self, host: str = "localhost", @@ -34,8 +104,7 @@ def do_connect( password: str | None = None, port: int = 3306, database: str | None = None, - url: str | None = None, - driver: Literal["pymysql"] = "pymysql", + autocommit: bool = True, **kwargs, ) -> None: """Create an Ibis client using the passed connection parameters. @@ -52,15 +121,10 @@ def do_connect( Port database Database to connect to - url - Complete SQLAlchemy connection string. If passed, the other - connection arguments are ignored. - driver - Python MySQL database driver + autocommit + Autocommit mode kwargs - Additional keyword arguments passed to `connect_args` in - `sqlalchemy.create_engine`. Use these to pass dialect specific - arguments. + Additional keyword arguments passed to `pymysql.connect` Examples -------- @@ -92,96 +156,362 @@ def do_connect( year : int32 month : int32 """ - if driver != "pymysql": - raise NotImplementedError("pymysql is currently the only supported driver") - alchemy_url = self._build_alchemy_url( - url=url, + con = pymysql.connect( + user=user, host=host, port=port, - user=user, password=password, database=database, - driver=f"mysql+{driver}", + autocommit=autocommit, + conv=pymysql.converters.conversions, + **kwargs, ) - engine = sa.create_engine( - alchemy_url, poolclass=sa.pool.StaticPool, connect_args=kwargs - ) - - @sa.event.listens_for(engine, "connect") - def connect(dbapi_connection, connection_record): - with dbapi_connection.cursor() as cur: - try: - cur.execute("SET @@session.time_zone = 'UTC'") - except (sa.exc.OperationalError, pymysql.err.OperationalError): - warnings.warn("Unable to set session timezone to UTC.") + with contextlib.closing(con.cursor()) as cur: + try: + cur.execute("SET @@session.time_zone = 'UTC'") + except Exception as e: # noqa: BLE001 + warnings.warn(f"Unable to set session timezone to UTC: {e}") - super().do_connect(engine) + self.con = con + self._temp_views = set() @property def current_database(self) -> str: - return self._scalar_query(sa.select(sa.func.database())) - - @staticmethod - def _new_sa_metadata(): - meta = sa.MetaData() - - @sa.event.listens_for(meta, "column_reflect") - def column_reflect(inspector, table, column_info): - if isinstance(column_info["type"], mysql.DATETIME): - column_info["type"] = MySQLDateTime() - if isinstance(column_info["type"], mysql.DOUBLE): - column_info["type"] = mysql.DOUBLE(asdecimal=False) - if isinstance(column_info["type"], mysql.FLOAT): - column_info["type"] = mysql.FLOAT(asdecimal=False) - - return meta + with self._safe_raw_sql(sg.select(self.compiler.f.database())) as cur: + [(database,)] = cur.fetchall() + return database def list_databases(self, like: str | None = None) -> list[str]: # In MySQL, "database" and "schema" are synonymous - databases = self.inspector.get_schema_names() + with self._safe_raw_sql("SHOW DATABASES") as cur: + databases = list(map(itemgetter(0), cur.fetchall())) return self._filter_with_like(databases, like) - def _metadata(self, table: str) -> Iterable[tuple[str, dt.DataType]]: - with self.begin() as con: - result = con.exec_driver_sql(f"DESCRIBE {table}").mappings().all() + def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: + table = util.gen_name("mysql_metadata") - for field in result: - name = field["Field"] - type_string = field["Type"] - is_nullable = field["Null"] == "YES" - yield name, MySQLType.from_string(type_string, nullable=is_nullable) + with self.begin() as cur: + cur.execute(f"CREATE TEMPORARY TABLE {table} AS {query}") + try: + cur.execute(f"DESCRIBE {table}") + result = cur.fetchall() + finally: + cur.execute(f"DROP TABLE {table}") - def _get_schema_using_query(self, query: str): - table = f"__ibis_mysql_metadata_{util.guid()}" + type_mapper = self.compiler.type_mapper + return ( + (name, type_mapper.from_string(type_string, nullable=is_nullable == "YES")) + for name, type_string, is_nullable, *_ in result + ) - with self.begin() as con: - con.exec_driver_sql(f"CREATE TEMPORARY TABLE {table} AS {query}") - result = con.exec_driver_sql(f"DESCRIBE {table}").mappings().all() - con.exec_driver_sql(f"DROP TABLE {table}") + def get_schema( + self, name: str, schema: str | None = None, database: str | None = None + ) -> sch.Schema: + table = sg.table(name, db=schema, catalog=database, quoted=True).sql(self.name) - fields = {} - for field in result: - name = field["Field"] - type_string = field["Type"] - is_nullable = field["Null"] == "YES" - fields[name] = MySQLType.from_string(type_string, nullable=is_nullable) + with self.begin() as cur: + cur.execute(f"DESCRIBE {table}") + result = cur.fetchall() + + type_mapper = self.compiler.type_mapper + fields = { + name: type_mapper.from_string(type_string, nullable=is_nullable == "YES") + for name, type_string, is_nullable, *_ in result + } return sch.Schema(fields) - def _get_temp_view_definition( - self, name: str, definition: sa.sql.compiler.Compiled - ) -> str: - yield f"CREATE OR REPLACE VIEW {name} AS {definition}" + def _get_temp_view_definition(self, name: str, definition: str) -> str: + return sge.Create( + kind="VIEW", + replace=True, + this=sg.to_identifier(name, quoted=self.compiler.quoted), + expression=definition, + ) def create_database(self, name: str, force: bool = False) -> None: - name = self._quote(name) - if_exists = "IF NOT EXISTS " * force - with self.begin() as con: - con.exec_driver_sql(f"CREATE DATABASE {if_exists}{name}") + sql = sge.Create(kind="DATABASE", exist=force, this=sg.to_identifier(name)).sql( + self.name + ) + with self.begin() as cur: + cur.execute(sql) def drop_database(self, name: str, force: bool = False) -> None: - name = self._quote(name) - if_exists = "IF EXISTS " * force - with self.begin() as con: - con.exec_driver_sql(f"DROP DATABASE {if_exists}{name}") + sql = sge.Drop(kind="DATABASE", exist=force, this=sg.to_identifier(name)).sql( + self.name + ) + with self.begin() as cur: + cur.execute(sql) + + @contextlib.contextmanager + def begin(self): + con = self.con + cur = con.cursor() + try: + yield cur + except Exception: + con.rollback() + raise + else: + con.commit() + finally: + cur.close() + + @contextlib.contextmanager + def _safe_raw_sql(self, *args, **kwargs): + with contextlib.closing(self.raw_sql(*args, **kwargs)) as result: + yield result + + def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: + with contextlib.suppress(AttributeError): + query = query.sql(dialect=self.name) + + con = self.con + cursor = con.cursor() + + try: + cursor.execute(query, **kwargs) + except Exception: + con.rollback() + cursor.close() + raise + else: + con.commit() + return cursor + + def list_tables( + self, like: str | None = None, schema: str | None = None + ) -> list[str]: + """List the tables in the database. + + Parameters + ---------- + like + A pattern to use for listing tables. + schema + The schema to perform the list against. + """ + conditions = [TRUE] + + if schema is not None: + conditions = C.table_schema.eq(sge.convert(schema)) + + col = "table_name" + sql = ( + sg.select(col) + .from_(sg.table("tables", db="information_schema")) + .distinct() + .where(*conditions) + .sql(self.name, pretty=True) + ) + + with self._safe_raw_sql(sql) as cur: + out = cur.fetchall() + + return self._filter_with_like(map(itemgetter(0), out), like) + + def execute( + self, expr: ir.Expr, limit: str | None = "default", **kwargs: Any + ) -> Any: + """Execute an expression.""" + + self._run_pre_execute_hooks(expr) + table = expr.as_table() + sql = self.compile(table, limit=limit, **kwargs) + + schema = table.schema() + + with self._safe_raw_sql(sql) as cur: + result = self._fetch_from_cursor(cur, schema) + return expr.__pandas_result__(result) + + def create_table( + self, + name: str, + obj: pd.DataFrame | pa.Table | ir.Table | None = None, + *, + schema: ibis.Schema | None = None, + database: str | None = None, + temp: bool = False, + overwrite: bool = False, + ) -> ir.Table: + if obj is None and schema is None: + raise ValueError("Either `obj` or `schema` must be specified") + + if database is not None and database != self.current_database: + raise com.UnsupportedOperationError( + "Creating tables in other databases is not supported by Postgres" + ) + else: + database = None + + properties = [] + + if temp: + properties.append(sge.TemporaryProperty()) + + if obj is not None: + if not isinstance(obj, ir.Expr): + table = ibis.memtable(obj) + else: + table = obj + + self._run_pre_execute_hooks(table) + + query = self._to_sqlglot(table) + else: + query = None + + column_defs = [ + sge.ColumnDef( + this=sg.to_identifier(colname, quoted=self.compiler.quoted), + kind=self.compiler.type_mapper.from_ibis(typ), + constraints=( + None + if typ.nullable + else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())] + ), + ) + for colname, typ in (schema or table.schema()).items() + ] + + if overwrite: + temp_name = util.gen_name(f"{self.name}_table") + else: + temp_name = name + + table = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted) + target = sge.Schema(this=table, expressions=column_defs) + + create_stmt = sge.Create( + kind="TABLE", + this=target, + properties=sge.Properties(expressions=properties), + ) + + this = sg.table(name, catalog=database, quoted=self.compiler.quoted) + with self._safe_raw_sql(create_stmt) as cur: + if query is not None: + insert_stmt = sge.Insert(this=table, expression=query).sql(self.name) + cur.execute(insert_stmt) + + if overwrite: + cur.execute( + sge.Drop(kind="TABLE", this=this, exists=True).sql(self.name) + ) + cur.execute( + f"ALTER TABLE IF EXISTS {table.sql(self.name)} RENAME TO {this.sql(self.name)}" + ) + + if schema is None: + return self.table(name, schema=database) + + # preserve the input schema if it was provided + return ops.DatabaseTable( + name, schema=schema, source=self, namespace=ops.Namespace(database=database) + ).to_expr() + + def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: + schema = op.schema + if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: + raise com.IbisTypeError( + "MySQL cannot yet reliably handle `null` typed columns; " + f"got null typed columns: {null_columns}" + ) + + # only register if we haven't already done so + if (name := op.name) not in self.list_tables(): + quoted = self.compiler.quoted + column_defs = [ + sg.exp.ColumnDef( + this=sg.to_identifier(colname, quoted=quoted), + kind=self.compiler.type_mapper.from_ibis(typ), + constraints=( + None + if typ.nullable + else [ + sg.exp.ColumnConstraint( + kind=sg.exp.NotNullColumnConstraint() + ) + ] + ), + ) + for colname, typ in schema.items() + ] + + create_stmt = sg.exp.Create( + kind="TABLE", + this=sg.exp.Schema( + this=sg.to_identifier(name, quoted=quoted), expressions=column_defs + ), + properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]), + ) + create_stmt_sql = create_stmt.sql(self.name) + + columns = schema.keys() + df = op.data.to_frame() + data = df.itertuples(index=False) + cols = ", ".join( + ident.sql(self.name) + for ident in map(partial(sg.to_identifier, quoted=quoted), columns) + ) + specs = ", ".join(repeat("%s", len(columns))) + table = sg.table(name, quoted=quoted) + sql = f"INSERT INTO {table.sql(self.name)} ({cols}) VALUES ({specs})" + with self.begin() as cur: + cur.execute(create_stmt_sql) + + if not df.empty: + cur.executemany(sql, data) + + @util.experimental + def to_pyarrow_batches( + self, + expr: ir.Expr, + *, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + chunk_size: int = 1_000_000, + **_: Any, + ) -> pa.ipc.RecordBatchReader: + import pyarrow as pa + + self._run_pre_execute_hooks(expr) + + schema = expr.as_table().schema() + with self._safe_raw_sql( + self.compile(expr, limit=limit, params=params) + ) as cursor: + df = self._fetch_from_cursor(cursor, schema) + table = pa.Table.from_pandas( + df, schema=schema.to_pyarrow(), preserve_index=False + ) + return table.to_reader(max_chunksize=chunk_size) + + def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: + import pandas as pd + + from ibis.backends.mysql.converter import MySQLPandasData + + try: + df = pd.DataFrame.from_records( + cursor, columns=schema.names, coerce_float=True + ) + except Exception: + # clean up the cursor if we fail to create the DataFrame + # + # in the sqlite case failing to close the cursor results in + # artificially locked tables + cursor.close() + raise + df = MySQLPandasData.convert_table(df, schema) + return df + + def _register_temp_view_cleanup(self, name: str) -> None: + def drop(self, name: str, query: str): + self.raw_sql(query) + self._temp_views.discard(name) + + query = sge.Drop(this=sg.table(name), kind="VIEW", exists=True) + atexit.register(drop, self, name=name, query=query) diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index 529dfe84b211..d053c83c4300 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -1,28 +1,422 @@ from __future__ import annotations -import sqlalchemy as sa +import string +from functools import partial, reduce, singledispatchmethod -from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator -from ibis.backends.mysql.datatypes import MySQLType -from ibis.backends.mysql.registry import operation_registry -from ibis.expr.rewrites import rewrite_sample +import sqlglot as sg +import sqlglot.expressions as sge +from public import public +from sqlglot.dialects import MySQL +from sqlglot.dialects.dialect import rename_func +import ibis.common.exceptions as com +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler +from ibis.backends.base.sqlglot.datatypes import MySQLType +from ibis.backends.base.sqlglot.rewrites import ( + exclude_unsupported_window_frame_from_ops, + exclude_unsupported_window_frame_from_row_number, + rewrite_empty_order_by_window, + rewrite_first_to_first_value, + rewrite_last_to_last_value, +) +from ibis.common.patterns import replace +from ibis.expr.rewrites import p, rewrite_sample -class MySQLExprTranslator(AlchemyExprTranslator): - # https://dev.mysql.com/doc/refman/8.0/en/spatial-function-reference.html - _registry = operation_registry.copy() - _rewrites = AlchemyExprTranslator._rewrites.copy() - _integer_to_timestamp = sa.func.from_unixtime - native_json_type = False - _dialect_name = "mysql" +MySQL.Generator.TRANSFORMS |= { + sge.LogicalOr: rename_func("max"), + sge.LogicalAnd: rename_func("min"), + sge.VariancePop: rename_func("var_pop"), + sge.Variance: rename_func("var_samp"), + sge.Stddev: rename_func("stddev_pop"), + sge.StddevPop: rename_func("stddev_pop"), + sge.StddevSamp: rename_func("stddev_samp"), + sge.RegexpLike: ( + lambda _, e: f"({e.this.sql('mysql')} RLIKE {e.expression.sql('mysql')})" + ), +} + + +@replace(p.Limit) +def rewrite_limit(_, **kwargs): + """Rewrite limit for MySQL to include a large upper bound. + + From the MySQL docs @ https://dev.mysql.com/doc/refman/8.0/en/select.html + + > To retrieve all rows from a certain offset up to the end of the result + > set, you can use some large number for the second parameter. This statement + > retrieves all rows from the 96th row to the last: + > + > SELECT * FROM tbl LIMIT 95,18446744073709551615; + """ + if _.n is None and _.offset is not None: + some_large_number = (1 << 64) - 1 + return _.copy(n=some_large_number) + return _ + + +@public +class MySQLCompiler(SQLGlotCompiler): + __slots__ = () + + dialect = "mysql" type_mapper = MySQLType + rewrites = ( + rewrite_limit, + rewrite_sample, + rewrite_first_to_first_value, + rewrite_last_to_last_value, + exclude_unsupported_window_frame_from_ops, + exclude_unsupported_window_frame_from_row_number, + rewrite_empty_order_by_window, + *SQLGlotCompiler.rewrites, + ) + quoted = True + + @property + def NAN(self): + raise NotImplementedError("MySQL does not support NaN") + + @property + def POS_INF(self): + raise NotImplementedError("MySQL does not support Infinity") + + NEG_INF = POS_INF + + def _aggregate(self, funcname: str, *args, where): + func = self.f[funcname] + if where is not None: + args = tuple(self.if_(where, arg, NULL) for arg in args) + return func(*args) + + @singledispatchmethod + def visit_node(self, op, **kwargs): + return super().visit_node(op, **kwargs) + + @staticmethod + def _minimize_spec(start, end, spec): + if ( + start is None + and isinstance(getattr(end, "value", None), ops.Literal) + and end.value.value == 0 + and end.following + ): + return None + return spec + + @visit_node.register(ops.Cast) + def visit_Cast(self, op, *, arg, to): + from_ = op.arg.dtype + if (from_.is_json() or from_.is_string()) and to.is_json(): + # MariaDB does not support casting to JSON because it's an alias + # for TEXT (except when casting of course!) + return arg + elif from_.is_integer() and to.is_interval(): + return self.visit_IntervalFromInteger( + ops.IntervalFromInteger(op.arg, unit=to.unit), arg=arg, unit=to.unit + ) + elif from_.is_integer() and to.is_timestamp(): + return self.f.from_unixtime(arg) + return super().visit_Cast(op, arg=arg, to=to) + + @visit_node.register(ops.TimestampDiff) + def visit_TimestampDiff(self, op, *, left, right): + return self.f.timestampdiff( + sge.Var(this="SECOND"), right, left, dialect=self.dialect + ) + + @visit_node.register(ops.DateDiff) + def visit_DateDiff(self, op, *, left, right): + return self.f.timestampdiff( + sge.Var(this="DAY"), right, left, dialect=self.dialect + ) + + @visit_node.register(ops.ApproxCountDistinct) + def visit_ApproxCountDistinct(self, op, *, arg, where): + if where is not None: + arg = self.if_(where, arg) + return self.f.count(sge.Distinct(expressions=[arg])) + + @visit_node.register(ops.CountStar) + def visit_CountStar(self, op, *, arg, where): + if where is not None: + return self.f.sum(self.cast(where, op.dtype)) + return self.f.count(STAR) + + @visit_node.register(ops.CountDistinct) + def visit_CountDistinct(self, op, *, arg, where): + if where is not None: + arg = self.if_(where, arg) + return self.f.count(sge.Distinct(expressions=[arg])) + + @visit_node.register(ops.CountDistinctStar) + def visit_CountDistinctStar(self, op, *, arg, where): + if where is not None: + raise com.UnsupportedOperationError( + "Filtered table count distinct is not supported in MySQL" + ) + func = partial(sg.column, table=arg.alias_or_name, quoted=self.quoted) + return self.f.count( + sge.Distinct(expressions=list(map(func, op.arg.schema.keys()))) + ) + + @visit_node.register(ops.GroupConcat) + def visit_GroupConcat(self, op, *, arg, sep, where): + if not isinstance(op.sep, ops.Literal): + raise com.UnsupportedOperationError( + "Only string literal separators are supported" + ) + if where is not None: + arg = self.if_(where, arg) + return self.f.group_concat(arg, sep) + + @visit_node.register(ops.DayOfWeekIndex) + def visit_DayOfWeekIndex(self, op, *, arg): + return (self.f.dayofweek(arg) + 5) % 7 + + @visit_node.register(ops.Literal) + def visit_Literal(self, op, *, value, dtype): + # avoid casting NULL: the set of types allowed by MySQL and + # MariaDB when casting is a strict subset of allowed types in other + # contexts like CREATE TABLE + if value is None: + return NULL + return super().visit_Literal(op, value=value, dtype=dtype) + + def visit_NonNullLiteral(self, op, *, value, dtype): + if dtype.is_decimal() and not value.is_finite(): + raise com.UnsupportedOperationError( + "MySQL does not support NaN or infinity" + ) + elif dtype.is_binary(): + return self.f.unhex(value.hex()) + elif dtype.is_date(): + return self.f.date(value.isoformat()) + elif dtype.is_timestamp(): + return self.f.timestamp(value.isoformat()) + elif dtype.is_time(): + return self.f.maketime( + value.hour, value.minute, value.second + value.microsecond / 1e6 + ) + elif dtype.is_array() or dtype.is_struct() or dtype.is_map(): + raise com.UnsupportedBackendType( + "MySQL does not support arrays, structs or maps" + ) + elif dtype.is_string(): + return sge.convert(value.replace("\\", "\\\\")) + return None + + @visit_node.register(ops.JSONGetItem) + def visit_JSONGetItem(self, op, *, arg, index): + if op.index.dtype.is_integer(): + path = self.f.concat("$[", self.cast(index, dt.string), "]") + else: + path = self.f.concat("$.", index) + return self.f.json_extract(arg, path) + + @visit_node.register(ops.DateFromYMD) + def visit_DateFromYMD(self, op, *, year, month, day): + return self.f.str_to_date( + self.f.concat( + self.f.lpad(year, 4, "0"), + self.f.lpad(month, 2, "0"), + self.f.lpad(day, 2, "0"), + ), + "%Y%m%d", + ) + + @visit_node.register(ops.FindInSet) + def visit_FindInSet(self, op, *, needle, values): + return self.f.find_in_set(needle, self.f.concat_ws(",", values)) + + @visit_node.register(ops.EndsWith) + def visit_EndsWith(self, op, *, arg, end): + to = sge.DataType(this=sge.DataType.Type.BINARY) + return self.f.right(arg, self.f.char_length(end)).eq(sge.Cast(this=end, to=to)) + + @visit_node.register(ops.StartsWith) + def visit_StartsWith(self, op, *, arg, start): + to = sge.DataType(this=sge.DataType.Type.BINARY) + return self.f.left(arg, self.f.length(start)).eq(sge.Cast(this=start, to=to)) + + @visit_node.register(ops.RegexSearch) + def visit_RegexSearch(self, op, *, arg, pattern): + return arg.rlike(pattern) + + @visit_node.register(ops.RegexExtract) + def visit_RegexExtract(self, op, *, arg, pattern, index): + extracted = self.f.regexp_substr(arg, pattern) + return self.if_( + arg.rlike(pattern), + self.if_( + index.eq(0), + extracted, + self.f.regexp_replace( + extracted, pattern, rf"\\{index.sql(self.dialect)}" + ), + ), + NULL, + ) + + @visit_node.register(ops.Equals) + def visit_Equals(self, op, *, left, right): + if op.left.dtype.is_string(): + assert op.right.dtype.is_string(), op.right.dtype + to = sge.DataType(this=sge.DataType.Type.BINARY) + return sge.Cast(this=left, to=to).eq(right) + return super().visit_Equals(op, left=left, right=right) + + @visit_node.register(ops.StringContains) + def visit_StringContains(self, op, *, haystack, needle): + return self.f.instr(haystack, needle) > 0 + + @visit_node.register(ops.StringFind) + def visit_StringFind(self, op, *, arg, substr, start, end): + if end is not None: + raise NotImplementedError( + "`end` argument is not implemented for MySQL `StringValue.find`" + ) + substr = sge.Cast(this=substr, to=sge.DataType(this=sge.DataType.Type.BINARY)) + + if start is not None: + return self.f.locate(substr, arg, start + 1) + return self.f.locate(substr, arg) + + @visit_node.register(ops.Capitalize) + def visit_Capitalize(self, op, *, arg): + return self.f.concat( + self.f.upper(self.f.left(arg, 1)), self.f.lower(self.f.substr(arg, 2)) + ) + + def visit_LRStrip(self, op, *, arg, position): + return reduce( + lambda arg, char: self.f.trim(this=arg, position=position, expression=char), + map( + partial(self.cast, to=dt.string), + map(self.f.unhex, map(self.f.hex, string.whitespace.encode())), + ), + arg, + ) + + @visit_node.register(ops.DateTruncate) + @visit_node.register(ops.TimestampTruncate) + def visit_DateTimestampTruncate(self, op, *, arg, unit): + truncate_formats = { + "s": "%Y-%m-%d %H:%i:%s", + "m": "%Y-%m-%d %H:%i:00", + "h": "%Y-%m-%d %H:00:00", + "D": "%Y-%m-%d", + # 'W': 'week', + "M": "%Y-%m-01", + "Y": "%Y-01-01", + } + if (format := truncate_formats.get(unit.short)) is None: + raise com.UnsupportedOperationError(f"Unsupported truncate unit {op.unit}") + return self.f.date_format(arg, format) + + @visit_node.register(ops.TimeDelta) + @visit_node.register(ops.DateDelta) + def visit_DateTimeDelta(self, op, *, left, right, part): + return self.f.timestampdiff( + sge.Var(this=part.this), right, left, dialect=self.dialect + ) + + @visit_node.register(ops.ExtractMillisecond) + def visit_ExtractMillisecond(self, op, *, arg): + return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg) / 1_000) + + @visit_node.register(ops.ExtractMicrosecond) + def visit_ExtractMicrosecond(self, op, *, arg): + return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg)) + + @visit_node.register(ops.Strip) + def visit_Strip(self, op, *, arg): + return self.visit_LRStrip(op, arg=arg, position="BOTH") + + @visit_node.register(ops.LStrip) + def visit_LStrip(self, op, *, arg): + return self.visit_LRStrip(op, arg=arg, position="LEADING") + + @visit_node.register(ops.RStrip) + def visit_RStrip(self, op, *, arg): + return self.visit_LRStrip(op, arg=arg, position="TRAILING") + + @visit_node.register(ops.IntervalFromInteger) + def visit_IntervalFromInteger(self, op, *, arg, unit): + return sge.Interval(this=arg, unit=sge.convert(op.resolution.upper())) + + @visit_node.register(ops.TimestampAdd) + def visit_TimestampAdd(self, op, *, left, right): + if op.right.dtype.unit.short == "ms": + right = sge.Interval( + this=right.this * 1_000, unit=sge.Var(this="MICROSECOND") + ) + return self.f.date_add(left, right, dialect=self.dialect) + + @visit_node.register(ops.ApproxMedian) + @visit_node.register(ops.Arbitrary) + @visit_node.register(ops.ArgMax) + @visit_node.register(ops.ArgMin) + @visit_node.register(ops.ArrayCollect) + @visit_node.register(ops.Array) + @visit_node.register(ops.ArrayFlatten) + @visit_node.register(ops.ArrayMap) + @visit_node.register(ops.Covariance) + @visit_node.register(ops.First) + @visit_node.register(ops.Last) + @visit_node.register(ops.Levenshtein) + @visit_node.register(ops.Median) + @visit_node.register(ops.Mode) + @visit_node.register(ops.MultiQuantile) + @visit_node.register(ops.Quantile) + @visit_node.register(ops.RegexReplace) + @visit_node.register(ops.RegexSplit) + @visit_node.register(ops.RowID) + @visit_node.register(ops.StringSplit) + @visit_node.register(ops.StructColumn) + @visit_node.register(ops.TimestampBucket) + @visit_node.register(ops.TimestampDelta) + @visit_node.register(ops.Translate) + @visit_node.register(ops.Unnest) + def visit_Undefined(self, op, **_): + raise com.OperationNotDefinedError(type(op).__name__) + + +_SIMPLE_OPS = { + ops.BitAnd: "bit_and", + ops.BitOr: "bit_or", + ops.BitXor: "bit_xor", + ops.DayOfWeekName: "dayname", + ops.Log10: "log10", + ops.LPad: "lpad", + ops.RPad: "rpad", + ops.StringAscii: "ascii", + ops.StringContains: "instr", + ops.ExtractWeekOfYear: "weekofyear", + ops.ExtractEpochSeconds: "unix_timestamp", + ops.ExtractDayOfYear: "dayofyear", + ops.Strftime: "date_format", + ops.StringToTimestamp: "str_to_date", + ops.Log2: "log2", +} + + +for _op, _name in _SIMPLE_OPS.items(): + assert isinstance(type(_op), type), type(_op) + if issubclass(_op, ops.Reduction): + + @MySQLCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, where, **kw): + return self.agg[_name](*kw.values(), where=where) + + else: + @MySQLCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, **kw): + return self.f[_name](*kw.values()) -rewrites = MySQLExprTranslator.rewrites + setattr(MySQLCompiler, f"visit_{_op.__name__}", _fmt) -class MySQLCompiler(AlchemyCompiler): - translator_class = MySQLExprTranslator - support_values_syntax_in_select = False - null_limit = None - rewrites = AlchemyCompiler.rewrites | rewrite_sample +del _op, _name, _fmt diff --git a/ibis/backends/mysql/converter.py b/ibis/backends/mysql/converter.py new file mode 100644 index 000000000000..ffa277c56de4 --- /dev/null +++ b/ibis/backends/mysql/converter.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import datetime + +from ibis.formats.pandas import PandasData + + +class MySQLPandasData(PandasData): + @classmethod + def convert_Time(cls, s, dtype, pandas_type): + def convert(timedelta): + comps = timedelta.components + return datetime.time( + hour=comps.hours, + minute=comps.minutes, + second=comps.seconds, + microsecond=comps.microseconds, + ) + + return s.map(convert, na_action="ignore") + + @classmethod + def convert_Timestamp(cls, s, dtype, pandas_type): + if s.dtype == "object": + s = s.replace("0000-00-00 00:00:00", None) + return super().convert_Timestamp(s, dtype, pandas_type) diff --git a/ibis/backends/mysql/datatypes.py b/ibis/backends/mysql/datatypes.py deleted file mode 100644 index 05ef1cc5a496..000000000000 --- a/ibis/backends/mysql/datatypes.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -import sqlalchemy.types as sat -from sqlalchemy.dialects import mysql - -import ibis.expr.datatypes as dt -from ibis.backends.base.sql.alchemy.datatypes import UUID, AlchemyType -from ibis.backends.base.sqlglot.datatypes import MySQLType as SqlglotMySQLType - - -class MySQLDateTime(mysql.DATETIME): - """Custom DATETIME type for MySQL that handles zero values.""" - - def result_processor(self, *_): - return lambda v: None if v == "0000-00-00 00:00:00" else v - - -_to_mysql_types = { - dt.Boolean: mysql.BOOLEAN, - dt.Int8: mysql.TINYINT, - dt.Int16: mysql.SMALLINT, - dt.Int32: mysql.INTEGER, - dt.Int64: mysql.BIGINT, - dt.Float16: mysql.FLOAT, - dt.Float32: mysql.FLOAT, - dt.Float64: mysql.DOUBLE, - dt.String: mysql.TEXT, - dt.JSON: mysql.JSON, - dt.Timestamp: MySQLDateTime, -} - -_from_mysql_types = { - mysql.BIGINT: dt.Int64, - mysql.BINARY: dt.Binary, - mysql.BLOB: dt.Binary, - mysql.BOOLEAN: dt.Boolean, - mysql.DATETIME: dt.Timestamp, - mysql.DOUBLE: dt.Float64, - mysql.FLOAT: dt.Float32, - mysql.INTEGER: dt.Int32, - mysql.JSON: dt.JSON, - mysql.LONGBLOB: dt.Binary, - mysql.LONGTEXT: dt.String, - mysql.MEDIUMBLOB: dt.Binary, - mysql.MEDIUMINT: dt.Int32, - mysql.MEDIUMTEXT: dt.String, - mysql.REAL: dt.Float64, - mysql.SMALLINT: dt.Int16, - mysql.TEXT: dt.String, - mysql.DATE: dt.Date, - mysql.TINYBLOB: dt.Binary, - mysql.TINYINT: dt.Int8, - mysql.VARBINARY: dt.Binary, - mysql.VARCHAR: dt.String, - mysql.ENUM: dt.String, - mysql.CHAR: dt.String, - mysql.TIME: dt.Time, - mysql.YEAR: dt.Int8, - MySQLDateTime: dt.Timestamp, - UUID: dt.UUID, -} - - -class MySQLType(AlchemyType): - dialect = "mysql" - - @classmethod - def from_ibis(cls, dtype): - try: - return _to_mysql_types[type(dtype)] - except KeyError: - return super().from_ibis(dtype) - - @classmethod - def to_ibis(cls, typ, nullable=True): - if isinstance(typ, (sat.NUMERIC, mysql.NUMERIC, mysql.DECIMAL)): - # https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html - return dt.Decimal(typ.precision or 10, typ.scale or 0, nullable=nullable) - elif isinstance(typ, mysql.BIT): - if 1 <= (length := typ.length) <= 8: - return dt.Int8(nullable=nullable) - elif 9 <= length <= 16: - return dt.Int16(nullable=nullable) - elif 17 <= length <= 32: - return dt.Int32(nullable=nullable) - elif 33 <= length <= 64: - return dt.Int64(nullable=nullable) - else: - raise ValueError(f"Invalid MySQL BIT length: {length:d}") - elif isinstance(typ, mysql.TIMESTAMP): - return dt.Timestamp(timezone="UTC", nullable=nullable) - elif isinstance(typ, mysql.SET): - return dt.Array(dt.string, nullable=nullable) - elif dtype := _from_mysql_types.get(type(typ)): - return dtype(nullable=nullable) - else: - return super().to_ibis(typ, nullable=nullable) - - @classmethod - def from_string(cls, type_string, nullable=True): - return SqlglotMySQLType.from_string(type_string, nullable=nullable) diff --git a/ibis/backends/mysql/registry.py b/ibis/backends/mysql/registry.py deleted file mode 100644 index 9b326cc2e63b..000000000000 --- a/ibis/backends/mysql/registry.py +++ /dev/null @@ -1,265 +0,0 @@ -from __future__ import annotations - -import contextlib -import functools -import operator -import string - -import sqlalchemy as sa -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql.functions import GenericFunction - -import ibis -import ibis.common.exceptions as com -import ibis.expr.operations as ops -from ibis.backends.base.sql.alchemy import ( - fixed_arity, - sqlalchemy_operation_registry, - sqlalchemy_window_functions_registry, - unary, -) -from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported -from ibis.backends.base.sql.alchemy.registry import ( - geospatial_functions, -) - -operation_registry = sqlalchemy_operation_registry.copy() - -# NOTE: window functions are available from MySQL 8 and MariaDB 10.2 -operation_registry.update(sqlalchemy_window_functions_registry) - -if geospatial_supported: - operation_registry.update(geospatial_functions) - -_truncate_formats = { - "s": "%Y-%m-%d %H:%i:%s", - "m": "%Y-%m-%d %H:%i:00", - "h": "%Y-%m-%d %H:00:00", - "D": "%Y-%m-%d", - # 'W': 'week', - "M": "%Y-%m-01", - "Y": "%Y-01-01", -} - - -def _truncate(t, op): - sa_arg = t.translate(op.arg) - try: - fmt = _truncate_formats[op.unit.short] - except KeyError: - raise com.UnsupportedOperationError(f"Unsupported truncate unit {op.unit}") - return sa.func.date_format(sa_arg, fmt) - - -def _round(t, op): - sa_arg = t.translate(op.arg) - - if op.digits is None: - sa_digits = 0 - else: - sa_digits = t.translate(op.digits) - - return sa.func.round(sa_arg, sa_digits) - - -def _interval_from_integer(t, op): - if op.unit.short in {"ms", "ns"}: - raise com.UnsupportedOperationError( - f"MySQL does not allow operation with INTERVAL offset {op.unit}" - ) - - sa_arg = t.translate(op.arg) - text_unit = op.dtype.resolution.upper() - - # XXX: Is there a better way to handle this? I.e. can we somehow use - # the existing bind parameter produced by translate and reuse its name in - # the string passed to sa.text? - if isinstance(sa_arg, sa.sql.elements.BindParameter): - return sa.text(f"INTERVAL :arg {text_unit}").bindparams(arg=sa_arg.value) - return sa.text(f"INTERVAL {sa_arg} {text_unit}") - - -def _literal(_, op): - dtype = op.dtype - value = op.value - if value is None: - return sa.null() - if dtype.is_interval(): - if dtype.unit.short in {"ms", "ns"}: - raise com.UnsupportedOperationError( - f"MySQL does not allow operation with INTERVAL offset {dtype.unit}" - ) - text_unit = dtype.resolution.upper() - sa_text = sa.text(f"INTERVAL :value {text_unit}") - return sa_text.bindparams(value=value) - elif dtype.is_binary(): - # the cast to BINARY is necessary here, otherwise the data come back as - # Python strings - # - # This lets the database handle encoding rather than ibis - return sa.cast(sa.literal(value), type_=sa.BINARY()) - elif dtype.is_date(): - return sa.func.date(value.isoformat()) - elif dtype.is_timestamp(): - # TODO: timezones - return sa.func.timestamp(value.isoformat()) - elif dtype.is_time(): - return sa.func.maketime( - value.hour, value.minute, value.second + value.microsecond / 1e6 - ) - else: - with contextlib.suppress(AttributeError): - value = value.to_pydatetime() - - return sa.literal(value) - - -def _group_concat(t, op): - if op.where is not None: - arg = t.translate(ops.IfElse(op.where, op.arg, ibis.NA)) - else: - arg = t.translate(op.arg) - sep = t.translate(op.sep) - return sa.func.group_concat(arg.op("SEPARATOR")(sep)) - - -def _json_get_item(t, op): - arg = t.translate(op.arg) - index = t.translate(op.index) - if op.index.dtype.is_integer(): - path = "$[" + sa.cast(index, sa.TEXT) + "]" - else: - path = "$." + index - return sa.func.json_extract(arg, path) - - -def _regex_extract(arg, pattern, index): - return sa.func.IF( - arg.op("REGEXP")(pattern), - sa.func.IF( - index == 0, - sa.func.REGEXP_SUBSTR(arg, pattern), - sa.func.REGEXP_REPLACE( - sa.func.REGEXP_SUBSTR(arg, pattern), pattern, rf"\{index.value}" - ), - ), - None, - ) - - -def _string_find(t, op): - arg = t.translate(op.arg) - substr = t.translate(op.substr) - - if op_start := op.start: - start = t.translate(op_start) - return sa.func.locate(substr, arg, start) - 1 - - return sa.func.locate(substr, arg) - 1 - - -class _mysql_trim(GenericFunction): - inherit_cache = True - - def __init__(self, input, side: str) -> None: - super().__init__(input) - self.type = sa.VARCHAR() - self.side = side - - -@compiles(_mysql_trim, "mysql") -def compiles_mysql_trim(element, compiler, **kw): - arg = compiler.function_argspec(element, **kw) - side = element.side.upper() - # has to be called once for every whitespace character because mysql - # interprets `char` literally, not as a set of characters like Python - return functools.reduce( - lambda arg, char: f"TRIM({side} '{char}' FROM {arg})", string.whitespace, arg - ) - - -def _temporal_delta(t, op): - left = t.translate(op.left) - right = t.translate(op.right) - part = sa.literal_column(op.part.value.upper()) - return sa.func.timestampdiff(part, right, left) - - -operation_registry.update( - { - ops.Literal: _literal, - # static checks are not happy with using "if" as a property - ops.IfElse: fixed_arity(getattr(sa.func, "if"), 3), - # strings - ops.StringFind: _string_find, - ops.FindInSet: ( - lambda t, op: ( - sa.func.find_in_set( - t.translate(op.needle), - sa.func.concat_ws(",", *map(t.translate, op.values)), - ) - - 1 - ) - ), - # LIKE in mysql is case insensitive - ops.StartsWith: fixed_arity( - lambda arg, start: sa.type_coerce( - arg.op("LIKE BINARY")(sa.func.concat(start, "%")), sa.BOOLEAN() - ), - 2, - ), - ops.EndsWith: fixed_arity( - lambda arg, end: sa.type_coerce( - arg.op("LIKE BINARY")(sa.func.concat("%", end)), sa.BOOLEAN() - ), - 2, - ), - ops.RegexSearch: fixed_arity( - lambda x, y: sa.type_coerce(x.op("REGEXP")(y), sa.BOOLEAN()), 2 - ), - ops.RegexExtract: fixed_arity(_regex_extract, 3), - # math - ops.Log: fixed_arity(lambda arg, base: sa.func.log(base, arg), 2), - ops.Log2: unary(sa.func.log2), - ops.Log10: unary(sa.func.log10), - ops.Round: _round, - # dates and times - ops.DateAdd: fixed_arity(operator.add, 2), - ops.DateSub: fixed_arity(operator.sub, 2), - ops.DateDiff: fixed_arity(sa.func.datediff, 2), - ops.TimestampAdd: fixed_arity(operator.add, 2), - ops.TimestampSub: fixed_arity(operator.sub, 2), - ops.TimestampDiff: fixed_arity( - lambda left, right: sa.func.timestampdiff(sa.text("SECOND"), right, left), 2 - ), - ops.StringToTimestamp: fixed_arity( - lambda arg, format_str: sa.func.str_to_date(arg, format_str), 2 - ), - ops.DateTruncate: _truncate, - ops.TimestampTruncate: _truncate, - ops.IntervalFromInteger: _interval_from_integer, - ops.Strftime: fixed_arity(sa.func.date_format, 2), - ops.ExtractDayOfYear: unary(sa.func.dayofyear), - ops.ExtractEpochSeconds: unary(sa.func.UNIX_TIMESTAMP), - ops.ExtractWeekOfYear: unary(sa.func.weekofyear), - ops.ExtractMicrosecond: fixed_arity( - lambda arg: sa.func.floor(sa.extract("microsecond", arg)), 1 - ), - ops.ExtractMillisecond: fixed_arity( - lambda arg: sa.func.floor(sa.extract("microsecond", arg) / 1000), 1 - ), - ops.TimestampNow: fixed_arity(sa.func.now, 0), - # others - ops.GroupConcat: _group_concat, - ops.DayOfWeekIndex: fixed_arity( - lambda arg: (sa.func.dayofweek(arg) + 5) % 7, 1 - ), - ops.DayOfWeekName: fixed_arity(lambda arg: sa.func.dayname(arg), 1), - ops.JSONGetItem: _json_get_item, - ops.Strip: unary(lambda arg: _mysql_trim(arg, "both")), - ops.LStrip: unary(lambda arg: _mysql_trim(arg, "leading")), - ops.RStrip: unary(lambda arg: _mysql_trim(arg, "trailing")), - ops.TimeDelta: _temporal_delta, - ops.DateDelta: _temporal_delta, - } -) diff --git a/ibis/backends/mysql/tests/conftest.py b/ibis/backends/mysql/tests/conftest.py index 8c0ad8007051..c7cadc448fd3 100644 --- a/ibis/backends/mysql/tests/conftest.py +++ b/ibis/backends/mysql/tests/conftest.py @@ -4,11 +4,9 @@ from typing import TYPE_CHECKING, Any import pytest -import sqlalchemy as sa -from packaging.version import parse as parse_version import ibis -from ibis.backends.conftest import TEST_TABLES, init_database +from ibis.backends.conftest import TEST_TABLES from ibis.backends.tests.base import ServiceBackendTest if TYPE_CHECKING: @@ -26,29 +24,18 @@ class TestConf(ServiceBackendTest): # mysql has the same rounding behavior as postgres check_dtype = False returned_timestamp_unit = "s" - supports_arrays = False - supports_arrays_outside_of_select = supports_arrays + supports_arrays = supports_arrays_outside_of_select = False native_bool = False supports_structs = False rounding_method = "half_to_even" service_name = "mysql" - deps = "pymysql", "sqlalchemy" + deps = ("pymysql",) + supports_window_operations = True @property def test_files(self) -> Iterable[Path]: return self.data_dir.joinpath("csv").glob("*.csv") - @property - def supports_window_operations(self) -> bool: - con = self.connection - with con.begin() as c: - version = c.execute(sa.select(sa.func.version())).scalar() - - # mariadb supports window operations after version 10.2 - # mysql supports window operations after version 8 - min_version = "10.2" if "MariaDB" in version else "8.0" - return parse_version(con.version) >= parse_version(min_version) - def _load_data( self, *, @@ -68,16 +55,10 @@ def _load_data( script_dir Location of scripts defining schemas """ - engine = init_database( - url=sa.engine.make_url( - f"mysql+pymysql://{user}:{password}@{host}:{port:d}?local_infile=1", - ), - database=database, - schema=self.ddl_script, - isolation_level="AUTOCOMMIT", - recreate=False, - ) - with engine.begin() as con: + with self.connection.begin() as cur: + for stmt in self.ddl_script: + cur.execute(stmt) + for table in TEST_TABLES: csv_path = self.data_dir / "csv" / f"{table}.csv" lines = [ @@ -88,7 +69,7 @@ def _load_data( "LINES TERMINATED BY '\\n'", "IGNORE 1 LINES", ] - con.exec_driver_sql("\n".join(lines)) + cur.execute("\n".join(lines)) @staticmethod def connect(*, tmpdir, worker_id, **kw): @@ -98,31 +79,12 @@ def connect(*, tmpdir, worker_id, **kw): password=MYSQL_PASS, database=IBIS_TEST_MYSQL_DB, port=MYSQL_PORT, + local_infile=1, + autocommit=True, **kw, ) -@pytest.fixture(scope="session") -def setup_privs(): - engine = sa.create_engine(f"mysql+pymysql://root:@{MYSQL_HOST}:{MYSQL_PORT:d}") - with engine.begin() as con: - # allow the ibis user to use any database - con.exec_driver_sql("CREATE SCHEMA IF NOT EXISTS `test_schema`") - con.exec_driver_sql( - f"GRANT CREATE,SELECT,DROP ON `test_schema`.* TO `{MYSQL_USER}`@`%%`" - ) - yield - with engine.begin() as con: - con.exec_driver_sql("DROP SCHEMA IF EXISTS `test_schema`") - - @pytest.fixture(scope="session") def con(tmp_path_factory, data_dir, worker_id): return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection - - -@pytest.fixture(scope="session") -def con_nodb(): - return ibis.mysql.connect( - host=MYSQL_HOST, user=MYSQL_USER, password=MYSQL_PASS, port=MYSQL_PORT - ) diff --git a/ibis/backends/mysql/tests/test_client.py b/ibis/backends/mysql/tests/test_client.py index 4c73faeb9c3d..18ad4a39c8af 100644 --- a/ibis/backends/mysql/tests/test_client.py +++ b/ibis/backends/mysql/tests/test_client.py @@ -7,22 +7,14 @@ import pandas as pd import pandas.testing as tm import pytest -import sqlalchemy as sa -from packaging.version import parse as vparse +import sqlglot as sg from pytest import param -from sqlalchemy.dialects import mysql import ibis import ibis.expr.datatypes as dt from ibis import udf -from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported from ibis.util import gen_name -if geospatial_supported: - import geoalchemy2 -else: - geoalchemy2 = None - MYSQL_TYPES = [ param("tinyint", dt.int8, id="tinyint"), param("int1", dt.int8, id="int1"), @@ -69,30 +61,20 @@ param("set('a', 'b', 'c', 'd')", dt.Array(dt.string), id="set"), param("mediumblob", dt.binary, id="mediumblob"), param("blob", dt.binary, id="blob"), - param( - "uuid", - dt.uuid, - marks=[ - pytest.mark.xfail( - condition=vparse(sa.__version__) < vparse("2"), - reason="geoalchemy2 0.14.x doesn't work", - ) - ], - id="uuid", - ), + param("uuid", dt.uuid, id="uuid"), ] @pytest.mark.parametrize(("mysql_type", "expected_type"), MYSQL_TYPES) def test_get_schema_from_query(con, mysql_type, expected_type): raw_name = ibis.util.guid() - name = con._quote(raw_name) + name = sg.to_identifier(raw_name, quoted=True).sql("mysql") expected_schema = ibis.schema(dict(x=expected_type)) # temporary tables get cleaned up by the db when the session ends, so we # don't need to explicitly drop the table with con.begin() as c: - c.exec_driver_sql(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})") + c.execute(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})") result_schema = con._get_schema_using_query(f"SELECT * FROM {name}") assert result_schema == expected_schema @@ -105,29 +87,23 @@ def test_get_schema_from_query(con, mysql_type, expected_type): def test_blob_type(con, coltype): tmp = f"tmp_{ibis.util.guid()}" with con.begin() as c: - c.exec_driver_sql(f"CREATE TEMPORARY TABLE {tmp} (a {coltype})") + c.execute(f"CREATE TEMPORARY TABLE {tmp} (a {coltype})") t = con.table(tmp) assert t.schema() == ibis.schema({"a": dt.binary}) @pytest.fixture(scope="session") -def tmp_t(con_nodb): - with con_nodb.begin() as c: - c.exec_driver_sql("CREATE TABLE IF NOT EXISTS test_schema.t (x INET6)") - yield - with con_nodb.begin() as c: - c.exec_driver_sql("DROP TABLE IF EXISTS test_schema.t") - - -@pytest.mark.usefixtures("setup_privs", "tmp_t") -@pytest.mark.xfail( - geospatial_supported and vparse(geoalchemy2.__version__) > vparse("0.13.3"), - reason="geoalchemy2 issues when using 0.14.x", - raises=sa.exc.OperationalError, -) -def test_get_schema_from_query_other_schema(con_nodb): - t = con_nodb.table("t", schema="test_schema") - assert t.schema() == ibis.schema({"x": dt.string}) +def tmp_t(con): + with con.begin() as c: + c.execute("CREATE TABLE IF NOT EXISTS test_schema.t (x INET6)") + yield "t" + with con.begin() as c: + c.execute("DROP TABLE IF EXISTS test_schema.t") + + +def test_get_schema_from_query_other_schema(con, tmp_t): + t = con.table(tmp_t, schema="test_schema") + assert t.schema() == ibis.schema({"x": dt.inet}) def test_zero_timestamp_data(con): @@ -137,11 +113,11 @@ def test_zero_timestamp_data(con): name CHAR(10) NULL, tradedate DATETIME NOT NULL, date DATETIME NULL - ); + ) """ with con.begin() as c: - c.exec_driver_sql(sql) - c.exec_driver_sql( + c.execute(sql) + c.execute( """ INSERT INTO ztmp_date_issue VALUES ('C', '2018-10-22', 0), @@ -166,12 +142,11 @@ def test_zero_timestamp_data(con): @pytest.fixture(scope="module") def enum_t(con): name = gen_name("mysql_enum_test") - t = sa.Table( - name, sa.MetaData(), sa.Column("sml", mysql.ENUM("small", "medium", "large")) - ) - with con.begin() as bind: - t.create(bind=bind) - bind.execute(t.insert().values(sml="small")) + with con.begin() as cur: + cur.execute( + f"CREATE TEMPORARY TABLE {name} (sml ENUM('small', 'medium', 'large'))" + ) + cur.execute(f"INSERT INTO {name} VALUES ('small')") yield con.table(name) con.drop_table(name, force=True) diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index ca70557f6321..229ad1577282 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -91,3 +91,10 @@ PsycoPg2SyntaxError = ( PsycoPg2IndeterminateDatatype ) = PsycoPg2InvalidTextRepresentation = PsycoPg2DivisionByZero = None + +try: + from pymysql.err import NotSupportedError as MySQLNotSupportedError + from pymysql.err import OperationalError as MySQLOperationalError + from pymysql.err import ProgrammingError as MySQLProgrammingError +except ImportError: + MySQLNotSupportedError = MySQLProgrammingError = MySQLOperationalError = None diff --git a/ibis/backends/tests/snapshots/test_interactive/test_default_limit/mysql/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_default_limit/mysql/out.sql new file mode 100644 index 000000000000..b4d624684bfc --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_default_limit/mysql/out.sql @@ -0,0 +1,5 @@ +SELECT + `t0`.`id`, + `t0`.`bool_col` = 1 AS `bool_col` +FROM `functional_alltypes` AS `t0` +LIMIT 11 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/mysql/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/mysql/out.sql new file mode 100644 index 000000000000..b4d624684bfc --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/mysql/out.sql @@ -0,0 +1,5 @@ +SELECT + `t0`.`id`, + `t0`.`bool_col` = 1 AS `bool_col` +FROM `functional_alltypes` AS `t0` +LIMIT 11 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/mysql/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/mysql/out.sql new file mode 100644 index 000000000000..d93091fd3aba --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/mysql/out.sql @@ -0,0 +1,19 @@ +SELECT + SUM(`t1`.`bigint_col`) AS `Sum(bigint_col)` +FROM ( + SELECT + `t0`.`id`, + `t0`.`bool_col` = 1 AS `bool_col`, + `t0`.`tinyint_col`, + `t0`.`smallint_col`, + `t0`.`int_col`, + `t0`.`bigint_col`, + `t0`.`float_col`, + `t0`.`double_col`, + `t0`.`date_string_col`, + `t0`.`string_col`, + `t0`.`timestamp_col`, + `t0`.`year`, + `t0`.`month` + FROM `functional_alltypes` AS `t0` +) AS `t1` \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/mysql/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/mysql/out.sql new file mode 100644 index 000000000000..1c3fb645041c --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/mysql/out.sql @@ -0,0 +1,10 @@ +SELECT + * +FROM ( + SELECT + `t0`.`id`, + `t0`.`bool_col` = 1 AS `bool_col` + FROM `functional_alltypes` AS `t0` + LIMIT 10 +) AS `t2` +LIMIT 11 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/mysql/out.sql b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/mysql/out.sql index fc16f2428d16..ac006b1d5f25 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/mysql/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/mysql/out.sql @@ -1,5 +1,5 @@ SELECT - CASE t0.continent + CASE `t0`.`continent` WHEN 'NA' THEN 'North America' WHEN 'SA' @@ -15,8 +15,8 @@ SELECT WHEN 'AN' THEN 'Antarctica' ELSE 'Unknown continent' - END AS cont, - SUM(t0.population) AS total_pop -FROM countries AS t0 + END AS `cont`, + SUM(`t0`.`population`) AS `total_pop` +FROM `countries` AS `t0` GROUP BY 1 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/mysql/out.sql b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/mysql/out.sql index a3042e85b3e7..db5ddb124e86 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/mysql/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/mysql/out.sql @@ -1,13 +1,9 @@ SELECT - t0.x IN ( + `t0`.`x` IN ( SELECT - t1.x - FROM ( - SELECT - t0.x AS x - FROM t AS t0 - WHERE - t0.x > 2 - ) AS t1 - ) AS `InColumn(x, x)` -FROM t AS t0 \ No newline at end of file + `t0`.`x` + FROM `t` AS `t0` + WHERE + `t0`.`x` > 2 + ) AS `InSubquery(x)` +FROM `t` AS `t0` \ No newline at end of file diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 24fa0c702fd1..d6a3a4d0f0d1 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -18,6 +18,7 @@ ClickHouseDatabaseError, ExaQueryError, GoogleBadRequest, + MySQLNotSupportedError, PolarsInvalidOperationError, Py4JError, PySparkAnalysisException, @@ -948,7 +949,6 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): "datafusion", "impala", "mssql", - "mysql", "polars", "sqlite", "druid", @@ -957,6 +957,7 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): ], raises=com.OperationNotDefinedError, ), + pytest.mark.notyet(["mysql"], raises=com.UnsupportedBackendType), pytest.mark.notyet( ["snowflake"], reason="backend doesn't implement array of quantiles as input", @@ -1359,6 +1360,7 @@ def test_date_quantile(alltypes, func): "::", id="expr", marks=[ + pytest.mark.notyet(["mysql"], raises=com.UnsupportedOperationError), pytest.mark.notyet( ["bigquery"], raises=GoogleBadRequest, @@ -1367,10 +1369,6 @@ def test_date_quantile(alltypes, func): pytest.mark.broken( ["pyspark"], raises=TypeError, reason="Column is not iterable" ), - pytest.mark.broken( - ["mysql"], - raises=sa.exc.ProgrammingError, - ), ], ), ], @@ -1680,7 +1678,7 @@ def test_grouped_case(backend, con): @pytest.mark.notyet(["druid"], raises=sa.exc.ProgrammingError) @pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError) @pytest.mark.notyet(["trino"], raises=TrinoUserError) -@pytest.mark.notyet(["mysql"], raises=sa.exc.NotSupportedError) +@pytest.mark.notyet(["mysql"], raises=MySQLNotSupportedError) @pytest.mark.notyet(["oracle"], raises=sa.exc.DatabaseError) @pytest.mark.notyet(["pyspark"], raises=PySparkAnalysisException) def test_group_concat_over_window(backend, con): diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index ccd0626123de..9267de868535 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -20,6 +20,7 @@ from ibis.backends.tests.errors import ( ClickHouseDatabaseError, GoogleBadRequest, + MySQLOperationalError, PolarsComputeError, PsycoPg2IndeterminateDatatype, PsycoPg2SyntaxError, @@ -30,10 +31,11 @@ pytestmark = [ pytest.mark.never( - ["sqlite", "mysql", "mssql", "exasol"], + ["sqlite", "mssql", "exasol"], reason="No array support", raises=Exception, ), + pytest.mark.never(["mysql"], reason="No array support", raises=(com.UnsupportedBackendType, com.OperationNotDefinedError, MySQLOperationalError)), pytest.mark.notyet(["impala"], reason="No array support", raises=Exception), pytest.mark.notimpl(["druid", "oracle"], raises=Exception), ] @@ -162,7 +164,11 @@ def test_array_index(con, idx): pytest.mark.never( ["mysql"], reason="array types are unsupported", - raises=com.OperationNotDefinedError, + raises=( + com.OperationNotDefinedError, + MySQLOperationalError, + com.UnsupportedBackendType, + ), ), pytest.mark.never( ["sqlite"], reason="array types are unsupported", raises=NotImplementedError @@ -419,7 +425,6 @@ def test_array_slice(backend, start, stop): "polars", "snowflake", "sqlite", - "mysql", ], raises=com.OperationNotDefinedError, ) @@ -429,9 +434,7 @@ def test_array_slice(backend, start, stop): reason="Operation 'ArrayMap' is not implemented for this backend", ) @pytest.mark.notimpl( - ["sqlite"], - raises=NotImplementedError, - reason="Unsupported type: Array: ...", + ["sqlite"], raises=NotImplementedError, reason="Unsupported type: Array: ..." ) @pytest.mark.parametrize( ("input", "output"), @@ -487,7 +490,6 @@ def test_array_map(con, input, output, func): "pandas", "polars", "snowflake", - "mysql", ], raises=com.OperationNotDefinedError, ) @@ -641,7 +643,7 @@ def test_array_remove(con, a): @builtin_array @pytest.mark.notimpl( - ["dask", "datafusion", "impala", "mssql", "polars", "mysql"], + ["dask", "datafusion", "impala", "mssql", "polars"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( @@ -759,8 +761,9 @@ def test_array_union(con, a, b, expected_array): ) +@builtin_array @pytest.mark.notimpl( - ["dask", "datafusion", "impala", "mssql", "pandas", "polars", "mysql", "flink"], + ["dask", "datafusion", "impala", "mssql", "pandas", "polars", "flink"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( @@ -1087,7 +1090,6 @@ def test_unnest_empty_array(con): "polars", "snowflake", "sqlite", - "mysql", "dask", "pandas", ], @@ -1114,7 +1116,6 @@ def test_array_map_with_conflicting_names(backend, con): "polars", "snowflake", "sqlite", - "mysql", "dask", "pandas", ], @@ -1292,9 +1293,11 @@ def test_timestamp_range_zero_step(con, start, stop, step, tzinfo): def test_repr_timestamp_array(con, monkeypatch): monkeypatch.setattr(ibis.options, "interactive", True) - monkeypatch.setattr(ibis.options, "default_backend", con) assert ibis.options.interactive is True + + monkeypatch.setattr(ibis.options, "default_backend", con) assert ibis.options.default_backend is con + expr = ibis.array(pd.date_range("2010-01-01", "2010-01-03", freq="D").tolist()) assert "No translation rule" not in repr(expr) diff --git a/ibis/backends/tests/test_asof_join.py b/ibis/backends/tests/test_asof_join.py index 1250e55ca35f..172075a0a860 100644 --- a/ibis/backends/tests/test_asof_join.py +++ b/ibis/backends/tests/test_asof_join.py @@ -78,13 +78,9 @@ def time_keyed_right(time_keyed_df2): @pytest.mark.parametrize( - ("direction", "op"), - [ - ("backward", operator.ge), - ("forward", operator.le), - ], + ("direction", "op"), [("backward", operator.ge), ("forward", operator.le)] ) -@pytest.mark.notimpl(["datafusion", "snowflake", "trino", "postgres"]) +@pytest.mark.notyet(["datafusion", "snowflake", "trino", "postgres", "mysql"]) def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op): on = op(time_left["time"], time_right["time"]) expr = time_left.asof_join(time_right, on=on, predicates="group") @@ -103,16 +99,12 @@ def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op @pytest.mark.parametrize( - ("direction", "op"), - [ - ("backward", operator.ge), - ("forward", operator.le), - ], + ("direction", "op"), [("backward", operator.ge), ("forward", operator.le)] ) @pytest.mark.broken( ["clickhouse"], raises=AssertionError, reason="`time` is truncated to seconds" ) -@pytest.mark.notimpl(["datafusion", "snowflake", "trino", "postgres"]) +@pytest.mark.notyet(["datafusion", "snowflake", "trino", "postgres", "mysql"]) def test_keyed_asof_join_with_tolerance( con, time_keyed_left, diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 3904ff82fbba..ab0581a4d736 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -70,10 +70,7 @@ def _create_temp_table_with_schema(backend, con, temp_table_name, schema, data=N @pytest.mark.parametrize( "sch", [ - param( - None, - id="no schema", - ), + param(None, id="no schema"), param( ibis.schema( [ @@ -102,12 +99,7 @@ def test_create_table(backend, con, temp_table, lamduh, sch): } ) - obj = lamduh(df) - con.create_table( - temp_table, - obj, - schema=sch, - ) + con.create_table(temp_table, lamduh(df), schema=sch) result = ( con.table(temp_table).execute().sort_values("first_name").reset_index(drop=True) ) @@ -1124,7 +1116,7 @@ def test_repr_mimebundle(alltypes, interactive, expr_type, monkeypatch): @pytest.mark.never( - ["postgres", "mysql", "bigquery", "duckdb"], + ["postgres", "bigquery", "duckdb"], reason="These backends explicitly do support Geo operations", ) @pytest.mark.parametrize("op", [ops.GeoDistance, ops.GeoAsText, ops.GeoUnaryUnion]) diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 4b678b7fb3c0..8f37499438c2 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -13,6 +13,7 @@ from ibis.backends.tests.errors import ( DuckDBNotImplementedException, DuckDBParserException, + MySQLOperationalError, PyDeltaTableError, PySparkAnalysisException, SnowflakeProgrammingError, @@ -356,7 +357,7 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError), pytest.mark.notyet(["trino"], raises=TrinoUserError), pytest.mark.notyet(["oracle"], raises=sa.exc.DatabaseError), - pytest.mark.notyet(["mysql"], raises=sa.exc.OperationalError), + pytest.mark.notyet(["mysql"], raises=MySQLOperationalError), pytest.mark.notyet( ["pyspark"], raises=PySparkAnalysisException, diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index f389e52c3411..a47d2c004907 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -24,7 +24,7 @@ GoogleBadRequest, ImpalaHiveServer2Error, Py4JJavaError, - PsycoPg2InvalidTextRepresentation, + MySQLProgrammingError, SnowflakeProgrammingError, TrinoUserError, ) @@ -1188,10 +1188,6 @@ def test_distinct_on_keep(backend, on, keep): idx=ibis.row_number().over(order_by=_.one, rows=(None, 0)) ) - requires_cache = backend.name() in ("mysql", "impala") - - if requires_cache: - t = t.cache() expr = t.distinct(on=on, keep=keep).order_by(ibis.asc("idx")) result = expr.execute() df = t.execute() @@ -1267,10 +1263,6 @@ def test_distinct_on_keep_is_none(backend, on): idx=ibis.row_number().over(order_by=_.one, rows=(None, 0)) ) - requires_cache = backend.name() in ("mysql", "impala") - - if requires_cache: - t = t.cache() expr = t.distinct(on=on, keep=None).order_by(ibis.asc("idx")) result = expr.execute() df = t.execute() @@ -1380,7 +1372,6 @@ def hash_256(col): "druid", "impala", "mssql", - "mysql", "oracle", "risingwave", "pyspark", @@ -1405,6 +1396,7 @@ def hash_256(col): pytest.mark.notyet(["trino"], raises=TrinoUserError), pytest.mark.broken(["polars"], reason="casts to 1672531200000000000"), pytest.mark.broken(["datafusion"], reason="casts to 1672531200000000"), + pytest.mark.broken(["mysql"], reason="returns 20230101000000"), ], ), ], @@ -1684,7 +1676,7 @@ def test_static_table_slice(backend, slc, expected_count_fn): ) @pytest.mark.notyet( ["mysql"], - raises=sa.exc.ProgrammingError, + raises=MySQLProgrammingError, reason="backend doesn't support dynamic limit/offset", ) @pytest.mark.notyet( @@ -1747,7 +1739,7 @@ def test_dynamic_table_slice(backend, slc, expected_count_fn): @pytest.mark.notyet( ["mysql"], - raises=sa.exc.ProgrammingError, + raises=MySQLProgrammingError, reason="backend doesn't support dynamic limit/offset", ) @pytest.mark.notyet( diff --git a/ibis/backends/tests/test_join.py b/ibis/backends/tests/test_join.py index 95087ba64d8e..1f264c34d455 100644 --- a/ibis/backends/tests/test_join.py +++ b/ibis/backends/tests/test_join.py @@ -256,11 +256,6 @@ def test_join_with_pandas_non_null_typed_columns(batting, awards_players): param( "outer", marks=[ - pytest.mark.notyet( - ["mysql"], - raises=sa.exc.ProgrammingError, - reason="MySQL doesn't support full outer joins natively", - ), pytest.mark.notyet( ["sqlite"], condition=vparse(sqlite3.sqlite_version) < vparse("3.39"), @@ -298,13 +293,9 @@ def test_join_with_trivial_predicate(awards_players, predicate, how, pandas_valu assert len(result) == len(expected) -outer_join_nullability_failures = [ - pytest.mark.notyet( - ["mysql"], - raises=sa.exc.ProgrammingError, - reason="mysql doesn't support full outer joins", - ) -] + [pytest.mark.notyet(["sqlite"])] * (vparse(sqlite3.sqlite_version) < vparse("3.39")) +outer_join_nullability_failures = [pytest.mark.notyet(["sqlite"])] * ( + vparse(sqlite3.sqlite_version) < vparse("3.39") +) @pytest.mark.notimpl( diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 217f3cff10de..17c9cd7d348e 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -22,6 +22,7 @@ ExaQueryError, GoogleBadRequest, ImpalaHiveServer2Error, + MySQLOperationalError, PsycoPg2DivisionByZero, Py4JError, SnowflakeProgrammingError, @@ -267,7 +268,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "postgres": decimal.Decimal("1.1"), "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), - "mysql": 1.1, + "mysql": decimal.Decimal("1"), "mssql": 1.1, "druid": 1.1, "datafusion": decimal.Decimal("1.1"), @@ -320,7 +321,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "postgres": decimal.Decimal("1.1"), "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), - "mysql": 1.1, + "mysql": decimal.Decimal("1.1"), "clickhouse": decimal.Decimal("1.1"), "dask": decimal.Decimal("1.1"), "mssql": 1.1, @@ -369,7 +370,6 @@ def test_numeric_literal(con, backend, expr, expected_types): "postgres": decimal.Decimal("1.1"), "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), - "mysql": 1.1, "clickhouse": decimal.Decimal( "1.10000000000000003193790845333396190208" ), @@ -388,6 +388,7 @@ def test_numeric_literal(con, backend, expr, expected_types): }, marks=[ pytest.mark.notimpl(["exasol"], raises=ExaQueryError), + pytest.mark.notimpl(["mysql"], raises=MySQLOperationalError), pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError), pytest.mark.broken( ["impala"], @@ -456,12 +457,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "An error occurred while calling z:org.apache.spark.sql.functions.lit.", raises=Py4JError, ), - pytest.mark.broken( - ["mysql"], - "(pymysql.err.OperationalError) (1054, \"Unknown column 'Infinity' in 'field list'\")" - "[SQL: SELECT %(param_1)s AS `Decimal('Infinity')`]", - raises=sa.exc.OperationalError, - ), + pytest.mark.notyet(["mysql"], raises=com.UnsupportedOperationError), pytest.mark.broken( ["mssql"], "(pymssql._pymssql.ProgrammingError) (207, b\"Invalid column name 'Infinity'." @@ -542,12 +538,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "An error occurred while calling z:org.apache.spark.sql.functions.lit.", raises=Py4JError, ), - pytest.mark.broken( - ["mysql"], - "(pymysql.err.OperationalError) (1054, \"Unknown column 'Infinity' in 'field list'\")" - "[SQL: SELECT %(param_1)s AS `Decimal('-Infinity')`]", - raises=sa.exc.OperationalError, - ), + pytest.mark.notyet(["mysql"], raises=com.UnsupportedOperationError), pytest.mark.broken( ["mssql"], "(pymssql._pymssql.ProgrammingError) (207, b\"Invalid column name 'Infinity'." @@ -630,12 +621,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "An error occurred while calling z:org.apache.spark.sql.functions.lit.", raises=Py4JError, ), - pytest.mark.broken( - ["mysql"], - "(pymysql.err.OperationalError) (1054, \"Unknown column 'NaN' in 'field list'\")" - "[SQL: SELECT %(param_1)s AS `Decimal('NaN')`]", - raises=sa.exc.OperationalError, - ), + pytest.mark.notyet(["mysql"], raises=com.UnsupportedOperationError), pytest.mark.broken( ["mssql"], "(pymssql._pymssql.ProgrammingError) (207, b\"Invalid column name 'NaN'." @@ -744,25 +730,15 @@ def test_decimal_literal(con, backend, expr, expected_types, expected_result): id="float-literal", marks=[ pytest.mark.notimpl( - ["exasol"], - raises=com.OperationNotDefinedError, - ), - pytest.mark.notimpl( - ["druid"], - raises=com.OperationNotDefinedError, - ), + ["exasol", "druid"], raises=com.OperationNotDefinedError + ) ], ), param( lambda t: ibis.literal(np.nan), lambda t: np.nan, id="nan-literal", - marks=[ - pytest.mark.notimpl( - ["druid"], - raises=com.OperationNotDefinedError, - ) - ], + marks=[pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError)], ), param( lambda t: ibis.literal(np.inf), @@ -770,13 +746,8 @@ def test_decimal_literal(con, backend, expr, expected_types, expected_result): id="inf-literal", marks=[ pytest.mark.notimpl( - ["exasol"], - raises=com.OperationNotDefinedError, - ), - pytest.mark.notimpl( - ["druid"], - raises=com.OperationNotDefinedError, - ), + ["exasol", "druid"], raises=com.OperationNotDefinedError + ) ], ), param( @@ -785,13 +756,8 @@ def test_decimal_literal(con, backend, expr, expected_types, expected_result): id="-inf-literal", marks=[ pytest.mark.notimpl( - ["exasol"], - raises=com.OperationNotDefinedError, - ), - pytest.mark.notimpl( - ["druid"], - raises=com.OperationNotDefinedError, - ), + ["exasol", "druid"], raises=com.OperationNotDefinedError + ) ], ), ], @@ -821,9 +787,9 @@ def test_decimal_literal(con, backend, expr, expected_types, expected_result): ], ) @pytest.mark.notimpl( - ["mysql", "sqlite", "mssql", "oracle", "flink"], - raises=com.OperationNotDefinedError, + ["sqlite", "mssql", "oracle", "flink"], raises=com.OperationNotDefinedError ) +@pytest.mark.notimpl(["mysql"], raises=(MySQLOperationalError, NotImplementedError)) def test_isnan_isinf( backend, con, @@ -1516,6 +1482,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): "snowflake", "trino", "postgres", + "mysql", ], reason="Not SQLAlchemy backends", ) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 0ca72f6a45a5..92c034b8124c 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -27,7 +27,7 @@ ) no_structs = pytest.mark.never( ["impala", "mysql", "sqlite", "mssql"], - raises=(NotImplementedError, sa.exc.CompileError), + raises=(NotImplementedError, sa.exc.CompileError, exc.UnsupportedBackendType), reason="structs not supported in the backend", ) no_struct_literals = pytest.mark.notimpl( diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index f72acf988793..2f71b71fdd8e 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -13,7 +13,6 @@ import ibis.expr.datatypes as dt from ibis.backends.tests.errors import ClickHouseDatabaseError, PySparkPythonException from ibis.common.annotations import ValidationError -from ibis.common.exceptions import OperationNotDefinedError @pytest.mark.parametrize( @@ -933,7 +932,7 @@ def test_substr_with_null_values(backend, alltypes, df): marks=[ pytest.mark.notyet( ["clickhouse", "snowflake", "trino"], - raises=OperationNotDefinedError, + raises=com.OperationNotDefinedError, reason="doesn't support `USERINFO`", ) ], @@ -946,7 +945,7 @@ def test_substr_with_null_values(backend, alltypes, df): marks=[ pytest.mark.notyet( ["snowflake"], - raises=OperationNotDefinedError, + raises=com.OperationNotDefinedError, reason="host is netloc", ), pytest.mark.broken( @@ -1011,12 +1010,15 @@ def test_capitalize(con): @pytest.mark.notimpl( ["dask", "pandas", "polars", "druid", "oracle", "flink"], - raises=OperationNotDefinedError, + raises=com.OperationNotDefinedError, ) @pytest.mark.notyet( - ["impala", "mssql", "mysql", "sqlite", "exasol"], + ["impala", "mssql", "sqlite", "exasol"], reason="no arrays", - raises=OperationNotDefinedError, + raises=com.OperationNotDefinedError, +) +@pytest.mark.never( + ["mysql"], raises=com.OperationNotDefinedError, reason="no array support" ) def test_array_string_join(con): s = ibis.array(["a", "b", "c"]) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index e8a34654835e..4a395dfbac4b 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -10,6 +10,7 @@ import pandas as pd import pytest import sqlalchemy as sa +import sqlglot as sg from pytest import param import ibis @@ -24,6 +25,8 @@ GoogleBadRequest, ImpalaHiveServer2Error, ImpalaOperationalError, + MySQLOperationalError, + MySQLProgrammingError, PolarsComputeError, PolarsPanicException, Py4JJavaError, @@ -393,9 +396,9 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): param( "W", marks=[ - pytest.mark.notimpl(["mysql"], raises=com.UnsupportedOperationError), pytest.mark.notimpl(["impala"], raises=AssertionError), pytest.mark.broken(["sqlite"], raises=AssertionError), + pytest.mark.notimpl(["mysql"], raises=com.UnsupportedOperationError), pytest.mark.broken( ["polars"], raises=AssertionError, @@ -623,12 +626,8 @@ def test_timestamp_truncate(backend, alltypes, df, unit): param( "W", marks=[ - pytest.mark.notimpl( - ["mysql"], - raises=com.UnsupportedOperationError, - reason="Unsupported truncate unit W", - ), pytest.mark.broken(["impala"], raises=AssertionError), + pytest.mark.notyet(["mysql"], raises=com.UnsupportedOperationError), pytest.mark.never( ["flink"], raises=Py4JJavaError, @@ -824,7 +823,7 @@ def test_date_truncate(backend, alltypes, df, unit): pd.Timedelta, marks=[ pytest.mark.notimpl( - ["mysql", "clickhouse"], raises=com.UnsupportedOperationError + ["clickhouse"], raises=com.UnsupportedOperationError ), pytest.mark.notimpl( ["pyspark"], @@ -1028,7 +1027,6 @@ def convert_to_offset(x): [ "dask", "impala", - "mysql", "risingwave", "snowflake", "sqlite", @@ -1036,6 +1034,7 @@ def convert_to_offset(x): ], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl(["mysql"], raises=sg.ParseError), pytest.mark.notimpl( ["druid"], raises=ValidationError, @@ -1055,13 +1054,13 @@ def convert_to_offset(x): "sqlite", "risingwave", "polars", - "mysql", "impala", "snowflake", "bigquery", ], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl(["mysql"], raises=sg.ParseError), pytest.mark.notimpl( ["druid"], raises=ValidationError, @@ -1574,7 +1573,7 @@ def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name @pytest.mark.notimpl( - ["sqlite", "snowflake", "mssql", "oracle"], + ["sqlite", "snowflake", "mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError, ) @pytest.mark.broken( @@ -1587,7 +1586,6 @@ def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name raises=Py4JJavaError, reason="ParseException: Encountered '+ INTERVAL CAST'", ) -@pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) def test_interval_add_cast_scalar(backend, alltypes): timestamp_date = alltypes.timestamp_col.date() delta = ibis.literal(10).cast("interval('D')") @@ -1601,7 +1599,7 @@ def test_interval_add_cast_scalar(backend, alltypes): ["pyspark"], reason="PySpark does not support casting columns to intervals" ) @pytest.mark.notimpl( - ["sqlite", "snowflake", "mssql", "oracle"], + ["sqlite", "snowflake", "mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( @@ -1609,7 +1607,6 @@ def test_interval_add_cast_scalar(backend, alltypes): raises=AttributeError, reason="'StringColumn' object has no attribute 'date'", ) -@pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) def test_interval_add_cast_column(backend, alltypes, df): timestamp_date = alltypes.timestamp_col.date() delta = alltypes.bigint_col.cast("interval('D')") @@ -2024,16 +2021,6 @@ def test_now_from_projection(alltypes): @pytest.mark.notimpl( ["oracle"], raises=sa.exc.DatabaseError, reason="ORA-00936 missing expression" ) -@pytest.mark.broken( - ["mysql"], - raises=sa.exc.ProgrammingError, - reason=( - '(pymysql.err.ProgrammingError) (1064, "You have an error in your SQL syntax; ' - "check the manual that corresponds to your MariaDB server version for " - "the right syntax to use near ' 2, 4) AS `DateFromYMD(2022, 2, 4)`' at line 1\")" - "[SQL: SELECT date(%(param_1)s, %(param_2)s, %(param_3)s) AS `DateFromYMD(2022, 2, 4)`]" - ), -) @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) @pytest.mark.notimpl( @@ -2065,7 +2052,9 @@ def test_date_literal(con, backend): } -@pytest.mark.notimpl(["pandas", "dask", "pyspark"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["pandas", "dask", "pyspark", "mysql"], raises=com.OperationNotDefinedError +) @pytest.mark.notimpl( ["druid"], raises=sa.exc.ProgrammingError, @@ -2076,11 +2065,6 @@ def test_date_literal(con, backend): "make_timestamp(, , , , , )" ), ) -@pytest.mark.broken( - ["mysql"], - raises=sa.exc.OperationalError, - reason="(pymysql.err.OperationalError) (1305, 'FUNCTION ibis_testing.make_timestamp does not exist')", -) @pytest.mark.notimpl( ["oracle"], raises=sa.exc.DatabaseError, reason="ORA-00904: MAKE TIMESTAMP invalid" ) @@ -2106,11 +2090,6 @@ def test_timestamp_literal(con, backend): @pytest.mark.notimpl( ["pandas", "mysql", "dask", "pyspark"], raises=com.OperationNotDefinedError ) -@pytest.mark.notimpl( - ["mysql"], - raises=sa.exc.OperationalError, - reason="FUNCTION ibis_testing.make_timestamp does not exist", -) @pytest.mark.notimpl( ["sqlite"], raises=com.UnsupportedOperationError, @@ -2181,29 +2160,11 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): @pytest.mark.notimpl( - [ - "pandas", - "datafusion", - "dask", - "pyspark", - "polars", - ], + ["pandas", "datafusion", "dask", "pyspark", "polars", "mysql"], raises=com.OperationNotDefinedError, ) @pytest.mark.notyet(["clickhouse", "impala"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) -@pytest.mark.broken( - [ - "mysql", - ], - raises=sa.exc.ProgrammingError, - reason=( - '(pymysql.err.ProgrammingError) (1064, "You have an error in your SQL syntax; check the manual that ' - "corresponds to your MariaDB server version for the right syntax to use near ' 20, 0) AS " - "`TimeFromHMS(16, 20, 0)`' at line 1\")" - "[SQL: SELECT time(%(param_1)s, %(param_2)s, %(param_3)s) AS `TimeFromHMS(16, 20, 0)`]" - ), -) @pytest.mark.broken( ["druid"], raises=sa.exc.ProgrammingError, reason="SQL parse failed" ) @@ -2323,7 +2284,7 @@ def test_extract_time_from_timestamp(con, microsecond): "AttributeError: 'TextClause' object has no attribute 'label'" "If SQLAlchemy >=2 is installed, test fails with the following exception:" "NotImplementedError", - raises=(NotImplementedError, AttributeError), + raises=MySQLProgrammingError, ) @pytest.mark.broken( ["bigquery", "duckdb"], @@ -2359,15 +2320,6 @@ def test_interval_literal(con, backend): @pytest.mark.notimpl(["pandas", "dask", "pyspark"], raises=com.OperationNotDefinedError) -@pytest.mark.broken( - ["mysql"], - raises=sa.exc.ProgrammingError, - reason=( - '(pymysql.err.ProgrammingError) (1064, "You have an error in your SQL syntax; check the manual ' - "that corresponds to your MariaDB server version for the right syntax to use near " - "' CAST(EXTRACT(month FROM t0.timestamp_col) AS SIGNED INTEGER), CAST(EXTRACT(d...' at line 1\")" - ), -) @pytest.mark.broken( ["druid"], raises=AttributeError, @@ -2393,17 +2345,14 @@ def test_date_column_from_ymd(backend, con, alltypes, df): backend.assert_series_equal(golden, result.timestamp_col) -@pytest.mark.notimpl(["pandas", "dask", "pyspark"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["pandas", "dask", "pyspark", "mysql"], raises=com.OperationNotDefinedError +) @pytest.mark.broken( ["druid"], raises=AttributeError, reason="StringColumn' object has no attribute 'year'", ) -@pytest.mark.broken( - ["mysql"], - raises=sa.exc.OperationalError, - reason="(pymysql.err.OperationalError) (1305, 'FUNCTION ibis_testing.make_timestamp does not exist')", -) @pytest.mark.notimpl( ["oracle"], raises=sa.exc.DatabaseError, reason="ORA-00904 make timestamp invalid" ) @@ -2689,6 +2638,11 @@ def test_large_timestamp(con): reason="doesn't support nanoseconds", raises=sa.exc.ProgrammingError, ), + pytest.mark.notyet( + ["mysql"], + reason="doesn't support nanoseconds", + raises=MySQLOperationalError, + ), pytest.mark.notyet( ["bigquery"], reason=( @@ -2711,7 +2665,6 @@ def test_large_timestamp(con): ), ], ) -@pytest.mark.notyet(["mysql"], raises=AssertionError) @pytest.mark.broken( ["druid"], raises=sa.exc.ProgrammingError, diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index 1393841ce83b..caf8dce62923 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -16,6 +16,7 @@ ClickHouseDatabaseError, GoogleBadRequest, ImpalaHiveServer2Error, + MySQLOperationalError, Py4JJavaError, PySparkAnalysisException, SnowflakeProgrammingError, @@ -864,7 +865,6 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=com.UnsupportedOperationError, reason="Flink engine does not support generic window clause with no order by", ), - pytest.mark.broken(["mysql"], raises=sa.exc.OperationalError), pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError), pytest.mark.notyet( ["snowflake"], @@ -917,7 +917,6 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=com.UnsupportedOperationError, reason="Flink engine does not support generic window clause with no order by", ), - pytest.mark.broken(["mysql"], raises=sa.exc.OperationalError), pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError), pytest.mark.notyet( ["snowflake"], @@ -1047,6 +1046,11 @@ def test_ungrouped_unbounded_window( reason="Feature is not yet implemented: window frame in `RANGE` mode is not supported yet", ) @pytest.mark.notyet(["mssql"], raises=sa.exc.ProgrammingError) +@pytest.mark.broken( + ["mysql"], + raises=MySQLOperationalError, + reason="https://github.com/tobymao/sqlglot/issues/2779", +) def test_grouped_bounded_range_window(backend, alltypes, df): # Explanation of the range window spec below: # @@ -1225,7 +1229,7 @@ def test_first_last(backend): ["impala"], raises=ImpalaHiveServer2Error, reason="not supported by Impala" ) @pytest.mark.notyet( - ["mysql"], raises=sa.exc.ProgrammingError, reason="not supported by MySQL" + ["mysql"], raises=MySQLOperationalError, reason="not supported by MySQL" ) @pytest.mark.notyet( ["mssql", "oracle", "polars", "snowflake", "sqlite"], diff --git a/ibis/formats/pandas.py b/ibis/formats/pandas.py index e202b48f0621..0522e965c1d8 100644 --- a/ibis/formats/pandas.py +++ b/ibis/formats/pandas.py @@ -152,9 +152,6 @@ def convert_table(cls, df, schema): def convert_column(cls, obj, dtype): pandas_type = PandasType.from_ibis(dtype) - if obj.dtype == pandas_type and dtype.is_primitive(): - return obj - method_name = f"convert_{dtype.__class__.__name__}" convert_method = getattr(cls, method_name, cls.convert_default) @@ -185,6 +182,8 @@ def convert_GeoSpatial(cls, s, dtype, pandas_type): @classmethod def convert_default(cls, s, dtype, pandas_type): + if s.dtype == pandas_type and dtype.is_primitive(): + return s try: return s.astype(pandas_type) except Exception: # noqa: BLE001 diff --git a/pyproject.toml b/pyproject.toml index 313a364eb6e6..76a4825f9a0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,6 @@ datafusion = { version = ">=0.6,<36", optional = true } db-dtypes = { version = ">=0.3,<2", optional = true } deltalake = { version = ">=0.9.0,<1", optional = true } duckdb = { version = ">=0.8.1,<1", optional = true } -geoalchemy2 = { version = ">=0.6.3,<1", optional = true } geopandas = { version = ">=0.6,<1", optional = true } google-cloud-bigquery = { version = ">=3,<4", optional = true } google-cloud-bigquery-storage = { version = ">=2,<3", optional = true } @@ -151,7 +150,6 @@ all = [ "db-dtypes", "duckdb", "deltalake", - "geoalchemy2", "geopandas", "google-cloud-bigquery", "google-cloud-bigquery-storage", @@ -189,10 +187,10 @@ druid = ["pydruid", "sqlalchemy"] duckdb = ["duckdb"] exasol = ["sqlalchemy", "sqlalchemy-exasol", "sqlalchemy-views"] flink = [] -geospatial = ["geoalchemy2", "geopandas", "shapely"] +geospatial = ["geopandas", "shapely"] impala = ["impyla", "sqlalchemy"] mssql = ["sqlalchemy", "pyodbc", "sqlalchemy-views"] -mysql = ["sqlalchemy", "pymysql", "sqlalchemy-views"] +mysql = ["pymysql"] oracle = ["sqlalchemy", "oracledb", "packaging", "sqlalchemy-views"] pandas = ["regex"] polars = ["polars", "packaging"]