diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index 5930da928be8..96d5d2efa2d7 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -155,13 +155,6 @@ jobs: - druid services: - druid - # - name: oracle - # title: Oracle - # serial: true - # extras: - # - oracle - # services: - # - oracle - name: exasol title: Exasol serial: true @@ -169,6 +162,13 @@ jobs: - exasol services: - exasol + - name: oracle + title: Oracle + serial: true + extras: + - oracle + services: + - oracle # - name: flink # title: Flink # serial: true @@ -265,15 +265,15 @@ jobs: - druid services: - druid - # - os: windows-latest - # backend: - # name: oracle - # title: Oracle - # serial: true - # extras: - # - oracle - # services: - # - oracle + - os: windows-latest + backend: + name: oracle + title: Oracle + serial: true + extras: + - oracle + services: + - oracle # - os: windows-latest # backend: # name: flink diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index 273295f83b2f..a54959092e4e 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -563,6 +563,29 @@ class DruidType(SqlglotType): class OracleType(SqlglotType): dialect = "oracle" + default_decimal_precision = 38 + default_decimal_scale = 9 + + default_temporal_scale = 9 + + unknown_type_strings = FrozenDict({"raw": dt.binary}) + + @classmethod + def _from_sqlglot_FLOAT(cls) -> dt.Float64: + return dt.Float64(nullable=cls.default_nullable) + + @classmethod + def _from_sqlglot_DECIMAL(cls, precision=None, scale=None) -> dt.Decimal: + if scale is None or int(scale.this.this) == 0: + return dt.Int64(nullable=cls.default_nullable) + else: + return super()._from_sqlglot_DECIMAL(precision, scale) + + @classmethod + def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType: + nullable = " NOT NULL" if not dtype.nullable else "" + return "VARCHAR2(4000)" + nullable + class SnowflakeType(SqlglotType): dialect = "snowflake" diff --git a/ibis/backends/conftest.py b/ibis/backends/conftest.py index 103ebc477ba2..b6de8b4b70d8 100644 --- a/ibis/backends/conftest.py +++ b/ibis/backends/conftest.py @@ -22,6 +22,7 @@ from ibis import util from ibis.backends.base import CanCreateDatabase, CanCreateSchema, _get_backend_names from ibis.conftest import WINDOWS +from ibis.util import promote_tuple if TYPE_CHECKING: from collections.abc import Iterable @@ -418,6 +419,13 @@ def pytest_runtest_call(item): # TODO: there has to be a better way than hacking `_fixtureinfo` item._fixtureinfo.argnames += ("backend", "snapshot") + def _filter_none_from_raises(kwargs): + # Filter out any None values from kwargs['raises'] + # to cover any missing backend error types as defined in ibis/backends/tests/errors.py + if (raises := kwargs.get("raises")) is not None: + kwargs["raises"] = tuple(filter(None, promote_tuple(raises))) + return kwargs + # Ibis hasn't exposed existing functionality # This xfails so that you know when it starts to pass for marker in item.iter_markers(name="notimpl"): @@ -429,6 +437,7 @@ def pytest_runtest_call(item): raise ValueError("notimpl requires a raises") kwargs = marker.kwargs.copy() kwargs.setdefault("reason", f"Feature not yet exposed in {backend}") + kwargs = _filter_none_from_raises(kwargs) item.add_marker(pytest.mark.xfail(**kwargs)) # Functionality is unavailable upstream (but could be) @@ -443,13 +452,16 @@ def pytest_runtest_call(item): kwargs = marker.kwargs.copy() kwargs.setdefault("reason", f"Feature not available upstream for {backend}") + kwargs = _filter_none_from_raises(kwargs) item.add_marker(pytest.mark.xfail(**kwargs)) for marker in item.iter_markers(name="never"): if backend in marker.args[0]: if "reason" not in marker.kwargs.keys(): raise ValueError("never requires a reason") - item.add_marker(pytest.mark.xfail(**marker.kwargs)) + kwargs = marker.kwargs.copy() + kwargs = _filter_none_from_raises(kwargs) + item.add_marker(pytest.mark.xfail(**kwargs)) # Something has been exposed as broken by a new test and it shouldn't be # imperative for a contributor to fix it just because they happened to @@ -464,10 +476,12 @@ def pytest_runtest_call(item): kwargs = marker.kwargs.copy() kwargs.setdefault("reason", f"Feature is failing on {backend}") + kwargs = _filter_none_from_raises(kwargs) item.add_marker(pytest.mark.xfail(**kwargs)) for marker in item.iter_markers(name="xfail_version"): kwargs = marker.kwargs.copy() + kwargs = _filter_none_from_raises(kwargs) if backend not in kwargs: continue @@ -549,10 +563,7 @@ def ddl_con(ddl_backend): return ddl_backend.connection -@pytest.fixture( - params=_get_backends_to_test(keep=("mssql", "oracle", "sqlite")), - scope="session", -) +@pytest.fixture(params=_get_backends_to_test(keep=("mssql", "sqlite")), scope="session") def alchemy_backend(request, data_dir, tmp_path_factory, worker_id): """Set up the SQLAlchemy-based backends.""" return _setup_backend(request, data_dir, tmp_path_factory, worker_id) diff --git a/ibis/backends/oracle/__init__.py b/ibis/backends/oracle/__init__.py index cd5ca715f28e..1f8c430c96cd 100644 --- a/ibis/backends/oracle/__init__.py +++ b/ibis/backends/oracle/__init__.py @@ -2,87 +2,43 @@ from __future__ import annotations -import atexit import contextlib -import sys +import re import warnings +from functools import cached_property +from operator import itemgetter from typing import TYPE_CHECKING, Any import oracledb import sqlglot as sg - +import sqlglot.expressions as sge + +import ibis +import ibis.common.exceptions as exc +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +import ibis.expr.schema as sch +import ibis.expr.types as ir from ibis import util - -# Wow, this is truly horrible -# Get out your clippers, it's time to shave a yak. -# -# 1. oracledb is only supported in sqlalchemy 2.0 -# 2. Ergo, module hacking is required to avoid doing a silly amount of work -# to create multiple lockfiles or port snowflake away from sqlalchemy -# 3. Also the version needs to be spoofed to be >= 7 or else the cx_Oracle -# dialect barfs -oracledb.__version__ = oracledb.version = "7" - -sys.modules["cx_Oracle"] = oracledb - -import sqlalchemy as sa # noqa: E402 - -import ibis.common.exceptions as exc # noqa: E402 -import ibis.expr.datatypes as dt # noqa: E402 -import ibis.expr.operations as ops # noqa: E402 -import ibis.expr.schema as sch # noqa: E402 -from ibis.backends.base.sql.alchemy import ( # noqa: E402 - AlchemyCompiler, - AlchemyExprTranslator, - BaseAlchemyBackend, -) -from ibis.backends.base.sqlglot import STAR, C # noqa: E402 -from ibis.backends.oracle.datatypes import OracleType # noqa: E402 -from ibis.backends.oracle.registry import operation_registry # noqa: E402 -from ibis.expr.rewrites import rewrite_sample # noqa: E402 +from ibis.backends.base.sqlglot import STAR, SQLGlotBackend +from ibis.backends.base.sqlglot.compiler import TRUE, C +from ibis.backends.oracle.compiler import OracleCompiler if TYPE_CHECKING: from collections.abc import Iterable + import pandas as pd + import pyrrow as pa -class OracleExprTranslator(AlchemyExprTranslator): - _registry = operation_registry.copy() - _rewrites = AlchemyExprTranslator._rewrites.copy() - _dialect_name = "oracle" - _has_reduction_filter_syntax = False - _require_order_by = ( - *AlchemyExprTranslator._require_order_by, - ops.Reduction, - ops.Lag, - ops.Lead, - ) - - _forbids_frame_clause = ( - *AlchemyExprTranslator._forbids_frame_clause, - ops.Lag, - ops.Lead, - ) - - _quote_column_names = True - _quote_table_names = True - - type_mapper = OracleType - -class OracleCompiler(AlchemyCompiler): - translator_class = OracleExprTranslator - support_values_syntax_in_select = False - supports_indexed_grouping_keys = False - null_limit = None - rewrites = AlchemyCompiler.rewrites | rewrite_sample - - -class Backend(BaseAlchemyBackend): +class Backend(SQLGlotBackend): name = "oracle" - compiler = OracleCompiler - supports_create_or_replace = False - supports_temporary_tables = True - _temporary_prefix = "GLOBAL TEMPORARY" + compiler = OracleCompiler() + + @cached_property + def version(self): + matched = re.search(r"(\d+)\.(\d+)\.(\d+)", self.con.version) + return ".".join(matched.groups()) def do_connect( self, @@ -143,41 +99,70 @@ def do_connect( if dsn is None: dsn = oracledb.makedsn(host, port, service_name=service_name, sid=sid) - url = sa.engine.url.make_url(f"oracle://{user}:{password}@{dsn}") - - engine = sa.create_engine( - url, - poolclass=sa.pool.StaticPool, - # We set the statement cache size to 0 because Oracle will otherwise - # attempt to reuse prepared statements even if the type of the bound variable - # has changed. - # This is apparently accepted behavior. - # https://python-oracledb.readthedocs.io/en/latest/user_guide/appendix_b.html#statement-caching-in-thin-and-thick-modes - connect_args={"stmtcachesize": 0}, - ) - super().do_connect(engine) + # We set the statement cache size to 0 because Oracle will otherwise + # attempt to reuse prepared statements even if the type of the bound variable + # has changed. + # This is apparently accepted behavior. + # https://python-oracledb.readthedocs.io/en/latest/user_guide/appendix_b.html#statement-caching-in-thin-and-thick-modes + self.con = oracledb.connect(dsn, user=user, password=password, stmtcachesize=0) - def normalize_name(name): - if name is None: - return None - elif not name: - return "" - elif name.lower() == name: - return sa.sql.quoted_name(name, quote=True) - else: - return name + # turn on autocommit + # TODO: it would be great if this worked but it doesn't seem to do the trick + # I had to hack in the commit lines to the compiler + # self.con.autocommit = True - self.con.dialect.normalize_name = normalize_name + # Set to ensure decimals come back as decimals + oracledb.defaults.fetch_decimals = True def _from_url(self, url: str, **kwargs): return self.do_connect(user=url.username, password=url.password, dsn=url.host) @property def current_database(self) -> str: - return self._scalar_query("SELECT * FROM global_name") + with self._safe_raw_sql(sg.select(STAR).from_("global_name")) as cur: + [(database,)] = cur.fetchall() + return database + + @contextlib.contextmanager + def begin(self): + con = self.con + cur = con.cursor() + try: + yield cur + except Exception: + con.rollback() + raise + else: + con.commit() + finally: + cur.close() + + @contextlib.contextmanager + def _safe_raw_sql(self, *args, **kwargs): + with contextlib.closing(self.raw_sql(*args, **kwargs)) as result: + yield result + + def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: + with contextlib.suppress(AttributeError): + query = query.sql(dialect=self.name) + + con = self.con + cursor = con.cursor() - def list_tables(self, like=None, schema=None): + try: + cursor.execute(query, **kwargs) + except Exception: + con.rollback() + cursor.close() + raise + else: + con.commit() + return cursor + + def list_tables( + self, like: str | None = None, schema: str | None = None + ) -> list[str]: """List the tables in the database. Parameters @@ -186,17 +171,214 @@ def list_tables(self, like=None, schema=None): A pattern to use for listing tables. schema The schema to perform the list against. + """ + conditions = [TRUE] + + if schema is None: + schema = self.con.username.upper() + conditions = C.owner.eq(sge.convert(schema.upper())) - ::: {.callout-warning} - ## `schema` refers to database hierarchy + tables = ( + sg.select("table_name", "owner") + .from_(sg.table("all_tables")) + .distinct() + .where(conditions) + ) + views = ( + sg.select("view_name", "owner") + .from_(sg.table("all_views")) + .distinct() + .where(conditions) + ) + sql = tables.union(views).sql(self.name) - The `schema` parameter does **not** refer to the column names and - types of `table`. - ::: + with self._safe_raw_sql(sql) as cur: + out = cur.fetchall() + + return self._filter_with_like(map(itemgetter(0), out), like) + + def list_schemas( + self, like: str | None = None, database: str | None = None + ) -> list[str]: + if database is not None: + raise exc.UnsupportedArgumentError( + "No cross-database schema access in Oracle" + ) + + query = sg.select("username").from_("all_users").order_by("username") + + with self._safe_raw_sql(query) as con: + schemata = list(map(itemgetter(0), con)) + + return self._filter_with_like(schemata, like) + + def get_schema( + self, name: str, schema: str | None = None, database: str | None = None + ) -> sch.Schema: + if schema is None: + schema = self.con.username.upper() + stmt = ( + sg.select( + "column_name", + "data_type", + sg.column("nullable").eq(sge.convert("Y")).as_("nullable"), + ) + .from_(sg.table("all_tab_columns")) + .where(sg.column("table_name").eq(sge.convert(name))) + .where(sg.column("owner").eq(sge.convert(schema))) + ) + with self._safe_raw_sql(stmt) as cur: + result = cur.fetchall() + + if not result: + raise exc.IbisError(f"Table not found: {name!r}") + + type_mapper = self.compiler.type_mapper + fields = { + name: type_mapper.from_string(type_string, nullable=nullable) + for name, type_string, nullable in result + } + + return sch.Schema(fields) + + def create_table( + self, + name: str, + obj: pd.DataFrame | pa.Table | ir.Table | None = None, + *, + schema: ibis.Schema | None = None, + database: str | None = None, + temp: bool = False, + overwrite: bool = False, + ): + """Create a table in Oracle. + + Parameters + ---------- + name + Name of the table to create + obj + The data with which to populate the table; optional, but at least + one of `obj` or `schema` must be specified + schema + The schema of the table to create; optional, but at least one of + `obj` or `schema` must be specified + database + The name of the database in which to create the table; if not + passed, the current database is used. + temp + Create a temporary table + overwrite + If `True`, replace the table if it already exists, otherwise fail + if the table exists """ - tables = self.inspector.get_table_names(schema=schema) - views = self.inspector.get_view_names(schema=schema) - return self._filter_with_like(tables + views, like) + if obj is None and schema is None: + raise ValueError("Either `obj` or `schema` must be specified") + + properties = [] + + if temp: + properties.append(sge.TemporaryProperty()) + + if obj is not None: + if not isinstance(obj, ir.Expr): + table = ibis.memtable(obj) + else: + table = obj + + self._run_pre_execute_hooks(table) + + query = self._to_sqlglot(table) + else: + query = None + + column_defs = [ + sge.ColumnDef( + this=sg.to_identifier(colname, quoted=self.compiler.quoted), + kind=self.compiler.type_mapper.from_ibis(typ), + constraints=( + None + if typ.nullable + else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())] + ), + ) + for colname, typ in (schema or table.schema()).items() + ] + + if overwrite: + temp_name = util.gen_name(f"{self.name}_table") + else: + temp_name = name + + table = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted) + target = sge.Schema(this=table, expressions=column_defs) + + create_stmt = sge.Create( + kind="TABLE", + this=target, + properties=sge.Properties(expressions=properties), + ) + + this = sg.table(name, catalog=database, quoted=self.compiler.quoted) + with self._safe_raw_sql(create_stmt) as cur: + if query is not None: + insert_stmt = sge.Insert(this=table, expression=query).sql(self.name) + cur.execute(insert_stmt) + + if overwrite: + cur.execute( + sge.Drop(kind="TABLE", this=this, exists=True).sql(self.name) + ) + cur.execute( + f"ALTER TABLE IF EXISTS {table.sql(self.name)} RENAME TO {this.sql(self.name)}" + ) + + if schema is None: + return self.table(name, schema=database) + + # preserve the input schema if it was provided + return ops.DatabaseTable( + name, schema=schema, source=self, namespace=ops.Namespace(database=database) + ).to_expr() + + def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: + schema = op.schema + + # only register if we haven't already done so + if (name := op.name) not in self.list_tables(): + quoted = self.compiler.quoted + column_defs = [ + sg.exp.ColumnDef( + this=sg.to_identifier(colname, quoted=quoted), + kind=self.compiler.type_mapper.from_ibis(typ), + constraints=( + None + if typ.nullable + else [ + sg.exp.ColumnConstraint( + kind=sg.exp.NotNullColumnConstraint() + ) + ] + ), + ) + for colname, typ in schema.items() + ] + + create_stmt = sg.exp.Create( + kind="TABLE", + this=sg.exp.Schema( + this=sg.to_identifier(name, quoted=quoted), expressions=column_defs + ), + ).sql(self.name, pretty=True) + + data = op.data.to_frame().itertuples(index=False) + specs = ", ".join(f":{i}" for i, _ in enumerate(schema)) + table = sg.table(name, quoted=quoted).sql(self.name) + insert_stmt = f"INSERT INTO {table} VALUES ({specs})" + with self.begin() as cur: + cur.execute(create_stmt) + for row in data: + cur.execute(insert_stmt, row) def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: name = util.gen_name("oracle_metadata") @@ -211,6 +393,17 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: if isinstance(sg_expr, sg.exp.Table): sg_expr = sg.select(STAR).from_(sg_expr) + # TODO(gforsyth): followup -- this should probably be made a default + # transform for quoting backends + def transformer(node): + if isinstance(node, sg.exp.Table): + return sg.table(node.name, quoted=True) + elif isinstance(node, sg.exp.Column): + return sg.column(col=node.name, quoted=True) + return node + + sg_expr = sg_expr.transform(transformer) + this = sg.table(name, quoted=True) create_view = sg.exp.Create(kind="VIEW", this=this, expression=sg_expr).sql( dialect @@ -232,13 +425,14 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: ) with self.begin() as con: - con.exec_driver_sql(create_view) + con.execute(create_view) try: - results = con.exec_driver_sql(metadata_query).fetchall() + results = con.execute(metadata_query).fetchall() finally: # drop the view no matter what - con.exec_driver_sql(drop_view) + con.execute(drop_view) + # TODO: hand all this off to the type mapper for name, type_string, precision, scale, nullable in results: # NUMBER(null, null) --> FLOAT # (null, null) --> from_string() @@ -269,26 +463,33 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: typ = dt.Decimal(precision=precision, scale=scale, nullable=nullable) else: - typ = OracleType.from_string(type_string, nullable=nullable) + typ = self.compiler.type_mapper.from_string( + type_string, nullable=nullable + ) yield name, typ - def _table_from_schema( - self, - name: str, - schema: sch.Schema, - temp: bool = False, - database: str | None = None, - **kwargs: Any, - ) -> sa.Table: - if temp: - kwargs["oracle_on_commit"] = "PRESERVE ROWS" - t = super()._table_from_schema(name, schema, temp, database, **kwargs) - if temp: - atexit.register(self._clean_up_tmp_table, t) - return t + def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: + # TODO(gforsyth): this can probably be generalized a bit and put into + # the base backend (or a mixin) + import pandas as pd + + from ibis.backends.oracle.converter import OraclePandasData + + try: + df = pd.DataFrame.from_records( + cursor, columns=schema.names, coerce_float=True + ) + except Exception: + # clean up the cursor if we fail to create the DataFrame + # + # in the sqlite case failing to close the cursor results in + # artificially locked tables + cursor.close() + raise + df = OraclePandasData.convert_table(df, schema) + return df def _clean_up_tmp_table(self, name: str) -> None: - tmptable = self._get_sqla_table(name, autoload=False) with self.begin() as bind: # global temporary tables cannot be dropped without first truncating them # @@ -296,10 +497,10 @@ def _clean_up_tmp_table(self, name: str) -> None: # # ignore DatabaseError exceptions because the table may not exist # because it's already been deleted - with contextlib.suppress(sa.exc.DatabaseError): - bind.exec_driver_sql(f'TRUNCATE TABLE "{tmptable.name}"') - with contextlib.suppress(sa.exc.DatabaseError): - tmptable.drop(bind=bind) + with contextlib.suppress(oracledb.DatabaseError): + bind.execute(f'TRUNCATE TABLE "{name}"') + with contextlib.suppress(oracledb.DatabaseError): + bind.execute(f'DROP TABLE "{name}"') def _clean_up_cached_table(self, op): self._clean_up_tmp_table(op.name) diff --git a/ibis/backends/oracle/compiler.py b/ibis/backends/oracle/compiler.py new file mode 100644 index 000000000000..63b171458195 --- /dev/null +++ b/ibis/backends/oracle/compiler.py @@ -0,0 +1,590 @@ +from __future__ import annotations + +from functools import singledispatchmethod + +import sqlglot as sg +import sqlglot.expressions as sge +import toolz +from public import public +from sqlglot.dialects import Oracle +from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func + +import ibis +import ibis.common.exceptions as com +import ibis.expr.operations as ops +from ibis.backends.base.sqlglot.compiler import NULL, STAR, SQLGlotCompiler +from ibis.backends.base.sqlglot.datatypes import OracleType +from ibis.backends.base.sqlglot.rewrites import Window, replace_log2, replace_log10 +from ibis.common.patterns import replace +from ibis.expr.analysis import p, x, y +from ibis.expr.rewrites import rewrite_sample + + +def _create_sql(self, expression: sge.Create) -> str: + # TODO: should we use CREATE PRIVATE instead? That will set an implicit + # lower bound of Oracle 18c + properties = expression.args.get("properties") + temporary = any( + isinstance(prop, sge.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) + + kind = expression.args["kind"] + if (obj := kind.upper()) in ("TABLE", "VIEW") and temporary: + if expression.expression: + return f"CREATE GLOBAL TEMPORARY {obj} {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" + else: + # TODO: why does autocommit not work here? need to specify the ON COMMIT part... + return f"CREATE GLOBAL TEMPORARY {obj} {self.sql(expression, 'this')} ON COMMIT PRESERVE ROWS" + + return create_with_partitions_sql(self, expression) + + +def _datatype_sql(self: Oracle.Generator, expression: sge.DataType) -> str: + # Use this to handle correctly formatting timestamp precision + # e.g. TIMESTAMP (scale) WITH TIME ZONE vs. TIMESTAMP WITH TIME ZONE(scale) + if expression.is_type("timestamptz"): + for exp in expression.expressions: + if isinstance(exp, sge.DataTypeParam): + return f"TIMESTAMP ({self.sql(exp, 'this')}) WITH TIME ZONE" + return "TIMESTAMP WITH TIME ZONE" + return self.datatype_sql(expression) + + +Oracle.Generator.TRANSFORMS |= { + sge.LogicalOr: rename_func("max"), + sge.LogicalAnd: rename_func("min"), + sge.VariancePop: rename_func("var_pop"), + sge.Variance: rename_func("var_samp"), + sge.Stddev: rename_func("stddev_pop"), + sge.StddevPop: rename_func("stddev_pop"), + sge.StddevSamp: rename_func("stddev_samp"), + sge.ApproxDistinct: rename_func("approx_count_distinct"), + sge.Create: _create_sql, + sge.Select: sg.transforms.preprocess([sg.transforms.eliminate_semi_and_anti_joins]), + sge.DataType: _datatype_sql, +} + + +@replace(p.WindowFunction(p.First(x, y))) +def rewrite_first(_, x, y): + if y is not None: + raise com.UnsupportedOperationError( + "`first` aggregate over window does not support `where`" + ) + return _.copy(func=ops.FirstValue(x)) + + +@replace(p.WindowFunction(p.Last(x, y))) +def rewrite_last(_, x, y): + if y is not None: + raise com.UnsupportedOperationError( + "`last` aggregate over window does not support `where`" + ) + return _.copy(func=ops.LastValue(x)) + + +@replace(p.WindowFunction(frame=x @ p.WindowFrame(order_by=()))) +def rewrite_empty_order_by_window(_, x): + return _.copy(frame=x.copy(order_by=(ibis.NA,))) + + +@replace(p.WindowFunction(p.RowNumber | p.NTile, x)) +def exclude_unsupported_window_frame_from_row_number(_, x): + return ops.Subtract(_.copy(frame=x.copy(start=None, end=None)), 1) + + +@replace( + p.WindowFunction( + p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All, + x @ p.WindowFrame(start=None), + ) +) +def exclude_unsupported_window_frame_from_ops(_, x): + return _.copy(frame=x.copy(start=None, end=None)) + + +@public +class OracleCompiler(SQLGlotCompiler): + __slots__ = () + + dialect = "oracle" + quoted = True + type_mapper = OracleType + rewrites = ( + exclude_unsupported_window_frame_from_row_number, + exclude_unsupported_window_frame_from_ops, + rewrite_first, + rewrite_last, + rewrite_empty_order_by_window, + rewrite_sample, + replace_log2, + replace_log10, + *SQLGlotCompiler.rewrites, + ) + + NAN = sge.Literal.number("binary_double_nan") + """Backend's NaN literal.""" + + POS_INF = sge.Literal.number("binary_double_infinity") + """Backend's positive infinity literal.""" + + NEG_INF = sge.Literal.number("-binary_double_infinity") + """Backend's negative infinity literal.""" + + def _aggregate(self, funcname: str, *args, where): + func = self.f[funcname] + if where is not None: + args = tuple(self.if_(where, arg) for arg in args) + return func(*args) + + @staticmethod + def _generate_groups(groups): + return groups + + @singledispatchmethod + def visit_node(self, op, **kwargs): + return super().visit_node(op, **kwargs) + + @visit_node.register(ops.Equals) + def visit_Equals(self, op, *, left, right): + # Oracle didn't have proper boolean types until recently and we handle them + # as integers so we end up with things like "t0"."bool_col" = 1 (for True) + # but then if we are testing that a boolean column IS True, it gets rendered as + # "t0"."bool_col" = 1 = 1 + # so intercept that and change it to WHERE (bool_col = 1) + # TODO(gil): there must be a better way to do this + if op.dtype.is_boolean() and isinstance(right, sge.Boolean): + if right.this: + return left + else: + return sg.not_(left) + return super().visit_Equals(op, left=left, right=right) + + @visit_node.register(ops.IsNull) + def visit_IsNull(self, op, *, arg): + # TODO(gil): find a better way to handle this + # but CASE WHEN (bool_col = 1) IS NULL isn't valid and we can simply check if + # bool_col is null + if isinstance(arg, sge.EQ): + return arg.this.is_(NULL) + return arg.is_(NULL) + + @visit_node.register(ops.Literal) + def visit_Literal(self, op, *, value, dtype): + # avoid casting NULL -- oracle handling for these casts is... complicated + if value is None: + return NULL + elif dtype.is_timestamp() or dtype.is_time(): + if getattr(dtype, "timezone", None) is not None: + return self.f.to_timestamp_tz( + value.isoformat(), 'YYYY-MM-DD"T"HH24:MI:SS.FF6TZH:TZM' + ) + else: + return self.f.to_timestamp( + value.isoformat(), 'YYYY-MM-DD"T"HH24:MI:SS.FF6' + ) + elif dtype.is_date(): + return self.f.to_date( + f"{value.year:04d}-{value.month:02d}-{value.day:02d}", "FXYYYY-MM-DD" + ) + elif dtype.is_uuid(): + return sge.convert(str(value)) + elif dtype.is_interval(): + if dtype.unit.short in ("Y", "M"): + return self.f.numtoyminterval(value, dtype.unit.name) + elif dtype.unit.short in ("D", "h", "m", "s"): + return self.f.numtodsinterval(value, dtype.unit.name) + else: + raise com.UnsupportedOperationError( + f"Intervals with precision {dtype.unit.name} not supported in Oracle." + ) + + return super().visit_Literal(op, value=value, dtype=dtype) + + @visit_node.register(ops.Cast) + def visit_Cast(self, op, *, arg, to): + if to.is_interval(): + # CASTing to an INTERVAL in Oracle requires specifying digits of + # precision that are a pain. There are two helper functions that + # should be used instead. + if to.unit.short in ("D", "h", "m", "s"): + return self.f.numtodsinterval(arg, to.unit.name) + elif to.unit.short in ("Y", "M"): + return self.f.numtoyminterval(arg, to.unit.name) + else: + raise com.UnsupportedArgumentError( + f"Interval {to.unit.name} not supported by Oracle" + ) + return self.cast(arg, to) + + @visit_node.register(ops.Limit) + def visit_Limit(self, op, *, parent, n, offset): + # push limit/offset into subqueries + if isinstance(parent, sge.Subquery) and parent.this.args.get("limit") is None: + result = parent.this + alias = parent.alias + else: + result = sg.select(STAR).from_(parent) + alias = None + + if isinstance(n, int): + result = result.limit(n) + elif n is not None: + raise com.UnsupportedArgumentError( + "No support for dynamic limit in the Oracle backend." + ) + # TODO: re-enable this for dynamic limits + # but it should be paired with offsets working + # result = result.where(C.ROWNUM <= sg.select(n).from_(parent).subquery()) + else: + assert n is None, n + if self.no_limit_value is not None: + result = result.limit(self.no_limit_value) + + assert offset is not None, "offset is None" + + if offset > 0: + raise com.UnsupportedArgumentError( + "No support for limit offsets in the Oracle backend." + ) + + if alias is not None: + return result.subquery(alias) + return result + + @visit_node.register(ops.Date) + def visit_Date(self, op, *, arg): + return sg.cast(arg, to="date") + + @visit_node.register(ops.IsNan) + def visit_IsNan(self, op, *, arg): + return arg.eq(self.NAN) + + @visit_node.register(ops.Log) + def visit_Log(self, op, *, arg, base): + return self.f.log(base, arg, dialect=self.dialect) + + @visit_node.register(ops.IsInf) + def visit_IsInf(self, op, *, arg): + return arg.isin(self.POS_INF, self.NEG_INF) + + @visit_node.register(ops.RandomScalar) + def visit_RandomScalar(self, op): + # Not using FuncGen here because of dotted function call + return sg.func("dbms_random.value") + + @visit_node.register(ops.Pi) + def visit_Pi(self, op): + return self.f.acos(-1) + + @visit_node.register(ops.Cot) + def visit_Cot(self, op, *, arg): + return 1 / self.f.tan(arg) + + @visit_node.register(ops.Degrees) + def visit_Degrees(self, op, *, arg): + return 180 * arg / self.visit_node(ops.Pi()) + + @visit_node.register(ops.Radians) + def visit_Radians(self, op, *, arg): + return self.visit_node(ops.Pi()) * arg / 180 + + @visit_node.register(ops.Modulus) + def visit_Modulus(self, op, *, left, right): + return self.f.mod(left, right) + + @visit_node.register(ops.Levenshtein) + def visit_Levenshtein(self, op, *, left, right): + # Not using FuncGen here because of dotted function call + return sg.func("utl_match.edit_distance", left, right) + + @visit_node.register(ops.StartsWith) + def visit_StartsWith(self, op, *, arg, start): + return self.f.substr(arg, 0, self.f.length(start)).eq(start) + + @visit_node.register(ops.EndsWith) + def visit_EndsWith(self, op, *, arg, end): + return self.f.substr(arg, -1 * self.f.length(end), self.f.length(end)).eq(end) + + @visit_node.register(ops.StringFind) + def visit_StringFind(self, op, *, arg, substr, start, end): + if end is not None: + raise NotImplementedError("`end` is not implemented") + + sub_string = substr + + if start is not None: + arg = self.f.substr(arg, start + 1) + pos = self.f.instr(arg, sub_string) + # TODO(gil): why, oh why, does this need an extra +1 on the end? + return sg.case().when(pos > 0, pos - 1 + start).else_(-1) + 1 + + return self.f.instr(arg, sub_string) + + @visit_node.register(ops.StrRight) + def visit_StrRight(self, op, *, arg, nchars): + return self.f.substr(arg, -nchars) + + @visit_node.register(ops.RegexExtract) + def visit_RegexExtract(self, op, *, arg, pattern, index): + return self.if_( + index.eq(0), + self.f.regexp_substr(arg, pattern), + self.f.regexp_substr(arg, pattern, 1, 1, "cn", index), + ) + + @visit_node.register(ops.RegexReplace) + def visit_RegexReplace(self, op, *, arg, pattern, replacement): + return sge.RegexpReplace(this=arg, expression=pattern, replacement=replacement) + + @visit_node.register(ops.StringContains) + def visit_StringContains(self, op, *, haystack, needle): + return self.f.instr(haystack, needle) > 0 + + @visit_node.register(ops.StringJoin) + def visit_StringJoin(self, op, *, arg, sep): + return self.f.concat(*toolz.interpose(sep, arg)) + + ## Aggregate stuff + + @visit_node.register(ops.Correlation) + def visit_Correlation(self, op, *, left, right, where, how): + if how == "sample": + raise ValueError( + "Oracle only implements population correlation coefficient" + ) + return self.agg.corr(left, right, where=where) + + @visit_node.register(ops.Covariance) + def visit_Covariance(self, op, *, left, right, where, how): + if how == "sample": + return self.agg.covar_samp(left, right, where=where) + return self.agg.covar_pop(left, right, where=where) + + @visit_node.register(ops.ApproxMedian) + def visit_ApproxMedian(self, op, *, arg, where): + return self.visit_Quantile(op, arg=arg, quantile=0.5, where=where) + + @visit_node.register(ops.Quantile) + def visit_Quantile(self, op, *, arg, quantile, where): + suffix = "cont" if op.arg.dtype.is_numeric() else "disc" + funcname = f"percentile_{suffix}" + + if where is not None: + arg = self.if_(where, arg) + + expr = sge.WithinGroup( + this=self.f[funcname](quantile), + expression=sge.Order(expressions=[sge.Ordered(this=arg)]), + ) + return expr + + @visit_node.register(ops.CountDistinct) + def visit_CountDistinct(self, op, *, arg, where): + if where is not None: + arg = self.if_(where, arg) + + return sge.Count(this=sge.Distinct(expressions=[arg])) + + @visit_node.register(ops.CountStar) + def visit_CountStar(self, op, *, arg, where): + if where is not None: + return self.f.count(self.if_(where, 1, NULL)) + return self.f.count(STAR) + + @visit_node.register(ops.IdenticalTo) + def visit_IdenticalTo(self, op, *, left, right): + # sqlglot NullSafeEQ uses "is not distinct from" which isn't supported in oracle + return ( + sg.case() + .when(left.eq(right).or_(left.is_(NULL).and_(right.is_(NULL))), 0) + .else_(1) + .eq(0) + ) + + @visit_node.register(ops.Xor) + def visit_Xor(self, op, *, left, right): + return (left.or_(right)).and_(sg.not_(left.and_(right))) + + @visit_node.register(ops.TimestampTruncate) + @visit_node.register(ops.DateTruncate) + def visit_DateTruncate(self, op, *, arg, unit): + trunc_unit_mapping = { + "Y": "year", + "M": "MONTH", + "W": "IW", + "D": "DDD", + "h": "HH", + "m": "MI", + } + + timestamp_unit_mapping = { + "s": "SS", + "ms": "SS.FF3", + "us": "SS.FF6", + "ns": "SS.FF9", + } + + if (unyt := timestamp_unit_mapping.get(unit.short)) is not None: + # Oracle only has trunc(DATE) and that can't do sub-minute precision, but we can + # handle those separately. + return self.f.to_timestamp( + self.f.to_char(arg, f"YYYY-MM-DD HH24:MI:{unyt}"), + f"YYYY-MM-DD HH24:MI:{unyt}", + ) + + if (unyt := trunc_unit_mapping.get(unit.short)) is None: + raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}") + + return self.f.trunc(arg, unyt) + + @visit_node.register(Window) + def visit_Window(self, op, *, how, func, start, end, group_by, order_by): + # Oracle has two (more?) types of analytic functions you can use inside OVER. + # + # The first group accepts an "analytic clause" which is decomposed into the + # PARTITION BY, ORDER BY and the windowing clause (e.g. ROWS BETWEEN + # UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING). These are the "full" window functions. + # + # The second group accepts an _optional_ PARTITION BY clause and a _required_ ORDER BY clause. + # If you try to pass, for instance, LEAD(col, 1) OVER() AS "val", this will error. + # + # The list of functions which accept the full analytic clause (and so + # accept a windowing clause) are those functions which are marked with + # an asterisk at the bottom of this page (yes, Oracle thinks this is + # a reasonable way to demarcate them): + # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Analytic-Functions.html + # + # (Side note: these unordered window function queries were not erroring + # in the SQLAlchemy Oracle backend but they were raising AssertionErrors. + # This is because the SQLAlchemy Oracle dialect automatically inserts an + # ORDER BY whether you ask it to or not.) + # + # If the windowing clause is omitted, the default is + # RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + # + # I (@gforsyth) believe that this is the windowing range applied to the + # analytic functions (like LEAD, LAG, CUME_DIST) which don't allow + # specifying a windowing clause. + # + # This allowance for specifying a windowing clause is handled below by + # explicitly listing the ops which correspond to the analytic functions + # that accept it. + + if type(op.func) in ( + # TODO: figure out REGR_* functions and also manage this list better + # Allowed windowing clause functions + ops.Mean, # "avg", + ops.Correlation, # "corr", + ops.Count, # "count", + ops.Covariance, # "covar_pop", "covar_samp", + ops.FirstValue, # "first_value", + ops.LastValue, # "last_value", + ops.Max, # "max", + ops.Min, # "min", + ops.NthValue, # "nth_value", + ops.StandardDev, # "stddev","stddev_pop","stddev_samp", + ops.Sum, # "sum", + ops.Variance, # "var_pop","var_samp","variance", + ): + if start is None: + start = {} + if end is None: + end = {} + + start_value = start.get("value", "UNBOUNDED") + start_side = start.get("side", "PRECEDING") + end_value = end.get("value", "UNBOUNDED") + end_side = end.get("side", "FOLLOWING") + + spec = sge.WindowSpec( + kind=how.upper(), + start=start_value, + start_side=start_side, + end=end_value, + end_side=end_side, + over="OVER", + ) + elif not order_by: + # For other analytic functions, ORDER BY is required + raise com.UnsupportedOperationError( + f"Function {op.func.name} cannot be used in Oracle without an order_by." + ) + else: + # and no windowing clause is supported, so set the spec to None. + spec = None + + order = sge.Order(expressions=order_by) if order_by else None + + spec = self._minimize_spec(op.start, op.end, spec) + + return sge.Window(this=func, partition_by=group_by, order=order, spec=spec) + + @visit_node.register(ops.Arbitrary) + @visit_node.register(ops.ArgMax) + @visit_node.register(ops.ArgMin) + @visit_node.register(ops.ArrayCollect) + @visit_node.register(ops.ArrayColumn) + @visit_node.register(ops.ArrayFlatten) + @visit_node.register(ops.ArrayMap) + @visit_node.register(ops.ArrayStringJoin) + @visit_node.register(ops.First) + @visit_node.register(ops.Last) + @visit_node.register(ops.Mode) + @visit_node.register(ops.MultiQuantile) + @visit_node.register(ops.RegexSplit) + @visit_node.register(ops.StringSplit) + @visit_node.register(ops.TimeTruncate) + @visit_node.register(ops.Bucket) + @visit_node.register(ops.TimestampBucket) + @visit_node.register(ops.TimeDelta) + @visit_node.register(ops.DateDelta) + @visit_node.register(ops.TimestampDelta) + @visit_node.register(ops.TimestampNow) + @visit_node.register(ops.TimestampFromYMDHMS) + @visit_node.register(ops.TimeFromHMS) + @visit_node.register(ops.IntervalFromInteger) + @visit_node.register(ops.DayOfWeekIndex) + @visit_node.register(ops.DayOfWeekName) + @visit_node.register(ops.DateDiff) + @visit_node.register(ops.ExtractEpochSeconds) + @visit_node.register(ops.ExtractWeekOfYear) + @visit_node.register(ops.ExtractDayOfYear) + @visit_node.register(ops.RowID) + def visit_Undefined(self, op, **_): + raise com.OperationNotDefinedError(type(op).__name__) + + +_SIMPLE_OPS = { + ops.ApproxCountDistinct: "approx_count_distinct", + ops.BitAnd: "bit_and_agg", + ops.BitOr: "bit_or_agg", + ops.BitXor: "bit_xor_agg", + ops.BitwiseAnd: "bitand", + ops.Hash: "hash", + ops.LPad: "lpad", + ops.RPad: "rpad", + ops.StringAscii: "ascii", + ops.Strip: "trim", + ops.Hash: "ora_hash", +} + +for _op, _name in _SIMPLE_OPS.items(): + assert isinstance(type(_op), type), type(_op) + if issubclass(_op, ops.Reduction): + + @OracleCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, where, **kw): + return self.agg[_name](*kw.values(), where=where) + + else: + + @OracleCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, **kw): + return self.f[_name](*kw.values()) + + setattr(OracleCompiler, f"visit_{_op.__name__}", _fmt) + + +del _op, _name, _fmt diff --git a/ibis/backends/oracle/converter.py b/ibis/backends/oracle/converter.py new file mode 100644 index 000000000000..7755cb595340 --- /dev/null +++ b/ibis/backends/oracle/converter.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import datetime + +from ibis.formats.pandas import PandasData + + +class OraclePandasData(PandasData): + @classmethod + def convert_Timestamp_element(cls, dtype): + return datetime.datetime.fromisoformat + + @classmethod + def convert_Date_element(cls, dtype): + return datetime.date.fromisoformat + + @classmethod + def convert_Time_element(cls, dtype): + return datetime.time.fromisoformat diff --git a/ibis/backends/oracle/datatypes.py b/ibis/backends/oracle/datatypes.py deleted file mode 100644 index 08cdc3be2e4f..000000000000 --- a/ibis/backends/oracle/datatypes.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -import sqlalchemy.types as sat -from sqlalchemy.dialects import oracle - -import ibis.expr.datatypes as dt -from ibis.backends.base.sql.alchemy.datatypes import AlchemyType -from ibis.backends.base.sqlglot.datatypes import OracleType as SqlglotOracleType - - -class OracleType(AlchemyType): - dialect = "oracle" - - @classmethod - def to_ibis(cls, typ, nullable=True): - if isinstance(typ, oracle.ROWID): - return dt.String(nullable=nullable) - elif isinstance(typ, (oracle.RAW, sat.BLOB)): - return dt.Binary(nullable=nullable) - elif isinstance(typ, sat.Float): - return dt.Float64(nullable=nullable) - elif isinstance(typ, sat.Numeric): - if typ.scale == 0: - # kind of a lie, should be int128 because 38 digits - return dt.Int64(nullable=nullable) - else: - return dt.Decimal( - precision=typ.precision or 38, - scale=typ.scale or 0, - nullable=nullable, - ) - else: - return super().to_ibis(typ, nullable=nullable) - - @classmethod - def from_ibis(cls, dtype): - if isinstance(dtype, dt.Float64): - return sat.Float(precision=53).with_variant(oracle.FLOAT(14), "oracle") - elif isinstance(dtype, dt.Float32): - return sat.Float(precision=23).with_variant(oracle.FLOAT(7), "oracle") - else: - return super().from_ibis(dtype) - - @classmethod - def from_string(cls, type_string, nullable=True): - return SqlglotOracleType.from_string(type_string, nullable=nullable) diff --git a/ibis/backends/oracle/registry.py b/ibis/backends/oracle/registry.py deleted file mode 100644 index 8c6b074bd21d..000000000000 --- a/ibis/backends/oracle/registry.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -import sqlalchemy as sa -import toolz -from packaging.version import parse as vparse - -import ibis.expr.operations as ops -from ibis.backends.base.sql.alchemy import ( - fixed_arity, - reduction, - sqlalchemy_operation_registry, - sqlalchemy_window_functions_registry, - unary, -) -from ibis.backends.base.sql.alchemy.registry import ( - _gen_string_find, -) -from ibis.backends.base.sql.alchemy.registry import ( - _literal as _alchemy_literal, -) - -operation_registry = sqlalchemy_operation_registry.copy() - -operation_registry.update(sqlalchemy_window_functions_registry) - - -def _cot(t, op): - arg = t.translate(op.arg) - return 1.0 / sa.func.tan(arg, type_=t.get_sqla_type(op.arg.dtype)) - - -def _cov(t, op): - return t._reduction(getattr(sa.func, f"covar_{op.how[:4]}"), op) - - -def _corr(t, op): - if op.how == "sample": - raise ValueError( - f"{t.__class__.__name__} only implements population correlation " - "coefficient" - ) - return t._reduction(sa.func.corr, op) - - -def _literal(t, op): - dtype = op.dtype - value = op.value - - if value is None: - return sa.null() - elif ( - # handle UUIDs in sqlalchemy < 2 - vparse(sa.__version__) < vparse("2") and dtype.is_uuid() - ): - return sa.literal(str(value), type_=t.get_sqla_type(dtype)) - elif dtype.is_timestamp(): - if dtype.timezone is not None: - return sa.func.to_utc_timestamp_tz(value.isoformat(timespec="microseconds")) - return sa.func.to_timestamp( - # comma for sep here because T is a special character in Oracle - # the FX prefix means "requires an exact match" - value.isoformat(sep=",", timespec="microseconds"), - "FXYYYY-MM-DD,HH24:MI:SS.FF6", - ) - elif dtype.is_date(): - return sa.func.to_date(value.isoformat(), "FXYYYY-MM-DD") - elif dtype.is_time(): - raise NotImplementedError("Time values are not supported in Oracle") - return _alchemy_literal(t, op) - - -def _second(t, op): - # Oracle returns fractional seconds, so `floor` the result to match - # the behavior of other backends - return sa.func.floor(sa.extract("SECOND", t.translate(op.arg))) - - -def _string_join(t, op): - sep = t.translate(op.sep) - values = list(map(t.translate, op.arg)) - return sa.func.concat(*toolz.interpose(sep, values)) - - -def _median(t, op): - arg = op.arg - if (where := op.where) is not None: - arg = ops.IfElse(where, arg, None) - - if arg.dtype.is_numeric(): - return sa.func.median(t.translate(arg)) - return sa.cast( - sa.func.percentile_disc(0.5).within_group(t.translate(arg)), - t.get_sqla_type(op.dtype), - ) - - -operation_registry.update( - { - ops.Log2: unary(lambda arg: sa.func.log(2, arg)), - ops.Log10: unary(lambda arg: sa.func.log(10, arg)), - ops.Log: fixed_arity(lambda arg, base: sa.func.log(base, arg), 2), - ops.Power: fixed_arity(sa.func.power, 2), - ops.Cot: _cot, - ops.Pi: lambda *_: sa.func.ACOS(-1), - ops.RandomScalar: fixed_arity(sa.func.dbms_random.value, 0), - ops.Degrees: lambda t, op: 180 * t.translate(op.arg) / t.translate(ops.Pi()), - ops.Radians: lambda t, op: t.translate(ops.Pi()) * t.translate(op.arg) / 180, - # Aggregate Functions - ops.Covariance: _cov, - ops.Correlation: _corr, - ops.ApproxMedian: reduction(sa.func.approx_median), - ops.Median: _median, - # Temporal - ops.ExtractSecond: _second, - # String - ops.StrRight: fixed_arity(lambda arg, nchars: sa.func.substr(arg, -nchars), 2), - ops.StringJoin: _string_join, - ops.StringFind: _gen_string_find(sa.func.instr), - # Generic - ops.Hash: unary(sa.func.ora_hash), - ops.Literal: _literal, - ops.Levenshtein: fixed_arity(sa.func.utl_match.edit_distance, 2), - } -) - -_invalid_operations = set() - -operation_registry = { - k: v for k, v in operation_registry.items() if k not in _invalid_operations -} diff --git a/ibis/backends/oracle/tests/conftest.py b/ibis/backends/oracle/tests/conftest.py index addf6a6d3924..ee27b4cc0f1f 100644 --- a/ibis/backends/oracle/tests/conftest.py +++ b/ibis/backends/oracle/tests/conftest.py @@ -7,8 +7,8 @@ import subprocess from typing import TYPE_CHECKING, Any +import oracledb import pytest -import sqlalchemy as sa import ibis from ibis.backends.tests.base import ServiceBackendTest @@ -28,6 +28,9 @@ # ./createAppUser user pass ORACLE_DB # where ORACLE_DB is the same name you used in the Compose file. +# Set to ensure decimals come back as decimals +oracledb.defaults.fetch_decimals = True + class TestConf(ServiceBackendTest): check_dtype = False @@ -43,7 +46,7 @@ class TestConf(ServiceBackendTest): rounding_method = "half_to_even" data_volume = "/opt/oracle/data" service_name = "oracle" - deps = "oracledb", "sqlalchemy" + deps = ("oracledb",) @property def test_files(self) -> Iterable[Path]: @@ -88,12 +91,11 @@ def _load_data( ) init_oracle_database( - url=sa.engine.make_url( - f"oracle://{user}:{password}@{host}:{port:d}/{database}", - ), + dsn=oracledb.makedsn(host, port, service_name=database), + user=user, + password=password, database=database, schema=self.ddl_script, - connect_args=dict(service_name=database), ) # then call sqlldr to ingest @@ -138,42 +140,29 @@ def con(data_dir, tmp_path_factory, worker_id): def init_oracle_database( - url: sa.engine.url.URL, + user: str, + password: str, + dsn: str, database: str, schema: str | None = None, **kwargs: Any, -) -> sa.engine.Engine: +) -> None: """Initialise `database` at `url` with `schema`. - If `recreate`, drop the `database` at `url`, if it exists. - Parameters ---------- - url : url.sa.engine.url.URL - Connection url to the database database : str Name of the database to be dropped schema : TextIO File object containing schema to use - - Returns - ------- - sa.engine.Engine - SQLAlchemy engine object """ - try: - url.database = database - except AttributeError: - url = url.set(database=database) - engine = sa.create_engine(url, **kwargs) + con = oracledb.connect(dsn, user=user, password=password, stmtcachesize=0) if schema: - with engine.begin() as conn: + with con.cursor() as cursor: for stmt in schema: # XXX: maybe should just remove the comments in the sql file # so we don't end up writing an entire parser here. if not stmt.startswith("--"): - conn.exec_driver_sql(stmt) - - return engine + cursor.execute(stmt) diff --git a/ibis/backends/oracle/tests/test_client.py b/ibis/backends/oracle/tests/test_client.py index 4f7290aa9157..d1dd5e6ad5fd 100644 --- a/ibis/backends/oracle/tests/test_client.py +++ b/ibis/backends/oracle/tests/test_client.py @@ -54,6 +54,6 @@ def stats_one_way_anova(x, y, value: str) -> int: """ with con.begin() as c: expected = pd.DataFrame( - c.exec_driver_sql(sql).fetchall(), columns=["string_col", "df_w"] + c.execute(sql).fetchall(), columns=["string_col", "df_w"] ) - tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result, expected, check_dtype=False) diff --git a/ibis/backends/oracle/tests/test_datatypes.py b/ibis/backends/oracle/tests/test_datatypes.py index 848bfd905e43..33efc4bdcb41 100644 --- a/ibis/backends/oracle/tests/test_datatypes.py +++ b/ibis/backends/oracle/tests/test_datatypes.py @@ -14,7 +14,7 @@ def test_blob_raw(con): con.drop_table("blob_raw_blobs_blob_raw", force=True) with con.begin() as bind: - bind.exec_driver_sql( + bind.execute( """CREATE TABLE "blob_raw_blobs_blob_raw" ("blob" BLOB, "raw" RAW(255))""" ) diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index a6c4064f9610..df1f290f7648 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -111,3 +111,8 @@ from pydruid.db.exceptions import ProgrammingError as PyDruidProgrammingError except ImportError: PyDruidProgrammingError = None + +try: + from oracledb.exceptions import DatabaseError as OracleDatabaseError +except ImportError: + OracleDatabaseError = None diff --git a/ibis/backends/tests/snapshots/test_interactive/test_default_limit/oracle/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_default_limit/oracle/out.sql new file mode 100644 index 000000000000..2124da09f645 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_default_limit/oracle/out.sql @@ -0,0 +1,5 @@ +SELECT + "t0"."id", + "t0"."bool_col" = 1 AS "bool_col" +FROM "functional_alltypes" "t0" +FETCH FIRST 11 ROWS ONLY \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/oracle/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/oracle/out.sql new file mode 100644 index 000000000000..2124da09f645 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_disable_query_limit/oracle/out.sql @@ -0,0 +1,5 @@ +SELECT + "t0"."id", + "t0"."bool_col" = 1 AS "bool_col" +FROM "functional_alltypes" "t0" +FETCH FIRST 11 ROWS ONLY \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/oracle/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/oracle/out.sql new file mode 100644 index 000000000000..7b50874f2771 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_interactive_execute_on_repr/oracle/out.sql @@ -0,0 +1,19 @@ +SELECT + SUM("t1"."bigint_col") AS "Sum(bigint_col)" +FROM ( + SELECT + "t0"."id", + "t0"."bool_col" = 1 AS "bool_col", + "t0"."tinyint_col", + "t0"."smallint_col", + "t0"."int_col", + "t0"."bigint_col", + "t0"."float_col", + "t0"."double_col", + "t0"."date_string_col", + "t0"."string_col", + "t0"."timestamp_col", + "t0"."year", + "t0"."month" + FROM "functional_alltypes" "t0" +) "t1" \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/oracle/out.sql b/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/oracle/out.sql new file mode 100644 index 000000000000..96217eecd9a1 --- /dev/null +++ b/ibis/backends/tests/snapshots/test_interactive/test_respect_set_limit/oracle/out.sql @@ -0,0 +1,10 @@ +SELECT + * +FROM ( + SELECT + "t0"."id", + "t0"."bool_col" = 1 AS "bool_col" + FROM "functional_alltypes" "t0" + FETCH FIRST 10 ROWS ONLY +) "t2" +FETCH FIRST 11 ROWS ONLY \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/oracle/out.sql b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/oracle/out.sql index 69fb369f7226..036e3567f920 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/oracle/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/oracle/out.sql @@ -1,5 +1,5 @@ SELECT - CASE t0."continent" + CASE "t0"."continent" WHEN 'NA' THEN 'North America' WHEN 'SA' @@ -16,10 +16,10 @@ SELECT THEN 'Antarctica' ELSE 'Unknown continent' END AS "cont", - SUM(t0."population") AS "total_pop" -FROM "countries" t0 + SUM("t0"."population") AS "total_pop" +FROM "countries" "t0" GROUP BY - CASE t0."continent" + CASE "t0"."continent" WHEN 'NA' THEN 'North America' WHEN 'SA' diff --git a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/oracle/out.sql b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/oracle/out.sql index 13480df0fe70..e63d49015d77 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_isin_bug/oracle/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_isin_bug/oracle/out.sql @@ -1,13 +1,9 @@ SELECT - t0."x" IN ( + "t0"."x" IN ( SELECT - t1."x" - FROM ( - SELECT - t0."x" AS "x" - FROM "t" t0 - WHERE - t0."x" > 2 - ) t1 - ) AS "InColumn(x, x)" -FROM "t" t0 \ No newline at end of file + "t0"."x" + FROM "t" "t0" + WHERE + "t0"."x" > 2 + ) AS "InSubquery(x)" +FROM "t" "t0" \ No newline at end of file diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 8899cc82f98d..2cbb2449bbc3 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -20,6 +20,7 @@ GoogleBadRequest, ImpalaHiveServer2Error, MySQLNotSupportedError, + OracleDatabaseError, PolarsInvalidOperationError, Py4JError, PyDruidProgrammingError, @@ -297,7 +298,7 @@ def mean_and_std(v): ), pytest.mark.broken( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-02000: missing AS keyword", ), pytest.mark.notimpl( @@ -317,7 +318,7 @@ def mean_and_std(v): ), pytest.mark.broken( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-02000: missing AS keyword", ), pytest.mark.notimpl( @@ -349,7 +350,7 @@ def mean_and_std(v): ), pytest.mark.broken( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-02000: missing AS keyword", ), pytest.mark.notimpl( @@ -369,7 +370,7 @@ def mean_and_std(v): ), pytest.mark.broken( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-02000: missing AS keyword", ), pytest.mark.notimpl( @@ -389,7 +390,7 @@ def mean_and_std(v): marks=[ pytest.mark.broken( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-02000: missing AS keyword", ), ], @@ -581,10 +582,10 @@ def mean_and_std(v): "dask", "datafusion", "druid", - "oracle", "impala", "mssql", "mysql", + "oracle", "pandas", "polars", "sqlite", @@ -641,10 +642,6 @@ def mean_and_std(v): raises=AttributeError, reason="'Series' object has no attribute 'bitand'", ), - pytest.mark.notimpl( - ["oracle"], - raises=sa.exc.DatabaseError, - ), ], ), param( @@ -664,11 +661,6 @@ def mean_and_std(v): raises=AttributeError, reason="'Series' object has no attribute 'bitor'", ), - pytest.mark.notyet( - ["oracle"], - raises=sa.exc.DatabaseError, - reason="ORA-00904: 'BIT_OR': invalid identifier", - ), ], ), param( @@ -688,11 +680,6 @@ def mean_and_std(v): raises=AttributeError, reason="'Series' object has no attribute 'bitxor'", ), - pytest.mark.notyet( - ["oracle"], - raises=sa.exc.DatabaseError, - reason="ORA-00904: 'BIT_XOR': invalid identifier", - ), ], ), param( @@ -801,7 +788,7 @@ def test_reduction_ops( ["bigquery", "druid", "mssql", "oracle", "sqlite", "flink"], raises=( sa.exc.OperationalError, - sa.exc.DatabaseError, + OracleDatabaseError, com.UnsupportedOperationError, com.OperationNotDefinedError, ), @@ -839,7 +826,6 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): "mysql", "sqlite", "druid", - "oracle", "exasol", ], raises=com.OperationNotDefinedError, @@ -1154,9 +1140,6 @@ def test_median(alltypes, df): raises=ClickHouseDatabaseError, reason="doesn't support median of strings", ) -@pytest.mark.notyet( - ["oracle"], raises=sa.exc.DatabaseError, reason="doesn't support median of strings" -) @pytest.mark.broken( ["pyspark"], raises=AssertionError, reason="pyspark returns null for string median" ) @@ -1175,7 +1158,17 @@ def test_median(alltypes, df): "func", [ param(methodcaller("quantile", 0.5), id="quantile"), - param(methodcaller("median"), id="median"), + param( + methodcaller("median"), + id="median", + marks=[ + pytest.mark.notyet( + ["oracle"], + raises=OracleDatabaseError, + reason="doesn't support median of strings", + ) + ], + ), ], ) def test_string_quantile(alltypes, func): @@ -1204,9 +1197,6 @@ def test_string_quantile(alltypes, func): param( methodcaller("quantile", 0.5), id="quantile", - marks=[ - pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError) - ], ), ], ) @@ -1264,7 +1254,7 @@ def test_date_quantile(alltypes, func): ) @pytest.mark.notyet( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-00904: 'GROUP_CONCAT': invalid identifier", ) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) @@ -1535,7 +1525,7 @@ def test_grouped_case(backend, con): @pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError) @pytest.mark.notyet(["trino"], raises=TrinoUserError) @pytest.mark.notyet(["mysql"], raises=MySQLNotSupportedError) -@pytest.mark.notyet(["oracle"], raises=sa.exc.DatabaseError) +@pytest.mark.notyet(["oracle"], raises=OracleDatabaseError) @pytest.mark.notyet(["pyspark"], raises=PySparkAnalysisException) def test_group_concat_over_window(backend, con): input_df = pd.DataFrame( diff --git a/ibis/backends/tests/test_asof_join.py b/ibis/backends/tests/test_asof_join.py index 3b71ebe88346..c10fe79dbff6 100644 --- a/ibis/backends/tests/test_asof_join.py +++ b/ibis/backends/tests/test_asof_join.py @@ -92,6 +92,7 @@ def time_keyed_right(time_keyed_df2): "impala", "bigquery", "exasol", + "oracle", ] ) def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op): @@ -129,6 +130,7 @@ def test_asof_join(con, time_left, time_right, time_df1, time_df2, direction, op "impala", "bigquery", "exasol", + "oracle", ] ) def test_keyed_asof_join_with_tolerance( diff --git a/ibis/backends/tests/test_binary.py b/ibis/backends/tests/test_binary.py index b67818980c51..0a5790c64631 100644 --- a/ibis/backends/tests/test_binary.py +++ b/ibis/backends/tests/test_binary.py @@ -20,7 +20,7 @@ @pytest.mark.notimpl( - ["clickhouse", "impala", "druid"], + ["clickhouse", "impala", "druid", "oracle"], "Unsupported type: Binary(nullable=True)", raises=NotImplementedError, ) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 9b0cd8e3a9b0..9e04c4f7d3a9 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1379,6 +1379,10 @@ def test_persist_expression_repeated_cache(alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@mark.notimpl( + ["oracle"], + reason="Oracle error message for a missing table/view doesn't include the name of the table", +) def test_persist_expression_release(con, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 3" diff --git a/ibis/backends/tests/test_dot_sql.py b/ibis/backends/tests/test_dot_sql.py index 765baf9ad6e9..df67a506769c 100644 --- a/ibis/backends/tests/test_dot_sql.py +++ b/ibis/backends/tests/test_dot_sql.py @@ -29,6 +29,7 @@ @pytest.mark.notimpl(["flink"]) +@pytest.mark.notyet(["oracle"], reason="table quoting behavior") @dot_sql_never @pytest.mark.parametrize( "schema", diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index a1cf917dc0b4..4c57128b1864 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -15,6 +15,7 @@ DuckDBParserException, ExaQueryError, MySQLOperationalError, + OracleDatabaseError, PyDeltaTableError, PyDruidProgrammingError, PySparkArithmeticException, @@ -359,7 +360,7 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): pytest.mark.notyet(["mssql"], raises=sa.exc.ProgrammingError), pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError), pytest.mark.notyet(["trino"], raises=TrinoUserError), - pytest.mark.notyet(["oracle"], raises=sa.exc.DatabaseError), + pytest.mark.notyet(["oracle"], raises=OracleDatabaseError), pytest.mark.notyet(["mysql"], raises=MySQLOperationalError), pytest.mark.notyet( ["pyspark"], diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index ca26588ef9d9..16c110d44759 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -25,6 +25,7 @@ GoogleBadRequest, ImpalaHiveServer2Error, MySQLProgrammingError, + OracleDatabaseError, PyDruidProgrammingError, SnowflakeProgrammingError, TrinoUserError, @@ -122,7 +123,7 @@ def test_scalar_fillna_nullif(con, expr, expected): ), ], ) -@pytest.mark.notimpl(["mssql", "oracle"]) +@pytest.mark.notimpl(["mssql"]) @pytest.mark.notyet(["flink"], "NaN is not supported in Flink SQL", raises=ValueError) def test_isna(backend, alltypes, col, value, filt): table = alltypes.select(**{col: value}) @@ -364,7 +365,7 @@ def test_case_where(backend, alltypes, df): # TODO: some of these are notimpl (datafusion) others are probably never -@pytest.mark.notimpl(["mysql", "sqlite", "mssql", "druid", "oracle", "exasol"]) +@pytest.mark.notimpl(["mysql", "sqlite", "mssql", "druid", "exasol"]) @pytest.mark.notyet(["flink"], "NaN is not supported in Flink SQL", raises=ValueError) def test_select_filter_mutate(backend, alltypes, df): """Test that select, filter and mutate are executed in right order. @@ -416,11 +417,7 @@ def test_table_fillna_invalid(alltypes): "replacements", [ param({"int_col": 20}, id="int"), - param( - {"double_col": -1, "string_col": "missing"}, - id="double-int-str", - marks=[pytest.mark.notimpl(["oracle"])], - ), + param({"double_col": -1, "string_col": "missing"}, id="double-int-str"), param({"double_col": -1.5, "string_col": "missing"}, id="double-str"), ], ) @@ -438,7 +435,6 @@ def test_table_fillna_mapping(backend, alltypes, replacements): backend.assert_frame_equal(result, expected, check_dtype=False) -@pytest.mark.notimpl(["oracle"]) def test_table_fillna_scalar(backend, alltypes): table = alltypes.mutate( int_col=alltypes.int_col.nullif(1), @@ -1102,7 +1098,11 @@ def test_pivot_wider(backend): ) @pytest.mark.notimpl( ["druid", "impala", "oracle"], - raises=(NotImplementedError, sa.exc.ProgrammingError, com.OperationNotDefinedError), + raises=( + NotImplementedError, + OracleDatabaseError, + com.OperationNotDefinedError, + ), reason="arbitrary not implemented in the backend", ) @pytest.mark.notimpl( @@ -1167,7 +1167,7 @@ def test_distinct_on_keep(backend, on, keep): ) @pytest.mark.notimpl( ["druid", "impala", "oracle"], - raises=(NotImplementedError, sa.exc.ProgrammingError, com.OperationNotDefinedError), + raises=(NotImplementedError, OracleDatabaseError, com.OperationNotDefinedError), reason="arbitrary not implemented in the backend", ) @pytest.mark.notimpl( @@ -1407,6 +1407,7 @@ def test_try_cast_func(con, from_val, to_type, func): raises=ExaQueryError, reason="doesn't support OFFSET without ORDER BY", ), + pytest.mark.notyet(["oracle"], raises=com.UnsupportedArgumentError), ], ), param( @@ -1430,6 +1431,7 @@ def test_try_cast_func(con, from_val, to_type, func): raises=ImpalaHiveServer2Error, reason="impala doesn't support OFFSET without ORDER BY", ), + pytest.mark.notyet(["oracle"], raises=com.UnsupportedArgumentError), ], ), # positive stop @@ -1448,6 +1450,7 @@ def test_try_cast_func(con, from_val, to_type, func): raises=ExaQueryError, reason="doesn't support OFFSET without ORDER BY", ), + pytest.mark.notyet(["oracle"], raises=com.UnsupportedArgumentError), ], ), param( @@ -1461,6 +1464,7 @@ def test_try_cast_func(con, from_val, to_type, func): reason="mssql doesn't support OFFSET without LIMIT", ), pytest.mark.notyet(["exasol"], raises=ExaQueryError), + pytest.mark.notyet(["oracle"], raises=com.UnsupportedArgumentError), pytest.mark.notyet( ["impala"], raises=ImpalaHiveServer2Error, @@ -1514,6 +1518,11 @@ def test_static_table_slice(backend, slc, expected_count_fn): raises=SnowflakeProgrammingError, reason="backend doesn't support dynamic limit/offset", ) +@pytest.mark.notyet( + ["oracle"], + raises=com.UnsupportedArgumentError, + reason="Removed half-baked dynamic offset functionality for now", +) @pytest.mark.notyet( ["trino"], raises=TrinoUserError, @@ -1569,6 +1578,11 @@ def test_dynamic_table_slice(backend, slc, expected_count_fn): raises=SnowflakeProgrammingError, reason="backend doesn't support dynamic limit/offset", ) +@pytest.mark.notyet( + ["oracle"], + raises=com.UnsupportedArgumentError, + reason="Removed half-baked dynamic offset functionality for now", +) @pytest.mark.notimpl( ["trino"], raises=TrinoUserError, diff --git a/ibis/backends/tests/test_join.py b/ibis/backends/tests/test_join.py index 2104893321ef..f1b79bdc2ef2 100644 --- a/ibis/backends/tests/test_join.py +++ b/ibis/backends/tests/test_join.py @@ -188,7 +188,7 @@ def test_semi_join_topk(batting, awards_players, func): assert not expr.limit(5).execute().empty -@pytest.mark.notimpl(["dask", "druid", "exasol"]) +@pytest.mark.notimpl(["dask", "druid", "exasol", "oracle"]) @pytest.mark.notimpl( ["postgres"], raises=com.IbisTypeError, diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 673200f26b9e..11d8617d9fec 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -23,6 +23,7 @@ GoogleBadRequest, ImpalaHiveServer2Error, MySQLOperationalError, + OracleDatabaseError, PsycoPg2DivisionByZero, Py4JError, PyDruidProgrammingError, @@ -248,7 +249,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "mssql": 1.1, "druid": decimal.Decimal("1.1"), "datafusion": decimal.Decimal("1.1"), - "oracle": 1.1, + "oracle": decimal.Decimal("1.1"), "flink": decimal.Decimal("1.1"), }, { @@ -290,7 +291,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "mssql": 1.1, "druid": decimal.Decimal("1.1"), "datafusion": decimal.Decimal("1.1"), - "oracle": 1.1, + "oracle": decimal.Decimal("1.1"), "flink": decimal.Decimal("1.1"), }, { @@ -336,6 +337,7 @@ def test_numeric_literal(con, backend, expr, expected_types): pytest.mark.notimpl(["exasol"], raises=ExaQueryError), pytest.mark.notimpl(["mysql"], raises=MySQLOperationalError), pytest.mark.notyet(["snowflake"], raises=SnowflakeProgrammingError), + pytest.mark.notyet(["oracle"], raises=OracleDatabaseError), pytest.mark.notyet(["impala"], raises=ImpalaHiveServer2Error), pytest.mark.broken( ["duckdb"], @@ -412,7 +414,7 @@ def test_numeric_literal(con, backend, expr, expected_types): pytest.mark.broken( ["oracle"], "(oracledb.exceptions.DatabaseError) DPY-4004: invalid number", - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, ), pytest.mark.notyet( ["trino"], @@ -487,7 +489,7 @@ def test_numeric_literal(con, backend, expr, expected_types): pytest.mark.broken( ["oracle"], "(oracledb.exceptions.DatabaseError) DPY-4004: invalid number", - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, ), pytest.mark.notyet( ["flink"], @@ -574,7 +576,7 @@ def test_numeric_literal(con, backend, expr, expected_types): pytest.mark.broken( ["oracle"], "(oracledb.exceptions.DatabaseError) DPY-4004: invalid number", - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, ), pytest.mark.notyet( ["flink"], @@ -683,7 +685,7 @@ def test_decimal_literal(con, backend, expr, expected_types, expected_result): ], ) @pytest.mark.notimpl( - ["sqlite", "mssql", "oracle", "flink", "druid"], raises=com.OperationNotDefinedError + ["sqlite", "mssql", "flink", "druid"], raises=com.OperationNotDefinedError ) @pytest.mark.notimpl(["mysql"], raises=(MySQLOperationalError, NotImplementedError)) def test_isnan_isinf( @@ -1087,7 +1089,7 @@ def test_floating_mod(backend, alltypes, df): marks=[ pytest.mark.notyet( "oracle", - raises=(sa.exc.DatabaseError, sa.exc.ArgumentError), + raises=OracleDatabaseError, reason="Oracle doesn't do integer division by zero", ), pytest.mark.notyet( @@ -1103,7 +1105,7 @@ def test_floating_mod(backend, alltypes, df): marks=[ pytest.mark.notyet( "oracle", - raises=(sa.exc.DatabaseError, sa.exc.ArgumentError), + raises=OracleDatabaseError, reason="Oracle doesn't do integer division by zero", ), pytest.mark.notyet( @@ -1119,7 +1121,7 @@ def test_floating_mod(backend, alltypes, df): marks=[ pytest.mark.notyet( "oracle", - raises=(sa.exc.DatabaseError, sa.exc.ArgumentError), + raises=OracleDatabaseError, reason="Oracle doesn't do integer division by zero", ), pytest.mark.notyet( @@ -1135,7 +1137,7 @@ def test_floating_mod(backend, alltypes, df): marks=[ pytest.mark.notyet( "oracle", - raises=(sa.exc.DatabaseError, sa.exc.ArgumentError), + raises=OracleDatabaseError, reason="Oracle doesn't do integer division by zero", ), pytest.mark.notyet( @@ -1153,7 +1155,7 @@ def test_floating_mod(backend, alltypes, df): marks=[ pytest.mark.notyet( "oracle", - raises=(sa.exc.DatabaseError, sa.exc.ArgumentError), + raises=OracleDatabaseError, reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), @@ -1165,7 +1167,7 @@ def test_floating_mod(backend, alltypes, df): marks=[ pytest.mark.notyet( "oracle", - raises=(sa.exc.DatabaseError, sa.exc.ArgumentError), + raises=OracleDatabaseError, reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), @@ -1177,7 +1179,7 @@ def test_floating_mod(backend, alltypes, df): marks=[ pytest.mark.notyet( "oracle", - raises=(sa.exc.DatabaseError, sa.exc.ArgumentError), + raises=OracleDatabaseError, reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), @@ -1189,7 +1191,7 @@ def test_floating_mod(backend, alltypes, df): marks=[ pytest.mark.notyet( "oracle", - raises=(sa.exc.DatabaseError, sa.exc.ArgumentError), + raises=OracleDatabaseError, reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), @@ -1264,6 +1266,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): "datafusion", "duckdb", "impala", + "oracle", "pandas", "pyspark", "polars", @@ -1401,7 +1404,14 @@ def test_constants(con, const): ) -@pytest.mark.parametrize("op", [and_, or_, xor]) +@pytest.mark.parametrize( + "op", + [ + and_, + param(or_, marks=[pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError)]), + param(xor, marks=[pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError)]), + ], +) @pytest.mark.parametrize( ("left_fn", "right_fn"), [ @@ -1410,7 +1420,6 @@ def test_constants(con, const): param(lambda t: t.int_col, lambda _: 3, id="col_scalar"), ], ) -@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) @pytest.mark.notimpl(["exasol"], raises=(sa.exc.DBAPIError, ExaQueryError)) @flink_no_bitwise def test_bitwise_columns(backend, con, alltypes, df, op, left_fn, right_fn): @@ -1447,7 +1456,7 @@ def test_bitwise_columns(backend, con, alltypes, df, op, left_fn, right_fn): param(rshift, lambda t: t.int_col, lambda _: 3, id="rshift_col_scalar"), ], ) -@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) +@pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError) @pytest.mark.notimpl(["exasol"], raises=(sa.exc.DBAPIError, ExaQueryError)) @flink_no_bitwise def test_bitwise_shift(backend, alltypes, df, op, left_fn, right_fn): @@ -1466,13 +1475,30 @@ def test_bitwise_shift(backend, alltypes, df, op, left_fn, right_fn): @pytest.mark.parametrize( "op", - [and_, or_, xor, lshift, rshift], + [ + and_, + param( + or_, + marks=pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError), + ), + param( + xor, + marks=pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError), + ), + param( + lshift, + marks=[pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError)], + ), + param( + rshift, + marks=[pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError)], + ), + ], ) @pytest.mark.parametrize( ("left", "right"), [param(4, L(2), id="int_col"), param(L(4), 2, id="col_int")], ) -@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) @flink_no_bitwise def test_bitwise_scalars(con, op, left, right): @@ -1483,7 +1509,7 @@ def test_bitwise_scalars(con, op, left, right): @pytest.mark.notimpl(["datafusion", "exasol"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) +@pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError) @flink_no_bitwise def test_bitwise_not_scalar(con): expr = ~L(2) @@ -1493,7 +1519,7 @@ def test_bitwise_not_scalar(con): @pytest.mark.notimpl(["datafusion", "exasol"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) +@pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError) @flink_no_bitwise def test_bitwise_not_col(backend, alltypes, df): expr = (~alltypes.int_col).name("tmp") diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index e5c7270e1e6f..3947b1af201d 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -6,13 +6,12 @@ import numpy as np import pandas as pd import pytest -import sqlalchemy as sa from pytest import param import ibis import ibis.expr.datatypes as dt from ibis import _ -from ibis.backends.tests.errors import Py4JJavaError +from ibis.backends.tests.errors import OracleDatabaseError, Py4JJavaError @pytest.mark.parametrize( @@ -38,7 +37,7 @@ def test_floating_scalar_parameter(backend, alltypes, df, column, raw_value): [("2009-03-01", "2010-07-03"), ("2014-12-01", "2017-01-05")], ) @pytest.mark.notimpl(["mssql", "trino", "druid"]) -@pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError) +@pytest.mark.broken(["oracle"], raises=OracleDatabaseError) def test_date_scalar_parameter(backend, alltypes, start_string, end_string): start, end = ibis.param(dt.date), ibis.param(dt.date) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index ea8dd9ccacb2..dfb272790c21 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -18,7 +18,7 @@ ibis.array([1]), marks=[ pytest.mark.never( - ["mssql", "oracle"], + ["mssql"], raises=sa.exc.CompileError, reason="arrays not supported in the backend", ), @@ -51,7 +51,7 @@ reason="structs not supported in the backend", ) no_struct_literals = pytest.mark.notimpl( - ["mssql", "oracle"], reason="struct literals are not yet implemented" + ["mssql"], reason="struct literals are not yet implemented" ) not_sql = pytest.mark.never( ["pandas", "dask"], diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 9a7ec1052714..117fde198996 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -11,7 +11,11 @@ import ibis import ibis.common.exceptions as com import ibis.expr.datatypes as dt -from ibis.backends.tests.errors import ClickHouseDatabaseError, PyDruidProgrammingError +from ibis.backends.tests.errors import ( + ClickHouseDatabaseError, + OracleDatabaseError, + PyDruidProgrammingError, +) from ibis.common.annotations import ValidationError @@ -49,7 +53,7 @@ id="string-quote1", marks=pytest.mark.broken( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-01741: illegal zero length identifier", ), ), @@ -69,7 +73,7 @@ id="string-quote2", marks=pytest.mark.broken( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-25716", ), ), @@ -185,7 +189,7 @@ def uses_java_re(t): id="rlike", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -195,7 +199,7 @@ def uses_java_re(t): id="re_search_substring", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -205,7 +209,7 @@ def uses_java_re(t): id="re_search", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -217,7 +221,7 @@ def uses_java_re(t): id="re_search_posix", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], + ["mssql", "exasol"], raises=com.OperationNotDefinedError, ), pytest.mark.never( @@ -233,7 +237,7 @@ def uses_java_re(t): id="re_extract", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -243,7 +247,7 @@ def uses_java_re(t): id="re_extract_group", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -255,7 +259,7 @@ def uses_java_re(t): id="re_extract_posix", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), pytest.mark.notimpl( ["druid"], reason="No posix support", raises=AssertionError @@ -268,7 +272,7 @@ def uses_java_re(t): id="re_extract_whole_group", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -280,7 +284,7 @@ def uses_java_re(t): id="re_extract_group_1", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -292,7 +296,7 @@ def uses_java_re(t): id="re_extract_group_2", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -304,7 +308,7 @@ def uses_java_re(t): id="re_extract_group_3", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -314,7 +318,7 @@ def uses_java_re(t): id="re_extract_group_at_beginning", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -324,7 +328,7 @@ def uses_java_re(t): id="re_extract_group_at_end", marks=[ pytest.mark.notimpl( - ["mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "exasol"], raises=com.OperationNotDefinedError ), ], ), @@ -336,7 +340,7 @@ def uses_java_re(t): id="re_replace_posix", marks=[ pytest.mark.notimpl( - ["mysql", "mssql", "druid", "oracle", "exasol"], + ["mysql", "mssql", "druid", "exasol"], raises=com.OperationNotDefinedError, ), ], @@ -347,7 +351,7 @@ def uses_java_re(t): id="re_replace", marks=[ pytest.mark.notimpl( - ["mysql", "mssql", "druid", "oracle", "exasol"], + ["mysql", "mssql", "druid", "exasol"], raises=com.OperationNotDefinedError, ), ], @@ -358,7 +362,7 @@ def uses_java_re(t): id="repeat_method", marks=pytest.mark.notimpl( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-00904: REPEAT invalid identifier", ), ), @@ -368,7 +372,7 @@ def uses_java_re(t): id="repeat_left", marks=pytest.mark.notimpl( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-00904: REPEAT invalid identifier", ), ), @@ -378,7 +382,7 @@ def uses_java_re(t): id="repeat_right", marks=pytest.mark.notimpl( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-00904: REPEAT invalid identifier", ), ), @@ -388,7 +392,7 @@ def uses_java_re(t): id="translate", marks=[ pytest.mark.notimpl( - ["mssql", "mysql", "polars", "druid", "oracle"], + ["mssql", "mysql", "polars", "druid"], raises=com.OperationNotDefinedError, ), pytest.mark.notyet( @@ -763,7 +767,7 @@ def test_string(backend, alltypes, df, result_func, expected_func): @pytest.mark.notimpl( - ["mysql", "mssql", "druid", "oracle", "exasol"], + ["mysql", "mssql", "druid", "exasol"], raises=com.OperationNotDefinedError, ) @pytest.mark.broken( @@ -779,11 +783,6 @@ def test_re_replace_global(con): @pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError) @pytest.mark.notimpl(["druid"], raises=ValidationError) -@pytest.mark.broken( - ["oracle"], - raises=sa.exc.DatabaseError, - reason="ORA-61801: only boolean column or attribute can be used as a predicate", -) def test_substr_with_null_values(backend, alltypes, df): table = alltypes.mutate( substr_col_null=ibis.case() @@ -917,7 +916,7 @@ def test_array_string_join(con): @pytest.mark.notimpl( - ["mssql", "mysql", "druid", "oracle", "exasol"], raises=com.OperationNotDefinedError + ["mssql", "mysql", "druid", "exasol"], raises=com.OperationNotDefinedError ) def test_subs_with_re_replace(con): expr = ibis.literal("hi").re_replace("i", "a").substitute({"d": "b"}, else_="k") @@ -964,11 +963,6 @@ def test_levenshtein(con, right): reason="doesn't allow boolean expressions in select statements", raises=sa.exc.ProgrammingError, ) -@pytest.mark.broken( - ["oracle"], - reason="sqlalchemy converts True to 1, which cannot be used in CASE WHEN statement", - raises=sa.exc.DatabaseError, -) @pytest.mark.parametrize( "expr", [ @@ -980,9 +974,7 @@ def test_no_conditional_percent_escape(con, expr): assert con.execute(expr) == "%" -@pytest.mark.notimpl( - ["dask", "mssql", "oracle", "exasol"], raises=com.OperationNotDefinedError -) +@pytest.mark.notimpl(["dask", "mssql", "exasol"], raises=com.OperationNotDefinedError) def test_non_match_regex_search_is_false(con): expr = ibis.literal("foo").re_search("bar") result = con.execute(expr) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 4f3468b93d00..ebac084a9607 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -27,6 +27,7 @@ ImpalaOperationalError, MySQLOperationalError, MySQLProgrammingError, + OracleDatabaseError, PolarsComputeError, PolarsPanicException, Py4JJavaError, @@ -74,7 +75,7 @@ def test_date_extract(backend, alltypes, df, attr, expr_fn): param( "quarter", marks=[ - pytest.mark.notyet(["oracle"], raises=sa.exc.DatabaseError), + pytest.mark.notyet(["oracle"], raises=OracleDatabaseError), pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError), ], ), @@ -314,7 +315,6 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): param( "W", marks=[ - pytest.mark.broken(["sqlite", "exasol"], raises=AssertionError), pytest.mark.notimpl(["mysql"], raises=com.UnsupportedOperationError), pytest.mark.broken( ["polars"], @@ -449,7 +449,7 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): reason="attempt to calculate the remainder with a divisor of zero", ), pytest.mark.notimpl( - ["flink", "exasol"], + ["flink"], raises=com.UnsupportedOperationError, reason=" unit is not supported in timestamp truncate", ), @@ -457,12 +457,12 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): ), ], ) -@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["druid"], raises=AttributeError, reason="AttributeError: 'StringColumn' object has no attribute 'truncate'", ) +@pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) def test_timestamp_truncate(backend, alltypes, df, unit): expr = alltypes.timestamp_col.truncate(unit).name("tmp") @@ -545,10 +545,6 @@ def test_timestamp_truncate(backend, alltypes, df, unit): @pytest.mark.broken( ["polars", "druid"], reason="snaps to the UNIX epoch", raises=AssertionError ) -@pytest.mark.notimpl( - ["oracle"], - raises=com.OperationNotDefinedError, -) @pytest.mark.broken( ["druid"], raises=AttributeError, @@ -977,11 +973,16 @@ def convert_to_offset(x): raises=Exception, reason="pyarrow.lib.ArrowNotImplementedError: Unsupported cast", ), + pytest.mark.broken( + ["oracle"], + raises=com.OperationNotDefinedError, + reason="Some wonkiness in sqlglot generation.", + ), ], ), ], ) -@pytest.mark.notimpl(["mssql", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError) def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): expr = expr_fn(alltypes, backend).name("tmp") expected = expected_fn(df, backend) @@ -1016,11 +1017,6 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): raises=AssertionError, reason="DateTime column overflows, should use DateTime64", ), - pytest.mark.broken( - ["clickhouse"], - raises=AssertionError, - reason="DateTime column overflows, should use DateTime64", - ), pytest.mark.notimpl( ["flink"], # Note (mehmet): Following cannot be imported for backends other than Flink. @@ -1105,11 +1101,6 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): raises=AssertionError, reason="DateTime column overflows, should use DateTime64", ), - pytest.mark.broken( - ["clickhouse"], - raises=AssertionError, - reason="DateTime column overflows, should use DateTime64", - ), pytest.mark.broken( ["flink"], # Note (mehmet): Following cannot be imported for backends other than Flink. @@ -1181,7 +1172,7 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): ), ], ) -@pytest.mark.notimpl(["sqlite", "mssql", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["sqlite", "mssql"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) def test_temporal_binop_pandas_timedelta( backend, con, alltypes, df, timedelta, temporal_fn @@ -1350,7 +1341,7 @@ def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name @pytest.mark.notimpl( - ["sqlite", "snowflake", "mssql", "oracle", "exasol"], + ["sqlite", "snowflake", "mssql", "exasol"], raises=com.OperationNotDefinedError, ) @pytest.mark.broken( @@ -1373,7 +1364,7 @@ def test_interval_add_cast_scalar(backend, alltypes): @pytest.mark.notimpl( - ["sqlite", "snowflake", "mssql", "oracle", "exasol"], + ["sqlite", "snowflake", "mssql", "exasol"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( @@ -1426,9 +1417,7 @@ def test_interval_add_cast_column(backend, alltypes, df): ), ], ) -@pytest.mark.notimpl( - ["datafusion", "mssql", "oracle"], raises=com.OperationNotDefinedError -) +@pytest.mark.notimpl(["datafusion", "mssql"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["druid"], raises=AttributeError, @@ -1609,7 +1598,6 @@ def test_integer_to_timestamp(backend, con, unit): "datafusion", "mssql", "druid", - "oracle", ], raises=com.OperationNotDefinedError, ) @@ -1783,7 +1771,7 @@ def test_now_from_projection(alltypes): ["druid"], raises=PyDruidProgrammingError, reason="SQL parse failed" ) @pytest.mark.notimpl( - ["oracle"], raises=sa.exc.DatabaseError, reason="ORA-00936 missing expression" + ["oracle"], raises=OracleDatabaseError, reason="ORA-00936 missing expression" ) def test_date_literal(con, backend): expr = ibis.date(2022, 2, 4) @@ -1809,12 +1797,9 @@ def test_date_literal(con, backend): @pytest.mark.notimpl( - ["pandas", "dask", "pyspark", "mysql", "exasol"], + ["pandas", "dask", "pyspark", "mysql", "exasol", "oracle"], raises=com.OperationNotDefinedError, ) -@pytest.mark.notimpl( - ["oracle"], raises=sa.exc.DatabaseError, reason="ORA-00904: MAKE TIMESTAMP invalid" -) @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) def test_timestamp_literal(con, backend): expr = ibis.timestamp(2022, 2, 4, 16, 20, 0) @@ -1840,10 +1825,7 @@ def test_timestamp_literal(con, backend): "Timestamp(timezone='***', scale=None, nullable=True)." ), ) -@pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["oracle"], raises=sa.exc.DatabaseError, reason="ORA-00904: MAKE TIMESTAMP invalid" -) +@pytest.mark.notyet(["impala", "oracle"], raises=com.OperationNotDefinedError) @pytest.mark.parametrize( ("timezone", "expected"), [ @@ -1895,13 +1877,12 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): @pytest.mark.notimpl( - ["pandas", "datafusion", "dask", "pyspark", "polars", "mysql"], + ["pandas", "datafusion", "dask", "pyspark", "polars", "mysql", "oracle"], raises=com.OperationNotDefinedError, ) @pytest.mark.notyet( ["clickhouse", "impala", "exasol"], raises=com.OperationNotDefinedError ) -@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) @pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError) def test_time_literal(con, backend): expr = ibis.time(16, 20, 0) @@ -1929,7 +1910,7 @@ def test_time_literal(con, backend): ["sqlite"], raises=AssertionError, reason="SQLite returns Timedelta from execution" ) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) -@pytest.mark.notyet(["oracle"], raises=sa.exc.DatabaseError) +@pytest.mark.notyet(["oracle"], raises=OracleDatabaseError) @pytest.mark.parametrize( "microsecond", [ @@ -2044,7 +2025,7 @@ def test_interval_literal(con, backend): reason="'StringColumn' object has no attribute 'year'", ) @pytest.mark.broken( - ["oracle"], raises=sa.exc.DatabaseError, reason="ORA-00936: missing expression" + ["oracle"], raises=OracleDatabaseError, reason="ORA-00936: missing expression" ) def test_date_column_from_ymd(backend, con, alltypes, df): c = alltypes.timestamp_col @@ -2065,10 +2046,7 @@ def test_date_column_from_ymd(backend, con, alltypes, df): raises=AttributeError, reason="StringColumn' object has no attribute 'year'", ) -@pytest.mark.notimpl( - ["oracle"], raises=sa.exc.DatabaseError, reason="ORA-00904 make timestamp invalid" -) -@pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) +@pytest.mark.notyet(["impala", "oracle"], raises=com.OperationNotDefinedError) def test_timestamp_column_from_ymdhms(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.timestamp( @@ -2082,7 +2060,7 @@ def test_timestamp_column_from_ymdhms(backend, con, alltypes, df): @pytest.mark.notimpl( - ["oracle"], raises=sa.exc.DatabaseError, reason="ORA-01861 literal does not match" + ["oracle"], raises=OracleDatabaseError, reason="ORA-01861 literal does not match" ) def test_date_scalar_from_iso(con): expr = ibis.literal("2022-02-24") @@ -2095,7 +2073,7 @@ def test_date_scalar_from_iso(con): @pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError) @pytest.mark.notyet( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-22849 type CLOB is not supported", ) @pytest.mark.notimpl(["exasol"], raises=AssertionError, strict=False) @@ -2128,7 +2106,11 @@ def test_timestamp_extract_milliseconds_with_big_value(con): raises=Exception, reason="Unsupported CAST from Int32 to Timestamp(Nanosecond, None)", ) -@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError, reason="ORA-00932") +@pytest.mark.notimpl( + ["oracle"], + raises=OracleDatabaseError, + reason="ORA-00932", +) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) def test_integer_cast_to_timestamp_column(backend, alltypes, df): expr = alltypes.int_col.cast("timestamp") @@ -2137,8 +2119,8 @@ def test_integer_cast_to_timestamp_column(backend, alltypes, df): backend.assert_series_equal(result, expected.astype(result.dtype)) -@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) @pytest.mark.notimpl(["exasol"], raises=ExaQueryError) +@pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError) def test_integer_cast_to_timestamp_scalar(alltypes, df): expr = alltypes.int_col.min().cast("timestamp") result = expr.execute() @@ -2185,7 +2167,7 @@ def build_date_col(t): @pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) -@pytest.mark.notimpl(["oracle"], raises=sa.exc.DatabaseError) +@pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError) @pytest.mark.parametrize( ("left_fn", "right_fn"), [ @@ -2321,7 +2303,7 @@ def test_large_timestamp(con): ) @pytest.mark.notimpl( ["oracle"], - raises=sa.exc.DatabaseError, + raises=OracleDatabaseError, reason="ORA-01843: invalid month was specified", ) def test_timestamp_precision_output(con, ts, scale, unit): @@ -2403,7 +2385,6 @@ def test_delta(con, start, end, unit, expected): "flink", "impala", "mysql", - "oracle", "pandas", "pyspark", "sqlite", @@ -2438,6 +2419,11 @@ def test_delta(con, start, end, unit, expected): raises=com.UnsupportedOperationError, reason="backend doesn't support sub-second interval precision", ), + pytest.mark.notimpl( + ["oracle"], + raises=com.UnsupportedOperationError, + reason="backend doesn't support sub-second interval precision", + ), ], id="milliseconds", ), @@ -2445,7 +2431,10 @@ def test_delta(con, start, end, unit, expected): {"seconds": 2}, "2s", marks=[ - pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) + pytest.mark.notimpl( + ["datafusion", "oracle"], + raises=com.OperationNotDefinedError, + ), ], id="seconds", ), @@ -2453,7 +2442,10 @@ def test_delta(con, start, end, unit, expected): {"minutes": 5}, "300s", marks=[ - pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) + pytest.mark.notimpl( + ["datafusion", "oracle"], + raises=com.OperationNotDefinedError, + ), ], id="minutes", ), @@ -2461,7 +2453,10 @@ def test_delta(con, start, end, unit, expected): {"hours": 2}, "2h", marks=[ - pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) + pytest.mark.notimpl( + ["datafusion", "oracle"], + raises=com.OperationNotDefinedError, + ), ], id="hours", ), @@ -2469,7 +2464,10 @@ def test_delta(con, start, end, unit, expected): {"days": 2}, "2D", marks=[ - pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) + pytest.mark.notimpl( + ["datafusion", "oracle"], + raises=com.OperationNotDefinedError, + ), ], id="days", ), diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index ce93febc7bab..6ef7c950c6eb 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -18,6 +18,7 @@ GoogleBadRequest, ImpalaHiveServer2Error, MySQLOperationalError, + OracleDatabaseError, Py4JJavaError, PyDruidProgrammingError, SnowflakeProgrammingError, @@ -256,7 +257,6 @@ def calc_zscore(s): id="cumany", marks=[ pytest.mark.notimpl(["dask"], raises=NotImplementedError), - pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError), ], ), param( @@ -270,7 +270,7 @@ def calc_zscore(s): id="cumnotany", marks=[ pytest.mark.notimpl(["dask"], raises=NotImplementedError), - pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError), + pytest.mark.broken(["oracle"], raises=OracleDatabaseError), ], ), param( @@ -284,7 +284,6 @@ def calc_zscore(s): id="cumall", marks=[ pytest.mark.notimpl(["dask"], raises=NotImplementedError), - pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError), ], ), param( @@ -298,7 +297,7 @@ def calc_zscore(s): id="cumnotall", marks=[ pytest.mark.notimpl(["dask"], raises=NotImplementedError), - pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError), + pytest.mark.broken(["oracle"], raises=OracleDatabaseError), ], ), param( @@ -780,7 +779,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=AssertionError, strict=False, # sometimes it passes ), - pytest.mark.broken(["oracle"], raises=AssertionError), + pytest.mark.notyet( + ["oracle"], + raises=com.UnsupportedOperationError, + reason="oracle doesn't allow unordered analytic functions without a windowing clause", + ), pytest.mark.notimpl( ["flink"], raises=com.UnsupportedOperationError, @@ -815,7 +818,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=AssertionError, strict=False, # sometimes it passes ), - pytest.mark.broken(["oracle"], raises=AssertionError), + pytest.mark.notyet( + ["oracle"], + raises=com.UnsupportedOperationError, + reason="oracle doesn't allow unordered analytic functions without a windowing clause", + ), pytest.mark.notimpl( ["flink"], raises=com.UnsupportedOperationError, @@ -1106,7 +1113,7 @@ def test_first_last(backend): ["mysql"], raises=MySQLOperationalError, reason="not supported by MySQL" ) @pytest.mark.notyet( - ["mssql", "oracle", "polars", "snowflake", "sqlite"], + ["mssql", "polars", "snowflake", "sqlite"], raises=com.OperationNotDefinedError, reason="not support by the backend", )