diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index bcfca1039e49..ad0b4eed1a02 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from typing import TYPE_CHECKING, Any, ClassVar import sqlglot as sg @@ -113,3 +114,86 @@ def sql( def _get_schema_using_query(self, query: str) -> sch.Schema: """Return an ibis Schema from a backend-specific SQL string.""" return sch.Schema.from_tuples(self._metadata(query)) + + def create_view( + self, + name: str, + obj: ir.Table, + *, + database: str | None = None, + overwrite: bool = False, + ) -> ir.Table: + src = sg.exp.Create( + this=sg.table(name, db=database), + kind="VIEW", + replace=overwrite, + expression=self._to_sqlglot(obj), + ) + self._register_in_memory_tables(obj) + external_tables = self._collect_in_memory_tables(obj) + with self._safe_raw_sql(src, external_tables=external_tables): + pass + return self.table(name, database=database) + + def _register_in_memory_tables(self, expr: ir.Expr) -> None: + for memtable in expr.op().find(ops.InMemoryTable): + self._register_in_memory_table(memtable) + + def drop_view( + self, name: str, *, database: str | None = None, force: bool = False + ) -> None: + src = sg.exp.Drop(this=sg.table(name, db=database), kind="VIEW", exists=force) + with contextlib.closing(self.raw_sql(src)): + pass + + def _get_temp_view_definition(self, name: str, definition: str) -> str: + yield sg.exp.Create( + this=sg.to_identifier(name, quoted=self.compiler.quoted), + kind="VIEW", + expression=definition, + replace=True, + properties=sg.exp.Properties(expressions=[sg.exp.TemporaryProperty()]), + ).sql(self.name) + + def _create_temp_view(self, table_name, source): + if table_name not in self._temp_views and table_name in self.list_tables(): + raise ValueError( + f"{table_name} already exists as a non-temporary table or view" + ) + with self._safe_raw_sql(self._get_temp_view_definition(table_name, source)): + pass + self._temp_views.add(table_name) + self._register_temp_view_cleanup(table_name) + + def _register_temp_view_cleanup(self, name: str) -> None: + """Register a clean up function for a temporary view. + + No-op by default. + + Parameters + ---------- + name + The temporary view to register for clean up. + """ + + def _load_into_cache(self, name, expr): + self.create_table(name, expr, schema=expr.schema(), temp=True) + + def _clean_up_cached_table(self, op): + self.drop_table(op.name) + + def execute( + self, expr: ir.Expr, limit: str | None = "default", **kwargs: Any + ) -> Any: + """Execute an expression.""" + + self._run_pre_execute_hooks(expr) + table = expr.as_table() + sql = self.compile(table, limit=limit, **kwargs) + + schema = table.schema() + self._log(sql) + + with self._safe_raw_sql(sql) as cur: + result = self.fetch_from_cursor(cur, schema) + return expr.__pandas_result__(result) diff --git a/ibis/backends/base/sqlglot/compiler.py b/ibis/backends/base/sqlglot/compiler.py index 5b3b79c5e53c..b3aef001cd9e 100644 --- a/ibis/backends/base/sqlglot/compiler.py +++ b/ibis/backends/base/sqlglot/compiler.py @@ -4,6 +4,7 @@ import calendar import functools import itertools +import math import operator import string from collections.abc import Mapping @@ -144,6 +145,15 @@ class SQLGlotCompiler(abc.ABC): quoted: bool | None = None """Whether to always quote identifiers.""" + NAN = sg.exp.Literal.number("'NaN'::double") + """Backend's NaN literal.""" + + POS_INF = sg.exp.Literal.number("'Inf'::double") + """Backend's positive infinity literal.""" + + NEG_INF = sg.exp.Literal.number("'-Inf'::double") + """Backend's negative infinity literal.""" + def __init__(self) -> None: self.agg = AggGen(aggfunc=self._aggregate) self.f = FuncGen() @@ -217,10 +227,10 @@ def fn(node, _, **kwargs): return result alias_index = next(gen_alias_index) - alias = f"t{alias_index:d}" + alias = sg.to_identifier(f"t{alias_index:d}", quoted=quoted) try: - return result.subquery(sg.exp.TableAlias(this=alias, quoted=quoted)) + return result.subquery(alias) except AttributeError: return result.as_(alias, quoted=quoted) @@ -269,6 +279,16 @@ def visit_Literal(self, op, *, value, dtype, **kw): raise com.UnsupportedOperationError( f"Unsupported NULL for non-nullable type: {dtype!r}" ) + elif dtype.is_integer(): + return sg.exp.convert(value) + elif dtype.is_floating(): + if math.isnan(value): + return self.NAN + elif math.isinf(value): + return self.POS_INF if value < 0 else self.NEG_INF + return sg.exp.convert(value) + elif dtype.is_decimal(): + return self.cast(sg.exp.convert(str(value)), dtype) elif dtype.is_interval(): return sg.exp.Interval( this=sg.exp.convert(str(value)), unit=dtype.resolution.upper() diff --git a/ibis/backends/base/sqlglot/datatypes.py b/ibis/backends/base/sqlglot/datatypes.py index be88edcc940a..78938ab30434 100644 --- a/ibis/backends/base/sqlglot/datatypes.py +++ b/ibis/backends/base/sqlglot/datatypes.py @@ -405,26 +405,5 @@ class OracleType(SqlglotType): dialect = "oracle" -class SnowflakeType(SqlglotType): - dialect = "snowflake" - default_temporal_scale = 9 - - @classmethod - def _from_sqlglot_FLOAT(cls) -> dt.Float64: - return dt.Float64(nullable=cls.default_nullable) - - @classmethod - def _from_sqlglot_DECIMAL(cls, precision=None, scale=None) -> dt.Decimal: - if scale is None or int(scale.this.this) == 0: - return dt.Int64(nullable=cls.default_nullable) - else: - return super()._from_sqlglot_DECIMAL(precision, scale) - - @classmethod - def _from_sqlglot_ARRAY(cls, value_type=None) -> dt.Array: - assert value_type is None - return dt.Array(dt.json, nullable=cls.default_nullable) - - class SQLiteType(SqlglotType): dialect = "sqlite" diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index 7f131366b8c1..7eba44bbd092 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -709,35 +709,10 @@ def create_view( expression=self._to_sqlglot(obj), ) external_tables = self._collect_in_memory_tables(obj) - with closing(self.raw_sql(src, external_tables=external_tables)): + with self._safe_raw_sql(src, external_tables=external_tables): pass return self.table(name, database=database) - def drop_view( - self, name: str, *, database: str | None = None, force: bool = False - ) -> None: - src = sg.exp.Drop(this=sg.table(name, db=database), kind="VIEW", exists=force) - with closing(self.raw_sql(src)): - pass - - def _load_into_cache(self, name, expr): - self.create_table(name, expr, schema=expr.schema(), temp=True) - - def _clean_up_cached_table(self, op): - self.drop_table(op.name) - - def _create_temp_view(self, table_name, source): - if table_name not in self._temp_views and table_name in self.list_tables(): - raise ValueError( - f"{table_name} already exists as a non-temporary table or view" - ) - src = sg.exp.Create( - this=sg.table(table_name), kind="VIEW", replace=True, expression=source - ) - self.raw_sql(src) - self._temp_views.add(table_name) - self._register_temp_view_cleanup(table_name) - def _register_temp_view_cleanup(self, name: str) -> None: def drop(self, name: str, query: str): self.raw_sql(query) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index c2010f52258b..82de248431be 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -261,6 +261,10 @@ def get_schema( } ) + @contextlib.contextmanager + def _safe_raw_sql(self, *args, **kwargs): + yield self.raw_sql(*args, **kwargs) + def list_databases(self, like: str | None = None) -> list[str]: col = "catalog_name" query = sg.select(sg.exp.Distinct(expressions=[sg.column(col)])).from_( @@ -412,26 +416,6 @@ def _from_url(self, url: str, **kwargs) -> BaseBackend: self._convert_kwargs(kwargs) return self.connect(**kwargs) - def execute( - self, expr: ir.Expr, limit: str | None = "default", **kwargs: Any - ) -> Any: - """Execute an expression.""" - - self._run_pre_execute_hooks(expr) - table = expr.as_table() - sql = self.compile(table, limit=limit, **kwargs) - - schema = table.schema() - self._log(sql) - - try: - cur = self.con.execute(sql) - except duckdb.CatalogException as e: - raise exc.IbisError(e) - - result = self.fetch_from_cursor(cur, schema) - return expr.__pandas_result__(result) - def load_extension(self, extension: str, force_install: bool = False) -> None: """Install and load a duckdb extension by name or path. @@ -532,25 +516,6 @@ def _register_failure(self): f"please call one of {msg} directly" ) - def _create_temp_view(self, table_name, source): - if table_name not in self._temp_views and table_name in self.list_tables(): - raise ValueError( - f"{table_name} already exists as a non-temporary table or view" - ) - src = sg.exp.Create( - this=sg.exp.Identifier( - this=table_name, quoted=True - ), # CREATE ... 'table_name' - kind="VIEW", # VIEW - replace=True, # OR REPLACE - properties=sg.exp.Properties( - expressions=[sg.exp.TemporaryProperty()] # TEMPORARY - ), - expression=source, # AS ... - ) - self.raw_sql(src.sql("duckdb")) - self._temp_views.add(table_name) - @util.experimental def read_json( self, @@ -1352,9 +1317,6 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: except duckdb.NotImplementedException: self.con.register(name, data.to_pyarrow(schema)) - def _get_temp_view_definition(self, name: str, definition) -> str: - yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}" - def _register_udfs(self, expr: ir.Expr) -> None: import ibis.expr.operations as ops diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 8f50c9ef1745..03abd980b650 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -9,21 +9,19 @@ import json import os import platform -import re import shutil import sys import tempfile import textwrap import warnings +from operator import itemgetter from pathlib import Path from typing import TYPE_CHECKING, Any import pyarrow as pa import pyarrow_hotfix # noqa: F401 -import sqlalchemy as sa import sqlglot as sg from packaging.version import parse as vparse -from sqlalchemy.ext.compiler import compiles import ibis import ibis.common.exceptions as com @@ -31,13 +29,10 @@ import ibis.expr.operations as ops import ibis.expr.types as ir from ibis import util -from ibis.backends.base import CanCreateDatabase -from ibis.backends.base.sql.alchemy import ( - AlchemyCanCreateSchema, - AlchemyCompiler, - AlchemyCrossSchemaBackend, - AlchemyExprTranslator, -) +from ibis.backends.base import CanCreateDatabase, CanCreateSchema +from ibis.backends.base.sqlglot import SQLGlotBackend +from ibis.backends.snowflake.compiler import SnowflakeCompiler +from ibis.backends.snowflake.converter import SnowflakePandasData with warnings.catch_warnings(): if vparse(importlib.metadata.version("snowflake-connector-python")) >= vparse( @@ -48,11 +43,7 @@ message="You have an incompatible version of 'pyarrow' installed", category=UserWarning, ) - from snowflake.sqlalchemy import ARRAY, DOUBLE, OBJECT, URL - - from ibis.backends.snowflake.converter import SnowflakePandasData - from ibis.backends.snowflake.datatypes import SnowflakeType - from ibis.backends.snowflake.registry import operation_registry + import snowflake.connector as sc if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping @@ -62,47 +53,25 @@ import ibis.expr.schema as sch -class SnowflakeExprTranslator(AlchemyExprTranslator): - _registry = operation_registry - _rewrites = AlchemyExprTranslator._rewrites.copy() - _has_reduction_filter_syntax = False - _forbids_frame_clause = ( - *AlchemyExprTranslator._forbids_frame_clause, - ops.Lag, - ops.Lead, - ) - _require_order_by = (*AlchemyExprTranslator._require_order_by, ops.Reduction) - _dialect_name = "snowflake" - _quote_column_names = True - _quote_table_names = True - supports_unnest_in_select = False - type_mapper = SnowflakeType - - -class SnowflakeCompiler(AlchemyCompiler): - cheap_in_memory_tables = True - translator_class = SnowflakeExprTranslator - - _SNOWFLAKE_MAP_UDFS = { "ibis_udfs.public.object_merge": { - "inputs": {"obj1": OBJECT, "obj2": OBJECT}, - "returns": OBJECT, + "inputs": {"obj1": "OBJECT", "obj2": "OBJECT"}, + "returns": "OBJECT", "source": "return Object.assign(obj1, obj2)", }, "ibis_udfs.public.object_values": { - "inputs": {"obj": OBJECT}, - "returns": ARRAY, + "inputs": {"obj": "OBJECT"}, + "returns": "ARRAY", "source": "return Object.values(obj)", }, "ibis_udfs.public.object_from_arrays": { - "inputs": {"ks": ARRAY, "vs": ARRAY}, - "returns": OBJECT, + "inputs": {"ks": "ARRAY", "vs": "ARRAY"}, + "returns": "OBJECT", "source": "return Object.assign(...ks.map((k, i) => ({[k]: vs[i]})))", }, "ibis_udfs.public.array_zip": { - "inputs": {"arrays": ARRAY}, - "returns": ARRAY, + "inputs": {"arrays": "ARRAY"}, + "returns": "ARRAY", "source": """\ const longest = arrays.reduce((a, b) => a.length > b.length ? a : b, []); const keys = Array.from(Array(arrays.length).keys()).map(key => `f${key + 1}`); @@ -113,18 +82,17 @@ class SnowflakeCompiler(AlchemyCompiler): "ibis_udfs.public.array_repeat": { # Integer inputs are not allowed because JavaScript only supports # doubles - "inputs": {"value": ARRAY, "count": DOUBLE}, - "returns": ARRAY, + "inputs": {"value": "ARRAY", "count": "DOUBLE"}, + "returns": "ARRAY", "source": """return Array(count).fill(value).flat();""", }, } -class Backend(AlchemyCrossSchemaBackend, CanCreateDatabase, AlchemyCanCreateSchema): +class Backend(SQLGlotBackend, CanCreateDatabase, CanCreateSchema): name = "snowflake" - compiler = SnowflakeCompiler - supports_create_or_replace = True - supports_python_udfs = True + compiler = SnowflakeCompiler() + supports_python_udfs = False _latest_udf_python_version = (3, 10) @@ -134,29 +102,27 @@ def _convert_kwargs(self, kwargs): @property def version(self) -> str: - return self._scalar_query(sa.select(sa.func.current_version())) + with self._safe_raw_sql(sg.select(sg.func("current_version"))) as cur: + (version,) = cur.fetchone() + return version @property def current_schema(self) -> str: - with self.con.connect() as con: - return con.connection.schema + return self.con.schema @property def current_database(self) -> str: - with self.con.connect() as con: - return con.connection.database - - def _compile_sqla_type(self, typ) -> str: - return sa.types.to_instance(typ).compile(dialect=self.con.dialect) + return self.con.database def _make_udf(self, name: str, defn) -> str: - dialect = self.con.dialect - quote = dialect.preparer(dialect).quote_identifier signature = ", ".join( - f"{quote(argname)} {self._compile_sqla_type(typ)}" + "{} {}".format( + sg.to_identifier(argname, quoted=self.compiler.quoted).sql(self.name), + typ, + ) for argname, typ in defn["inputs"].items() ) - return_type = self._compile_sqla_type(defn["returns"]) + return_type = defn["returns"] return f"""\ CREATE OR REPLACE FUNCTION {name}({signature}) RETURNS {return_type} @@ -166,17 +132,7 @@ def _make_udf(self, name: str, defn) -> str: AS $$ {defn["source"]} $$""" - def do_connect( - self, - user: str, - account: str, - database: str, - password: str | None = None, - authenticator: str | None = None, - connect_args: Mapping[str, Any] | None = None, - create_object_udfs: bool = True, - **kwargs: Any, - ): + def do_connect(self, create_object_udfs: bool = True, **kwargs: Any): """Connect to Snowflake. Parameters @@ -205,87 +161,64 @@ def do_connect( Enable object UDF extensions defined by ibis on the first connection to the database. connect_args - Additional arguments passed to the SQLAlchemy engine creation call. + Additional arguments passed to the DBAPI connection call. kwargs - Additional arguments passed to the SQLAlchemy URL constructor. - See https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#additional-connection-parameters - for more details + Additional arguments passed to the URL constructor. """ - dbparams = dict(zip(("database", "schema"), database.split("/", 1))) - if dbparams.get("schema") is None: - raise ValueError( - "Schema must be non-None. Pass the schema as part of the " - f"database e.g., {dbparams['database']}/my_schema" - ) - - # snowflake-connector-python does not handle `None` for password, but - # accepts the empty string - url = URL( - account=account, user=user, password=password or "", **dbparams, **kwargs - ) - if connect_args is None: - connect_args = {} - - session_parameters = connect_args.setdefault("session_parameters", {}) + connect_args = kwargs.copy() + session_parameters = connect_args.pop("session_parameters", {}) # enable multiple SQL statements by default - session_parameters.setdefault("MULTI_STATEMENT_COUNT", "0") + session_parameters.setdefault("MULTI_STATEMENT_COUNT", 0) # don't format JSON output by default - session_parameters.setdefault("JSON_INDENT", "0") + session_parameters.setdefault("JSON_INDENT", 0) # overwrite session parameters that are required for ibis + snowflake # to work session_parameters.update( dict( # Use Arrow for query results - PYTHON_CONNECTOR_QUERY_RESULT_FORMAT="ARROW", + PYTHON_CONNECTOR_QUERY_RESULT_FORMAT="arrow_force", # JSON output must be strict for null versus undefined - STRICT_JSON_OUTPUT="TRUE", + STRICT_JSON_OUTPUT=True, # Timezone must be UTC TIMEZONE="UTC", ), ) - if authenticator is not None: - connect_args.setdefault("authenticator", authenticator) - - engine = sa.create_engine( - url, connect_args=connect_args, poolclass=sa.pool.StaticPool - ) - - @sa.event.listens_for(engine, "connect") - def connect(dbapi_connection, connection_record): - """Register UDFs on a `"connect"` event.""" - if create_object_udfs: - with dbapi_connection.cursor() as cur: - database, schema = cur.execute( - "SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()" - ).fetchone() - try: - cur.execute("CREATE DATABASE IF NOT EXISTS ibis_udfs") - # snowflake activates a database on creation, so reset - # it back to the original database and schema - cur.execute(f"USE SCHEMA {database}.{schema}") - for name, defn in _SNOWFLAKE_MAP_UDFS.items(): - cur.execute(self._make_udf(name, defn)) - except Exception as e: # noqa: BLE001 - warnings.warn( - f"Unable to create map UDFs, some functionality will not work: {e}" - ) - - super().do_connect(engine) - - def normalize_name(name): - if name is None: - return None - elif not name: - return "" - elif name.lower() == name: - return sa.sql.quoted_name(name, quote=True) - else: - return name - - self.con.dialect.normalize_name = normalize_name + con = sc.connect(**connect_args, session_parameters=session_parameters) + + if create_object_udfs: + database = con.database + schema = con.schema + create_stmt = sg.exp.Create( + kind="DATABASE", this=sg.to_identifier("ibis_udfs"), exists=True + ).sql(self.name) + use_stmt = sg.exp.Use( + kind="SCHEMA", + this=sg.table(schema, db=database, quoted=self.compiler.quoted), + ).sql(self.name) + + stmts = [ + create_stmt, + # snowflake activates a database on creation, so reset it back + # to the original database and schema + use_stmt, + *( + self._make_udf(name, defn) + for name, defn in _SNOWFLAKE_MAP_UDFS.items() + ), + ] + stmt = "; ".join(stmts) + with contextlib.closing(con.cursor()) as cur: + try: + cur.execute(stmt) + except Exception as e: # noqa: BLE001 + warnings.warn( + f"Unable to create map UDFs, some functionality will not work: {e}" + ) + self.con = con + self._temp_views: set[str] = set() def _get_udf_source(self, udf_node: ops.ScalarUDF): name = type(udf_node).__name__ @@ -311,7 +244,7 @@ def _get_udf_source(self, udf_node: ops.ScalarUDF): def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str: return """\ -CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature}) +CREATE OR REPLACE TEMP FUNCTION {name}({signature}) RETURNS {return_type} LANGUAGE PYTHON IMMUTABLE @@ -328,7 +261,7 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str: def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: return """\ -CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature}) +CREATE OR REPLACE TEMP FUNCTION {name}({signature}) RETURNS {return_type} LANGUAGE PYTHON IMMUTABLE @@ -357,16 +290,15 @@ def to_pyarrow( *, params: Mapping[ir.Scalar, Any] | None = None, limit: int | str | None = None, - **_: Any, + **kwargs: Any, ) -> pa.Table: from ibis.backends.snowflake.converter import SnowflakePyArrowData self._run_pre_execute_hooks(expr) - query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params) - sql = query_ast.compile() - with self.begin() as con: - res = con.execute(sql).cursor.fetch_arrow_all() + sql = self.compile(expr, limit=limit, params=params, **kwargs) + with self._safe_raw_sql(sql) as cur: + res = cur.fetch_arrow_all() target_schema = expr.as_table().schema().to_pyarrow() if res is None: @@ -375,7 +307,7 @@ def to_pyarrow( return expr.__pyarrow_result__(res, data_mapper=SnowflakePyArrowData) def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: - if (table := cursor.cursor.fetch_arrow_all()) is None: + if (table := cursor.fetch_arrow_all()) is None: table = schema.to_pyarrow().empty_table() df = table.to_pandas(timestamp_as_object=True) df.columns = list(schema.names) @@ -387,20 +319,18 @@ def to_pandas_batches( *, params: Mapping[ir.Scalar, Any] | None = None, limit: int | str | None = None, - **_: Any, + **kwargs: Any, ) -> Iterator[pd.DataFrame | pd.Series | Any]: self._run_pre_execute_hooks(expr) - query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params) - sql = query_ast.compile() + sql = self.compile(expr, limit=limit, params=params, **kwargs) target_schema = expr.as_table().schema() converter = functools.partial( SnowflakePandasData.convert_table, schema=target_schema ) - with self.begin() as con, contextlib.closing(con.execute(sql)) as cur: + with self._safe_raw_sql(sql) as cur: yield from map( - expr.__pandas_result__, - map(converter, cur.cursor.fetch_pandas_batches()), + expr.__pandas_result__, map(converter, cur.fetch_pandas_batches()) ) def to_pyarrow_batches( @@ -410,11 +340,10 @@ def to_pyarrow_batches( params: Mapping[ir.Scalar, Any] | None = None, limit: int | str | None = None, chunk_size: int = 1_000_000, - **_: Any, + **kwargs: Any, ) -> pa.ipc.RecordBatchReader: self._run_pre_execute_hooks(expr) - query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params) - sql = query_ast.compile() + sql = self.compile(expr, limit=limit, params=params, **kwargs) target_schema = expr.as_table().schema().to_pyarrow() return pa.RecordBatchReader.from_batches( @@ -427,30 +356,49 @@ def to_pyarrow_batches( def _make_batch_iter( self, sql: str, *, target_schema: sch.Schema, chunk_size: int ) -> Iterator[pa.RecordBatch]: - with self.begin() as con, contextlib.closing(con.execute(sql)) as cur: + with self._safe_raw_sql(sql) as cur: yield from itertools.chain.from_iterable( t.rename_columns(target_schema.names) .cast(target_schema) .to_batches(max_chunksize=chunk_size) - for t in cur.cursor.fetch_arrow_batches() + for t in cur.fetch_arrow_batches() ) + def get_schema( + self, table_name: str, schema: str | None = None, database: str | None = None + ) -> Iterable[tuple[str, dt.DataType]]: + table = sg.table( + table_name, db=schema, catalog=database, quoted=self.compiler.quoted + ).sql(self.name) + with self._safe_raw_sql(f"DESCRIBE TABLE {table}") as cur: + result = cur.fetchall() + + fields = { + name: self.compiler.type_mapper.from_string(typ, nullable=nullable == "Y") + for name, typ, _, nullable, *_ in result + } + return ibis.schema(fields) + def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: - with self.begin() as con: - con.exec_driver_sql(query) - result = con.exec_driver_sql("DESC RESULT last_query_id()").mappings().all() + with self._safe_raw_sql(f"{query}; DESC RESULT last_query_id()") as cur: + result = cur.fetchall() + breakpoint() for field in result: name = field["name"] type_string = field["type"] is_nullable = field["null?"] == "Y" - yield name, SnowflakeType.from_string(type_string, nullable=is_nullable) + yield ( + name, + self.compiler.type_mapper.from_string( + type_string, nullable=is_nullable + ), + ) def list_databases(self, like: str | None = None) -> list[str]: - with self.begin() as con: - databases = [ - row["name"] for row in con.exec_driver_sql("SHOW DATABASES").mappings() - ] + with self._safe_raw_sql("SHOW DATABASES") as con: + rows = con.fetchall() + breakpoint() return self._filter_with_like(databases, like) def list_schemas( @@ -459,10 +407,11 @@ def list_schemas( query = "SHOW SCHEMAS" if database is not None: - query += f" IN {self._quote(database)}" + query += f" IN {self.to_identifier(database).sql(self.name)}" - with self.begin() as con: - schemata = [row["name"] for row in con.exec_driver_sql(query).mappings()] + with self.con.cursor() as con: + breakpoint() + schemata = [row["name"] for row in con.execute(query)] return self._filter_with_like(schemata, like) @@ -517,12 +466,10 @@ def list_tables( tables_query += f" IN {database}" views_query += f" IN {database}" - with self.begin() as con: + with self.con.cursor() as cur: # TODO: considering doing this with a single query using information_schema - tables = [ - row["name"] for row in con.exec_driver_sql(tables_query).mappings() - ] - views = [row["name"] for row in con.exec_driver_sql(views_query).mappings()] + tables = list(map(itemgetter(1), cur.execute(tables_query))) + views = list(map(itemgetter(1), cur.execute(views_query))) return self._filter_with_like(tables + views, like=like) @@ -531,8 +478,8 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: raw_name = op.name - with self.begin() as con: - if con.exec_driver_sql(f"SHOW TABLES LIKE '{raw_name}'").scalar() is None: + with self.con.cursor() as con: + if not con.execute(f"SHOW TABLES LIKE '{raw_name}'").fetchone(): tmpdir = tempfile.TemporaryDirectory() try: path = os.path.join(tmpdir.name, f"{raw_name}.parquet") @@ -546,25 +493,26 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: with contextlib.suppress(Exception): shutil.rmtree(tmpdir.name) - def _get_temp_view_definition( - self, name: str, definition: sa.sql.compiler.Compiled - ) -> str: - yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}" - def create_database(self, name: str, force: bool = False) -> None: current_database = self.current_database current_schema = self.current_schema - name = self._quote(name) - if_not_exists = "IF NOT EXISTS " * force - with self.begin() as con: - con.exec_driver_sql(f"CREATE DATABASE {if_not_exists}{name}") + create_stmt = sg.exp.Create( + this=sg.to_identifier(name, quoted=self.compiler.quoted), + kind="DATABASE", + exists=force, + ) + use_stmt = sg.exp.Use( + kind="DATABASE", + this=sg.table( + current_schema, db=current_database, quoted=self.compiler.quoted + ), + ).sql(self.name) + with self._safe_raw_sql(create_stmt) as cur: # Snowflake automatically switches to the new database after creating # it per # https://docs.snowflake.com/en/sql-reference/sql/create-database#general-usage-notes # so we switch back to the original database and schema - con.exec_driver_sql( - f"USE SCHEMA {self._quote(current_database)}.{self._quote(current_schema)}" - ) + cur.execute(use_stmt) def drop_database(self, name: str, force: bool = False) -> None: current_database = self.current_database @@ -572,27 +520,49 @@ def drop_database(self, name: str, force: bool = False) -> None: raise com.UnsupportedOperationError( "Dropping the current database is not supported because its behavior is undefined" ) - name = self._quote(name) - if_exists = "IF EXISTS " * force - with self.begin() as con: - con.exec_driver_sql(f"DROP DATABASE {if_exists}{name}") + drop_stmt = sg.exp.Drop( + this=sg.to_identifier(name, quoted=self.compiler.quoted), + kind="DATABASE", + exists=force, + ) + with self._safe_raw_sql(drop_stmt): + pass def create_schema( self, name: str, database: str | None = None, force: bool = False ) -> None: - name = ".".join(map(self._quote, filter(None, [database, name]))) - if_not_exists = "IF NOT EXISTS " * force current_database = self.current_database current_schema = self.current_schema - with self.begin() as con: - con.exec_driver_sql(f"CREATE SCHEMA {if_not_exists}{name}") + create_stmt = sg.exp.Create( + this=sg.table(name, db=database, quoted=self.compiler.quoted), + kind="SCHEMA", + exists=force, + ) + use_stmt = sg.exp.Use( + kind="SCHEMA", + this=sg.table( + current_schema, db=current_database, quoted=self.compiler.quoted + ), + ).sql(self.name) + with self._safe_raw_sql(create_stmt) as cur: # Snowflake automatically switches to the new schema after creating # it per # https://docs.snowflake.com/en/sql-reference/sql/create-schema#usage-notes # so we switch back to the original schema - con.exec_driver_sql( - f"USE SCHEMA {self._quote(current_database)}.{self._quote(current_schema)}" - ) + cur.execute(use_stmt) + + @contextlib.contextmanager + def _safe_raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: + with contextlib.suppress(AttributeError): + query = query.sql(dialect=self.name) + + with self.con.cursor() as cur: + yield cur.execute(query, **kwargs) + + def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: + with contextlib.suppress(AttributeError): + query = query.sql(dialect=self.name) + return self.con.execute(query, **kwargs) def drop_schema( self, name: str, database: str | None = None, force: bool = False @@ -604,10 +574,13 @@ def drop_schema( "Dropping the current schema is not supported because its behavior is undefined" ) - name = ".".join(map(self._quote, filter(None, [database, name]))) - if_exists = "IF EXISTS " * force - with self.begin() as con: - con.exec_driver_sql(f"DROP SCHEMA {if_exists}{name}") + drop_stmt = sg.exp.Drop( + this=sg.table(name, db=database, quoted=self.compiler.quoted), + kind="SCHEMA", + exists=force, + ) + with self._safe_raw_sql(drop_stmt): + pass def create_table( self, @@ -646,23 +619,35 @@ def create_table( if obj is None and schema is None: raise ValueError("Either `obj` or `schema` must be specified") - create_stmt = "CREATE" + column_defs = [ + sg.exp.ColumnDef( + this=sg.to_identifier(name, quoted=self.compiler.quoted), + kind=self.compiler.type_mapper.from_ibis(typ), + constraints=( + None + if typ.nullable + else [ + sg.exp.ColumnConstraint(kind=sg.exp.NotNullColumnConstraint()) + ] + ), + ) + for name, typ in (schema or {}).items() + ] + + target = sg.table(name, db=database, quoted=self.compiler.quoted) - if overwrite: - create_stmt += " OR REPLACE" + if column_defs: + target = sg.exp.Schema(this=target, expressions=column_defs) - if temp: - create_stmt += " TEMPORARY" + properties = [] - ident = self._quote(name) - create_stmt += f" TABLE {ident}" + if temp: + properties.append(sg.exp.TemporaryProperty()) - if schema is not None: - schema_sql = ", ".join( - f"{name} {SnowflakeType.to_string(typ) + ' NOT NULL' * (not typ.nullable)}" - for name, typ in zip(map(self._quote, schema.keys()), schema.values()) + if comment is not None: + properties.append( + sg.exp.SchemaCommentProperty(this=sg.exp.convert(comment)) ) - create_stmt += f" ({schema_sql})" if obj is not None: if not isinstance(obj, ir.Expr): @@ -672,29 +657,37 @@ def create_table( self._run_pre_execute_hooks(table) - query = self.compile(table).compile( - dialect=self.con.dialect, compile_kwargs=dict(literal_binds=True) - ) - create_stmt += f" AS {query}" - - if comment is not None: - create_stmt += f" COMMENT '{comment}'" + query = self._to_sqlglot(table) + else: + query = None + + create_stmt = sg.exp.Create( + kind="TABLE", + this=target, + replace=overwrite, + properties=sg.exp.Properties(expressions=properties) + if properties + else None, + expression=query, + ) - with self.begin() as con: - con.exec_driver_sql(create_stmt) + with self._safe_raw_sql(create_stmt): + pass return self.table(name, schema=database) def drop_table( - self, name: str, database: str | None = None, force: bool = False + self, + name: str, + database: str | None = None, + schema: str | None = None, + force: bool = False, ) -> None: - name = self._quote(name) - # TODO: handle database quoting - if database is not None: - name = f"{database}.{name}" - drop_stmt = "DROP TABLE" + (" IF EXISTS" * force) + f" {name}" - with self.begin() as con: - con.exec_driver_sql(drop_stmt) + drop_stmt = sg.exp.Drop( + kind="TABLE", this=sg.table(name, db=schema, catalog=database), exists=force + ) + with self._safe_raw_sql(drop_stmt): + pass def read_csv( self, path: str | Path, table_name: str | None = None, **kwargs: Any @@ -722,7 +715,7 @@ def read_csv( # https://docs.snowflake.com/en/sql-reference/sql/put#optional-parameters threads = min((os.cpu_count() or 2) // 2, 99) table = table_name or ibis.util.gen_name("read_csv_snowflake") - qtable = self._quote(table) + qtable = sg.to_identifier(table, quoted=self.compiler.quoted) parse_header = header = kwargs.pop("parse_header", True) skip_header = kwargs.pop("skip_header", True) @@ -737,59 +730,58 @@ def read_csv( f"{name.upper()} = {value!r}" for name, value in kwargs.items() ) - with self.begin() as con: + stmts = [ # create a temporary stage for the file - con.exec_driver_sql(f"CREATE TEMP STAGE {stage}") - + f"CREATE TEMP STAGE {stage}", # create a temporary file format for CSV schema inference - create_infer_fmt = ( + ( f"CREATE TEMP FILE FORMAT {file_format} TYPE = CSV PARSE_HEADER = {str(header).upper()}" + options - ) - con.exec_driver_sql(create_infer_fmt) - + ), # copy the local file to the stage - con.exec_driver_sql( - f"PUT 'file://{Path(path).absolute()}' @{stage} PARALLEL = {threads:d}" - ) + f"PUT 'file://{Path(path).absolute()}' @{stage} PARALLEL = {threads:d}", + ] + + with self.con.cursor() as con: + con.execute("; ".join(stmts)) # handle setting up the schema in python because snowflake is # broken for csv globs: it cannot parse the result of the following # query in USING TEMPLATE - fields = json.loads( - con.exec_driver_sql( - f""" - SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*)) - WITHIN GROUP (ORDER BY ORDER_ID ASC) - FROM TABLE( - INFER_SCHEMA( - LOCATION => '@{stage}', - FILE_FORMAT => '{file_format}' - ) + (info,) = con.execute( + f""" + SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*)) + WITHIN GROUP (ORDER BY ORDER_ID ASC) + FROM TABLE( + INFER_SCHEMA( + LOCATION => '@{stage}', + FILE_FORMAT => '{file_format}' ) - """ - ).scalar() - ) - fields = [ - (self._quote(field["COLUMN_NAME"]), field["TYPE"], field["NULLABLE"]) - for field in fields - ] + ) + """ + ).fetchall() columns = ", ".join( - f"{quoted_name} {typ}{' NOT NULL' * (not nullable)}" - for quoted_name, typ, nullable in fields + "{} {}{}".format( + self.to_identifier( + field["COLUMN_NAME"], quoted=self.compiler.quoted + ).sql(self.name), + field["TYPE"], + " NOT NULL" if not field["NULLABLE"] else "", + ) + for field in json.loads(info) ) - # create a temporary table using the stage and format inferred - # from the CSV - con.exec_driver_sql(f"CREATE TEMP TABLE {qtable} ({columns})") - - # load the CSV into the table - con.exec_driver_sql( + stmts = [ + # create a temporary table using the stage and format inferred + # from the CSV + f"CREATE TEMP TABLE {qtable} ({columns})", + # load the CSV into the table f""" COPY INTO {qtable} FROM @{stage} FILE_FORMAT = (TYPE = CSV SKIP_HEADER = {int(header)}{options}) - """ - ) + """, + ] + con.execute("; ".join(stmts)) return self.table(table) @@ -817,7 +809,7 @@ def read_json( stage = util.gen_name("read_json_stage") file_format = util.gen_name("read_json_format") table = table_name or util.gen_name("read_json_snowflake") - qtable = self._quote(table) + qtable = sg.to_identifier(table, quoted=self.compiler.quoted) threads = min((os.cpu_count() or 2) // 2, 99) kwargs.setdefault("strip_outer_array", True) @@ -827,42 +819,33 @@ def read_json( f"{name.upper()} = {value!r}" for name, value in kwargs.items() ) - with self.begin() as con: - con.exec_driver_sql( - f"CREATE TEMP FILE FORMAT {file_format} TYPE = JSON" + options - ) - - con.exec_driver_sql( - f"CREATE TEMP STAGE {stage} FILE_FORMAT = {file_format}" - ) - con.exec_driver_sql( - f"PUT 'file://{Path(path).absolute()}' @{stage} PARALLEL = {threads:d}" - ) - - con.exec_driver_sql( - f""" - CREATE TEMP TABLE {qtable} - USING TEMPLATE ( - SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*)) - WITHIN GROUP (ORDER BY ORDER_ID ASC) - FROM TABLE( - INFER_SCHEMA( - LOCATION => '@{stage}', - FILE_FORMAT => '{file_format}' - ) + stmts = [ + f"CREATE TEMP FILE FORMAT {file_format} TYPE = JSON" + options, + f"CREATE TEMP STAGE {stage} FILE_FORMAT = {file_format}", + f"PUT 'file://{Path(path).absolute()}' @{stage} PARALLEL = {threads:d}", + f""" + CREATE TEMP TABLE {qtable} + USING TEMPLATE ( + SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*)) + WITHIN GROUP (ORDER BY ORDER_ID ASC) + FROM TABLE( + INFER_SCHEMA( + LOCATION => '@{stage}', + FILE_FORMAT => '{file_format}' ) ) - """ ) - + """, # load the JSON file into the table - con.exec_driver_sql( - f""" - COPY INTO {qtable} - FROM @{stage} - MATCH_BY_COLUMN_NAME = {str(match_by_column_name).upper()} - """ - ) + f""" + COPY INTO {qtable} + FROM @{stage} + MATCH_BY_COLUMN_NAME = {str(match_by_column_name).upper()} + """, + ] + + with self._safe_raw_sql("; ".join(stmts)): + pass return self.table(table) @@ -898,7 +881,7 @@ def read_parquet( stage = util.gen_name("read_parquet_stage") table = table_name or util.gen_name("read_parquet_snowflake") - qtable = self._quote(table) + qtable = sg.to_identifier(table, quoted=self.compiler.quoted) threads = min((os.cpu_count() or 2) // 2, 99) options = " " * bool(kwargs) + " ".join( @@ -911,11 +894,16 @@ def read_parquet( # see # https://community.snowflake.com/s/article/How-to-load-logical-type-TIMESTAMP-data-from-Parquet-files-into-Snowflake names_types = [ - (name, SnowflakeType.to_string(typ), typ.nullable, typ.is_timestamp()) + ( + name, + self.compiler.type_mapper.to_string(typ), + typ.nullable, + typ.is_timestamp(), + ) for name, typ in schema.items() ] snowflake_schema = ", ".join( - f"{self._quote(col)} {typ}{' NOT NULL' * (not nullable)}" + f"{sg.to_identifier(col, quoted=self.compiler.quoted)} {typ}{' NOT NULL' * (not nullable)}" for col, typ, nullable, _ in names_types ) cols = ", ".join( @@ -923,31 +911,29 @@ def read_parquet( for col, typ, _, is_timestamp in names_types ) - with self.begin() as con: - con.exec_driver_sql( - f"CREATE TEMP STAGE {stage} FILE_FORMAT = (TYPE = PARQUET{options})" - ) - con.exec_driver_sql( - f"PUT 'file://{abspath}' @{stage} PARALLEL = {threads:d}" - ) - con.exec_driver_sql(f"CREATE TEMP TABLE {qtable} ({snowflake_schema})") - con.exec_driver_sql( - f"COPY INTO {qtable} FROM (SELECT {cols} FROM @{stage})" - ) - - return self.table(table) + stmts = [ + f"CREATE TEMP STAGE {stage} FILE_FORMAT = (TYPE = PARQUET{options})" + f"PUT 'file://{abspath}' @{stage} PARALLEL = {threads:d}", + f"CREATE TEMP TABLE {qtable} ({snowflake_schema})", + f"COPY INTO {qtable} FROM (SELECT {cols} FROM @{stage})", + ] + with self._safe_raw_sql("; ".join(stmts)): + pass -@compiles(sa.sql.Join, "snowflake") -def compile_join(element, compiler, **kw): - """Override compilation of LATERAL joins. + return self.table(table) - Snowflake doesn't support lateral joins with ON clauses as of - https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 - even if they are trivial boolean literals. - """ - result = compiler.visit_join(element, **kw) - if element.right._is_lateral: - return re.sub(r"^(.+) ON true$", r"\1", result, flags=re.IGNORECASE | re.DOTALL) - return result +# @compiles(sa.sql.Join, "snowflake") +# def compile_join(element, compiler, **kw): +# """Override compilation of LATERAL joins. +# +# Snowflake doesn't support lateral joins with ON clauses as of +# https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 +# even if they are trivial boolean literals. +# """ +# result = compiler.visit_join(element, **kw) +# +# if element.right._is_lateral: +# return re.sub(r"^(.+) ON true$", r"\1", result, flags=re.IGNORECASE | re.DOTALL) +# return result diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py new file mode 100644 index 000000000000..041838fa4d86 --- /dev/null +++ b/ibis/backends/snowflake/compiler.py @@ -0,0 +1,428 @@ +from __future__ import annotations + +import itertools +from functools import reduce, singledispatchmethod + +import sqlglot as sg +from public import public + +import ibis.common.exceptions as com +import ibis.expr.operations as ops +from ibis.backends.base.sqlglot.compiler import NULL, SQLGlotCompiler +from ibis.backends.snowflake.datatypes import SnowflakeType +from ibis.common.deferred import _ + + +@public +class SnowflakeCompiler(SQLGlotCompiler): + __slots__ = () + + dialect = "snowflake" + quoted = True + type_mapper = SnowflakeType + + def _aggregate(self, funcname: str, *args, where): + if where is not None: + args = [self.if_(where, arg, NULL) for arg in args] + + func = self.f[funcname] + return func(*args) + + @singledispatchmethod + def visit_node(self, op, **kwargs): + return super().visit_node(op, **kwargs) + + @visit_node.register(ops.Literal) + def visit_Literal(self, op, *, value, dtype): + if value is None: + return super().visit_Literal(op, value=value, dtype=dtype) + elif dtype.is_timestamp(): + args = ( + value.year, + value.month, + value.day, + value.hour, + value.minute, + value.second, + value.microsecond * 1_000, + ) + if value.tzinfo is not None: + return self.f.timestamp_tz_from_parts(*args, dtype.timezone) + else: + return self.f.timestamp_from_parts(*args) + elif dtype.is_time(): + nanos = value.microsecond * 1_000 + return self.f.time_from_parts(value.hour, value.minute, value.second, nanos) + elif dtype.is_map() or dtype.is_struct(): + # TODO: handle conversion of keys and values to expressions + return self.f.object_construct_keep_null( + *itertools.chain.from_iterable(value.items()) + ) + elif dtype.is_uuid(): + return sg.exp.convert(str(value)) + return super().visit_node(op, value=value, dtype=dtype) + + @visit_node.register(ops.Cast) + def visit_Cast(self, op, *, arg, to): + if to.is_struct() or to.is_map(): + return self.if_(self.f.is_object(arg), arg, NULL) + elif to.is_array(): + return self.if_(self.f.is_array(arg), arg, NULL) + return self.cast(arg, to) + + @visit_node.register(ops.IsNan) + def visit_IsNan(self, op, *, arg): + return arg.eq(self.NAN) + + @visit_node.register(ops.IsInf) + def visit_IsInf(self, op, *, arg): + return arg.isin(self.POS_INF, self.NEG_INF) + + @visit_node.register(ops.JSONGetItem) + def visit_JSONGetItem(self, op, *, arg, index): + return self.f.get(arg, index) + + @visit_node.register(ops.StringFind) + def visit_StringFind(self, op, *, arg, substr, start, end): + args = [substr, arg] + if start is not None: + start += 1 + args.append(start) + return self.f.position(*args) + + def _gen_udf_name(self, name: str) -> sg.exp.Dot: + return sg.exp.Dot( + this=sg.to_identifier("ibis_udfs"), + expression=sg.exp.Dot(this="public", expression=sg.to_identifier(name)), + ) + + @visit_node.register(ops.Map) + def visit_Map(self, op, *, keys, values): + return self.if_( + sg.and_(self.f.is_array(keys), self.f.is_array(values)), + sg.func(self._gen_udf_name("object_from_arrays"), keys, values), + NULL, + ) + + @visit_node.register(ops.MapKeys) + def visit_MapKeys(self, op, *, arg): + return self.if_(self.f.is_object(arg), self.f.object_keys(arg), NULL) + + @visit_node.register(ops.MapValues) + def visit_MapValues(self, op, *, arg): + return self.if_( + self.f.is_object(arg), + sg.func(self._gen_udf_name("object_values"), arg), + NULL, + ) + + @visit_node.register(ops.MapGet) + def visit_MapGet(self, op, *, arg, key, default): + dtype = op.dtype + expr = self.f.coalesce(self.f.get(arg, key), self.f.to_variant(default)) + if dtype.is_json() or dtype.is_null(): + return expr + return self.cast(expr, dtype) + + @visit_node.register(ops.MapContains) + def visit_MapContains(self, op, *, arg, key): + return self.f.array_contains( + self.f.to_variant(key), + self.if_(self.f.is_object(arg), self.f.object_keys(arg), NULL), + ) + + @visit_node.register(ops.MapMerge) + def visit_MapMerge(self, op, *, left, right): + return self.if_( + sg.and_(self.f.is_object(left), self.f.is_object(right)), + sg.func(self._gen_udf_name("object_merge"), left, right), + NULL, + ) + + @visit_node.register(ops.MapLength) + def visit_MapLength(self, op, *, arg): + return self.if_( + self.f.is_object(arg), self.f.array_size(self.f.object_keys(arg)), NULL + ) + + @visit_node.register(ops.BitwiseBinary) + def visit_BitwiseOps(self, op, *, left, right): + funcname = type(op).__name__.lower().replace("wise", "") + return self.f[funcname](left, right) + + @visit_node.register(ops.Log2) + def visit_Log2(self, op, *, arg): + return self.f.log(2, arg) + + @visit_node.register(ops.Log10) + def visit_Log10(self, op, *, arg): + return self.f.log(10, arg) + + @visit_node.register(ops.Log) + def visit_Log(self, op, *, arg, base): + return self.f.log(base, arg, dialect=self.name) + + @visit_node.register(ops.RandomScalar) + def visit_RandomScalar(self, op): + return self.f.uniform( + self.f.to_double(0.0), self.f.to_double(1.0), self.f.random() + ) + + @visit_node.register(ops.ToJSONArray) + @visit_node.register(ops.ToJSONMap) + def visit_ToJSON(self, op, *, arg): + return self.cast(arg, op.dtype) + + @visit_node.register(ops.ApproxMedian) + def visit_ApproxMedian(self, op, *, arg): + return self.f.approx_percentile(arg, 0.5) + + @visit_node.register(ops.TimeDelta) + def visit_TimeDelta(self, op, *, part, left, right): + return self.f.timediff(part, right, left) + + @visit_node.register(ops.DateDelta) + def visit_DateDelta(self, op, *, part, left, right): + return self.f.datediff(part, right, left) + + @visit_node.register(ops.TimestampDelta) + def visit_TimestampDelta(self, op, *, part, left, right): + return self.f.timestampdiff(part, right, left) + + @visit_node.register(ops.IntegerRange) + def visit_IntegerRange(self, op, *, start, stop, step): + return self.if_( + step.ne(0), self.f.array_generate_range(start, stop, step), self.f.array() + ) + + @visit_node.register(ops.StructColumn) + def visit_StructColumn(self, op, *, names, values): + return self.f.object_construct_keep_null( + *itertools.chain.from_iterable(zip(names, values)) + ) + + @visit_node.register(ops.StructField) + def visit_StructField(self, op, *, arg, field): + return self.cast(self.f.get(arg, field), op.dtype) + + @visit_node.register(ops.RegexSearch) + def visit_RegexSearch(self, op, *, arg, pattern): + return self.f.regexp_instr(arg, pattern).ne(0) + + @visit_node.register(ops.TypeOf) + def visit_TypeOf(self, op, *, arg): + return self.f.typeof(self.f.to_variant(arg)) + + @visit_node.register(ops.ArrayRepeat) + def visit_ArrayRepeat(self, op, *, arg, times): + return sg.func(self._gen_udf_name("array_repeat"), arg, times) + + @visit_node.register(ops.ArrayUnion) + def visit_ArrayUnion(self, op, *, left, right): + return self.f.array_distinct(self.f.array_cat(left, right)) + + @visit_node.register(ops.ArrayContains) + def visit_ArrayContains(self, op, *, arg, other): + return self.f.array_contains(self.f.to_variant(other), arg) + + @visit_node.register(ops.ArrayCollect) + def visit_ArrayCollect(self, op, *, arg, where): + return self.agg.array_agg(self.f.ifnull(arg, self.f.parse_json("null"))) + + @visit_node.register(ops.ArrayConcat) + def visit_ArrayConcat(self, op, *, arg): + return reduce(self.f.array_cat, arg) + + @visit_node.register(ops.ArrayPosition) + def visit_ArrayPosition(self, op, *, arg, other): + # snowflake is zero-based here, so we don't need to subtract 1 from the + # result + return self.f.coalesce(self.f.array_position(self.f.to_variant(other), arg), -1) + + @visit_node.register(ops.RegexExtract) + def visit_RegexExtract(self, op, *, arg, pattern, index): + # https://docs.snowflake.com/en/sql-reference/functions/regexp_substr + return self.f.regexp_substr(arg, pattern, 1, 1, "ce", index) + + @visit_node.register(ops.ArrayZip) + def visit_ArrayZip(self, op, *, arg): + func = self._gen_udf_name("array_zip") + return func(self.f.array(*arg)) + + @visit_node.register(ops.DayOfWeekName) + def visit_DayOfWeekName(self, op, *, arg): + return sg.exp.Case( + this=self.f.dayname(arg), + ifs=[ + self.if_("Sun", "Sunday"), + self.if_("Mon", "Monday"), + self.if_("Tue", "Tuesday"), + self.if_("Wed", "Wednesday"), + self.if_("Thu", "Thursday"), + self.if_("Fri", "Friday"), + self.if_("Sat", "Saturday"), + ], + default=NULL, + ) + + @visit_node.register(ops.TimestampFromUNIX) + def visit_DayOfWeekName(self, op, *, arg): + timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} + return self.f.to_timestamp(arg, timestamp_units_to_scale[op.unit.short]) + + @visit_node.register(ops.First) + def visit_First(self, op, *, arg, where): + return self.f.get(self.agg.array_agg(arg, where=where), 0) + + @visit_node.register(ops.Last) + def visit_Last(self, op, *, arg, where): + expr = self.agg.array_agg(arg, where=where) + offset = self.f.array_size(expr) - 1 + return self.f.get(expr, offset) + + @visit_node.register(ops.GroupConcat) + def visit_GroupConcat(self, op, *, arg, where, sep): + if where is None: + return self.f.listagg(arg, sep) + + arg = self.if_(where, arg, None) + + return self.if_( + self.f.count_if(arg.is_(sg.not_(NULL))) != 0, self.f.listagg(arg, sep), NULL + ) + + @visit_node.register(ops.TimestampBucket) + def visit_TimestampBucket(self, op, *, arg, interval, offset): + if offset is not None: + raise com.UnsupportedOperationError( + "`offset` is not supported in the Snowflake backend for timestamp bucketing" + ) + + interval = op.interval + if not isinstance(interval, sg.exp.Literal): + raise com.UnsupportedOperationError( + f"Interval must be a literal for the Snowflake backend, got {type(interval)}" + ) + + return self.f.time_slice(arg, interval.value, interval.dtype.unit.name) + + @visit_node.register(ops.Arbitrary) + def visit_Arbitrary(self, op, *, arg, how, where): + if how == "first": + return self.f.get(self.agg.array_agg(arg), 0) + elif how == "last": + expr = self.agg.array_agg(arg) + return self.f.get(expr, self.f.array_size(expr) - 1) + else: + raise com.UnsupportedOperationError("how must be 'first' or 'last'") + + @visit_node.register(ops.ArraySlice) + def visit_ArraySclie(self, op, *, arg, start, stop): + if start is None: + start = 0 + + if stop is None: + stop = self.f.array_size(arg) + return self.f.array_slice(arg, start, stop) + + @visit_node.register(ops.ExtractMicrosecond) + def visit_ExtractMicrosecond(self, op, *, arg): + return self.cast(self.f.extract("epoch_microsecond", arg) % 1_000_000, op.dtype) + + @visit_node.register(ops.ExtractMillisecond) + def visit_ExtractMillisecond(self, op, *, arg): + return self.cast(self.f.extract("epoch_millisecond", arg) % 1_000, op.dtype) + + @visit_node.register(ops.ExtractQuery) + def visit_ExtractQuery(self, op, *, arg, key): + parsed_url = self.f.parse_url(arg, 1) + if key is not None: + r = self.f.get(self.f.get(parsed_url, "parameters"), key) + else: + r = self.f.get(parsed_url, "query") + return self.f.nullif(self.f.as_varchar(r), "") + + @visit_node.register(ops.ExtractProtocol) + def visit_ExtractProtocol(self, op, *, arg): + return self.f.nullif( + self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "scheme")), "" + ) + + @visit_node.register(ops.ExtractAuthority) + def visit_ExtractAuthority(self, op, *, arg): + return self.f.concat_ws( + ":", + self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "host")), + self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "port")), + ) + + @visit_node.register(ops.ExtractFile) + def visit_ExtractFile(self, op, *, arg, key): + return self.f.concat_ws( + "?", + self.visit_ExtractPath(op, arg=arg, key=key), + self.visit_ExtractQuery(op, arg=arg, key=None), + ) + + @visit_node.register(ops.ExtractPath) + def visit_ExtractPath(self, op, *, arg, key): + return "/" + self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "path")) + + @visit_node.register(ops.ExtractFragment) + def visit_ExtractFragment(self, op, *, arg, key): + return self.f.nullif( + self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "fragment")), "" + ) + + @visit_node.register(ops.Unnest) + def visit_Unnest(self, op, *, arg): + return sg.exp.Explode(this=arg) + + +_SIMPLE_OPS = { + ops.Mode: "mode", + ops.TimeFromHMS: "time_from_parts", + ops.ArrayIndex: "get", + ops.ArrayLength: "array_size", + ops.ArrayDistinct: "array_distinct", + ops.ArrayRemove: "array_remove", + ops.ArrayIntersect: "array_intersection", + ops.ArraySort: "array_sort", + ops.ArrayFlatten: "array_flatten", + ops.StringSplit: "split", + ops.All: "booland_agg", + ops.Any: "boolor_agg", + ops.BitAnd: "bitand_agg", + ops.BitOr: "bitor_agg", + ops.BitXor: "bitxor_agg", + ops.DateFromYMD: "date_from_parts", + ops.StringToTimestamp: "to_timestamp_tz", + ops.RegexReplace: "regex_replace", + ops.ArgMin: "min_by", + ops.ArgMax: "max_by", + ops.StartsWith: "startswith", + ops.EndsWith: "endswith", + ops.Hash: "hash", + ops.Median: "median", + ops.Levenshtein: "editdistance", + ops.TimestampFromYMDHMS: "timestamp_from_parts", +} + +for _op, _name in _SIMPLE_OPS.items(): + assert isinstance(type(_op), type), type(_op) + if issubclass(_op, ops.Reduction): + + @SnowflakeCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, where, **kw): + return self.agg[_name](*kw.values(), where=where) + + else: + + @SnowflakeCompiler.visit_node.register(_op) + def _fmt(self, op, *, _name: str = _name, **kw): + return self.f[_name](*kw.values()) + + setattr(SnowflakeCompiler, f"visit_{_op.__name__}", _fmt) + + +del _op, _name, _fmt diff --git a/ibis/backends/snowflake/datatypes.py b/ibis/backends/snowflake/datatypes.py index f93b0cc9b855..0d12b3e94a74 100644 --- a/ibis/backends/snowflake/datatypes.py +++ b/ibis/backends/snowflake/datatypes.py @@ -1,78 +1,84 @@ from __future__ import annotations -import sqlalchemy.types as sat -from snowflake.sqlalchemy import ( - ARRAY, - OBJECT, - TIMESTAMP_LTZ, - TIMESTAMP_NTZ, - TIMESTAMP_TZ, - VARIANT, -) -from sqlalchemy.ext.compiler import compiles - import ibis.expr.datatypes as dt -from ibis.backends.base.sql.alchemy.datatypes import AlchemyType -from ibis.backends.base.sqlglot.datatypes import SnowflakeType as SqlglotSnowflakeType - - -@compiles(sat.NullType, "snowflake") -def compiles_nulltype(element, compiler, **kw): - return "VARIANT" +from ibis.backends.base.sqlglot.datatypes import SqlglotType -class SnowflakeType(AlchemyType): +class SnowflakeType(SqlglotType): dialect = "snowflake" + default_temporal_scale = 9 @classmethod - def from_ibis(cls, dtype): - if dtype.is_array(): - return ARRAY - elif dtype.is_map() or dtype.is_struct(): - return OBJECT - elif dtype.is_json(): - return VARIANT - elif dtype.is_timestamp(): - if dtype.timezone is None: - return TIMESTAMP_NTZ - else: - return TIMESTAMP_TZ - elif dtype.is_string(): - # 16MB - return sat.VARCHAR(2**24) - elif dtype.is_binary(): - # 8MB - return sat.VARBINARY(2**23) - else: - return super().from_ibis(dtype) + def _from_sqlglot_FLOAT(cls) -> dt.Float64: + return dt.Float64(nullable=cls.default_nullable) @classmethod - def to_ibis(cls, typ, nullable=True): - if isinstance(typ, (sat.REAL, sat.FLOAT, sat.Float)): - return dt.Float64(nullable=nullable) - elif isinstance(typ, TIMESTAMP_NTZ): - return dt.Timestamp(timezone=None, nullable=nullable) - elif isinstance(typ, (TIMESTAMP_LTZ, TIMESTAMP_TZ)): - return dt.Timestamp(timezone="UTC", nullable=nullable) - elif isinstance(typ, ARRAY): - return dt.Array(dt.json, nullable=nullable) - elif isinstance(typ, OBJECT): - return dt.Map(dt.string, dt.json, nullable=nullable) - elif isinstance(typ, VARIANT): - return dt.JSON(nullable=nullable) - elif isinstance(typ, sat.Numeric): - if (scale := typ.scale) == 0: - # kind of a lie, should be int128 because 38 digits - return dt.Int64(nullable=nullable) - else: - return dt.Decimal( - precision=typ.precision or 38, - scale=scale or 0, - nullable=nullable, - ) + def _from_sqlglot_DECIMAL(cls, precision=None, scale=None) -> dt.Decimal: + if scale is None or int(scale.this.this) == 0: + return dt.Int64(nullable=cls.default_nullable) else: - return super().to_ibis(typ, nullable=nullable) + return super()._from_sqlglot_DECIMAL(precision, scale) @classmethod - def from_string(cls, type_string, nullable=True): - return SqlglotSnowflakeType.from_string(type_string, nullable=nullable) + def _from_sqlglot_ARRAY(cls, value_type=None) -> dt.Array: + assert value_type is None + return dt.Array(dt.json, nullable=cls.default_nullable) + + +# class SnowflakeType( +# SqlglotSnowflakeType, +# ): +# dialect = "snowflake" +# +# @classmethod +# def from_ibis(cls, dtype): +# if dtype.is_array(): +# return ARRAY +# elif dtype.is_map() or dtype.is_struct(): +# return OBJECT +# elif dtype.is_json(): +# return VARIANT +# elif dtype.is_timestamp(): +# if dtype.timezone is None: +# return TIMESTAMP_NTZ +# else: +# return TIMESTAMP_TZ +# elif dtype.is_string(): +# # 16MB +# return sat.VARCHAR(2**24) +# elif dtype.is_binary(): +# # 8MB +# return sat.VARBINARY(2**23) +# else: +# return super().from_ibis(dtype) +# +# @classmethod +# def to_ibis(cls, typ, nullable=True): +# if isinstance(typ, (sat.REAL, sat.FLOAT, sat.Float)): +# return dt.Float64(nullable=nullable) +# elif isinstance(typ, TIMESTAMP_NTZ): +# return dt.Timestamp(timezone=None, nullable=nullable) +# elif isinstance(typ, (TIMESTAMP_LTZ, TIMESTAMP_TZ)): +# return dt.Timestamp(timezone="UTC", nullable=nullable) +# elif isinstance(typ, ARRAY): +# return dt.Array(dt.json, nullable=nullable) +# elif isinstance(typ, OBJECT): +# return dt.Map(dt.string, dt.json, nullable=nullable) +# elif isinstance(typ, VARIANT): +# return dt.JSON(nullable=nullable) +# elif isinstance(typ, sat.Numeric): +# if (scale := typ.scale) == 0: +# # kind of a lie, should be int128 because 38 digits +# return dt.Int64(nullable=nullable) +# else: +# return dt.Decimal( +# precision=typ.precision or 38, +# scale=scale or 0, +# nullable=nullable, +# ) +# else: +# return super().to_ibis(typ, nullable=nullable) +# +# @classmethod +# def from_string(cls, type_string, nullable=True): +# return SqlglotSnowflakeType.from_string(type_string, nullable=nullable) diff --git a/ibis/backends/snowflake/registry.py b/ibis/backends/snowflake/registry.py index 8db1096a7831..23b143e9a10b 100644 --- a/ibis/backends/snowflake/registry.py +++ b/ibis/backends/snowflake/registry.py @@ -5,7 +5,6 @@ import numpy as np import sqlalchemy as sa -from snowflake.sqlalchemy import ARRAY, OBJECT, VARIANT from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.elements import Cast @@ -21,7 +20,6 @@ unary, varargs, ) -from ibis.backends.postgres.registry import _literal as _postgres_literal from ibis.backends.postgres.registry import operation_registry as _operation_registry operation_registry = { @@ -30,475 +28,28 @@ } -def _literal(t, op): - value = op.value - dtype = op.dtype +# def _unnest(t, op): +# arg = t.translate(op.arg) +# # HACK: https://community.snowflake.com/s/question/0D50Z000086MVhnSAG/has-anyone-found-a-way-to-unnest-an-array-without-loosing-the-null-values +# sep = util.guid() +# col = sa.func.nullif( +# sa.func.split_to_table(sa.func.array_to_string(arg, sep), sep) +# .table_valued("value") # seq, index, value is supported but we only need value +# .lateral() +# .c["value"], +# "", +# ) +# return sa.cast( +# sa.func.coalesce(sa.func.try_parse_json(col), sa.func.to_variant(col)), +# type_=t.get_sqla_type(op.dtype), +# ) - if value is None: - return sa.null() - - if dtype.is_floating(): - if np.isnan(value): - return _SF_NAN - - if np.isinf(value): - return _SF_NEG_INF if value < 0 else _SF_POS_INF - elif dtype.is_timestamp(): - args = ( - value.year, - value.month, - value.day, - value.hour, - value.minute, - value.second, - value.microsecond * 1_000, - ) - if value.tzinfo is not None: - return sa.func.timestamp_tz_from_parts(*args, dtype.timezone) - else: - return sa.func.timestamp_from_parts(*args) - elif dtype.is_date(): - return sa.func.date_from_parts(value.year, value.month, value.day) - elif dtype.is_time(): - nanos = value.microsecond * 1_000 - return sa.func.time_from_parts(value.hour, value.minute, value.second, nanos) - elif dtype.is_array(): - return sa.func.array_construct(*value) - elif dtype.is_map() or dtype.is_struct(): - return sa.func.object_construct_keep_null( - *itertools.chain.from_iterable(value.items()) - ) - elif dtype.is_uuid(): - return sa.literal(str(value)) - return _postgres_literal(t, op) - - -def _table_column(t, op): - ctx = t.context - table = op.table - - sa_table = get_sqla_table(ctx, table) - out_expr = get_col(sa_table, op) - - if (dtype := op.dtype).is_timestamp() and (timezone := dtype.timezone) is not None: - out_expr = sa.func.convert_timezone(timezone, out_expr).label(op.name) - - # If the column does not originate from the table set in the current SELECT - # context, we should format as a subquery - if t.permit_subquery and ctx.is_foreign_expr(table): - return sa.select(out_expr) - - return out_expr - - -def _string_find(t, op): - args = [t.translate(op.substr), t.translate(op.arg)] - if (start := op.start) is not None: - args.append(t.translate(start) + 1) - return sa.func.position(*args) - 1 - - -def _round(t, op): - args = [t.translate(op.arg)] - if (digits := op.digits) is not None: - args.append(t.translate(digits)) - return sa.func.round(*args) - - -def _day_of_week_name(arg): - return sa.case( - ("Sun", "Sunday"), - ("Mon", "Monday"), - ("Tue", "Tuesday"), - ("Wed", "Wednesday"), - ("Thu", "Thursday"), - ("Fri", "Friday"), - ("Sat", "Saturday"), - value=sa.func.dayname(arg), - else_=None, - ) - - -def _extract_url_query(t, op): - parsed_url = sa.func.parse_url(t.translate(op.arg), 1) - - if (key := op.key) is not None: - r = sa.func.get(sa.func.get(parsed_url, "parameters"), t.translate(key)) - else: - r = sa.func.get(parsed_url, "query") - - return sa.func.nullif(sa.func.as_varchar(r), "") - - -def _array_slice(t, op): - arg = t.translate(op.arg) - - if (start := op.start) is not None: - start = t.translate(start) - else: - start = 0 - - if (stop := op.stop) is not None: - stop = t.translate(stop) - else: - stop = sa.func.array_size(arg) - - return sa.func.array_slice(t.translate(op.arg), start, stop) - - -def _nth_value(t, op): - if not isinstance(nth := op.nth, ops.Literal): - raise TypeError(f"`nth` argument must be a literal Python int, got {type(nth)}") - return sa.func.nth_value(t.translate(op.arg), nth.value + 1) - - -def _arbitrary(t, op): - if (how := op.how) == "first": - return t._reduction(lambda x: sa.func.get(sa.func.array_agg(x), 0), op) - elif how == "last": - return t._reduction( - lambda x: sa.func.get( - sa.func.array_agg(x), sa.func.array_size(sa.func.array_agg(x)) - 1 - ), - op, - ) - else: - raise com.UnsupportedOperationError("how must be 'first' or 'last'") - - -@compiles(Cast, "snowflake") -def compiles_cast(element, compiler, **kw): - typ = compiler.visit_typeclause(element, **kw) - if typ in ("OBJECT", "ARRAY"): - arg = compiler.process(element.clause, **kw) - return f"IFF(IS_{typ}({arg}), {arg}, NULL)" - return compiler.visit_cast(element, **kw) - - -@compiles(sa.TEXT, "snowflake") -@compiles(sa.VARCHAR, "snowflake") -def compiles_string(element, compiler, **kw): - return "VARCHAR" - - -@compiles(OBJECT, "snowflake") -@compiles(ARRAY, "snowflake") -@compiles(VARIANT, "snowflake") -def compiles_object_type(element, compiler, **kw): - return type(element).__name__.upper() - - -def _unnest(t, op): - arg = t.translate(op.arg) - # HACK: https://community.snowflake.com/s/question/0D50Z000086MVhnSAG/has-anyone-found-a-way-to-unnest-an-array-without-loosing-the-null-values - sep = util.guid() - col = sa.func.nullif( - sa.func.split_to_table(sa.func.array_to_string(arg, sep), sep) - .table_valued("value") # seq, index, value is supported but we only need value - .lateral() - .c["value"], - "", - ) - return sa.cast( - sa.func.coalesce(sa.func.try_parse_json(col), sa.func.to_variant(col)), - type_=t.get_sqla_type(op.dtype), - ) - - -def _group_concat(t, op): - if (where := op.where) is None: - return sa.func.listagg(t.translate(op.arg), t.translate(op.sep)) - - where_sa = t.translate(where) - arg_sa = sa.func.iff(where_sa, t.translate(op.arg), None) - - return sa.func.iff( - sa.func.count_if(arg_sa != sa.null()) != 0, - sa.func.listagg(arg_sa, t.translate(op.sep)), - None, - ) - - -def _array_zip(t, op): - return sa.type_coerce( - sa.func.ibis_udfs.public.array_zip( - sa.func.array_construct(*map(t.translate, op.arg)) - ), - t.get_sqla_type(op.dtype), - ) - - -def _regex_extract(t, op): - arg = t.translate(op.arg) - pattern = t.translate(op.pattern) - index = t.translate(op.index) - # https://docs.snowflake.com/en/sql-reference/functions/regexp_substr - return sa.func.regexp_substr(arg, pattern, 1, 1, "ce", index) - - -def _map_get(t, op): - arg = op.arg - key = op.key - default = op.default - dtype = op.dtype - sqla_type = t.get_sqla_type(dtype) - expr = sa.func.coalesce( - sa.func.get(t.translate(arg), t.translate(key)), - sa.func.to_variant(t.translate(default)), - type_=sqla_type, - ) - if dtype.is_json() or dtype.is_null(): - return expr - - # cast if ibis thinks the value type is not JSON - # - # this ensures that we can get deserialized map values even though maps are - # always JSON in the value type inside snowflake - return sa.cast(expr, sqla_type) - - -def _timestamp_bucket(t, op): - if op.offset is not None: - raise com.UnsupportedOperationError( - "`offset` is not supported in the Snowflake backend for timestamp bucketing" - ) - - interval = op.interval - - if not isinstance(interval, ops.Literal): - raise com.UnsupportedOperationError( - f"Interval must be a literal for the Snowflake backend, got {type(interval)}" - ) - - return sa.func.time_slice( - t.translate(op.arg), interval.value, interval.dtype.unit.name - ) - - -_TIMESTAMP_UNITS_TO_SCALE = {"s": 0, "ms": 3, "us": 6, "ns": 9} - -_SF_POS_INF = sa.func.to_double("Inf") -_SF_NEG_INF = sa.func.to_double("-Inf") -_SF_NAN = sa.func.to_double("NaN") operation_registry.update( { - ops.JSONGetItem: fixed_arity(sa.func.get, 2), - ops.StringFind: _string_find, - ops.Map: fixed_arity( - lambda keys, values: sa.func.iff( - sa.func.is_array(keys) & sa.func.is_array(values), - sa.func.ibis_udfs.public.object_from_arrays(keys, values), - sa.null(), - ), - 2, - ), - ops.MapKeys: unary( - lambda arg: sa.func.iff( - sa.func.is_object(arg), sa.func.object_keys(arg), sa.null() - ) - ), - ops.MapValues: unary( - lambda arg: sa.func.iff( - sa.func.is_object(arg), - sa.func.ibis_udfs.public.object_values(arg), - sa.null(), - ) - ), - ops.MapGet: _map_get, - ops.MapContains: fixed_arity( - lambda arg, key: sa.func.array_contains( - sa.func.to_variant(key), - sa.func.iff( - sa.func.is_object(arg), sa.func.object_keys(arg), sa.null() - ), - ), - 2, - ), - ops.MapMerge: fixed_arity( - lambda a, b: sa.func.iff( - sa.func.is_object(a) & sa.func.is_object(b), - sa.func.ibis_udfs.public.object_merge(a, b), - sa.null(), - ), - 2, - ), - ops.MapLength: unary( - lambda arg: sa.func.array_size( - sa.func.iff(sa.func.is_object(arg), sa.func.object_keys(arg), sa.null()) - ) - ), - ops.BitwiseAnd: fixed_arity(sa.func.bitand, 2), - ops.BitwiseNot: unary(sa.func.bitnot), - ops.BitwiseOr: fixed_arity(sa.func.bitor, 2), - ops.BitwiseXor: fixed_arity(sa.func.bitxor, 2), - ops.BitwiseLeftShift: fixed_arity(sa.func.bitshiftleft, 2), - ops.BitwiseRightShift: fixed_arity(sa.func.bitshiftright, 2), - ops.Ln: unary(sa.func.ln), - ops.Log2: unary(lambda arg: sa.func.log(2, arg)), - ops.Log10: unary(lambda arg: sa.func.log(10, arg)), - ops.Log: fixed_arity(lambda arg, base: sa.func.log(base, arg), 2), - ops.IsInf: unary(lambda arg: arg.in_((_SF_POS_INF, _SF_NEG_INF))), - ops.IsNan: unary(lambda arg: arg == _SF_NAN), - ops.Literal: _literal, - ops.Round: _round, - ops.Modulus: fixed_arity(sa.func.mod, 2), - ops.Mode: reduction(sa.func.mode), - ops.IfElse: fixed_arity(sa.func.iff, 3), # numbers - ops.RandomScalar: fixed_arity( - lambda: sa.func.uniform( - sa.func.to_double(0.0), sa.func.to_double(1.0), sa.func.random() - ), - 0, - ), - # time and dates - ops.TimeFromHMS: fixed_arity(sa.func.time_from_parts, 3), # columns - ops.DayOfWeekName: unary(_day_of_week_name), - ops.ExtractProtocol: unary( - lambda arg: sa.func.nullif( - sa.func.as_varchar(sa.func.get(sa.func.parse_url(arg, 1), "scheme")), "" - ) - ), - ops.ExtractAuthority: unary( - lambda arg: sa.func.concat_ws( - ":", - sa.func.as_varchar(sa.func.get(sa.func.parse_url(arg, 1), "host")), - sa.func.as_varchar(sa.func.get(sa.func.parse_url(arg, 1), "port")), - ) - ), - ops.ExtractFile: unary( - lambda arg: sa.func.concat_ws( - "?", - "/" - + sa.func.as_varchar(sa.func.get(sa.func.parse_url(arg, 1), "path")), - sa.func.as_varchar(sa.func.get(sa.func.parse_url(arg, 1), "query")), - ) - ), - ops.ExtractPath: unary( - lambda arg: ( - "/" + sa.func.as_varchar(sa.func.get(sa.func.parse_url(arg, 1), "path")) - ) - ), - ops.ExtractQuery: _extract_url_query, - ops.ExtractFragment: unary( - lambda arg: sa.func.nullif( - sa.func.as_varchar(sa.func.get(sa.func.parse_url(arg, 1), "fragment")), - "", - ) - ), - ops.ArrayIndex: fixed_arity(sa.func.get, 2), - ops.ArrayLength: fixed_arity(sa.func.array_size, 1), - ops.ArrayConcat: varargs( - lambda *args: functools.reduce(sa.func.array_cat, args) - ), - ops.ArrayColumn: lambda t, op: sa.func.array_construct( - *map(t.translate, op.cols) - ), - ops.ArraySlice: _array_slice, - ops.ArrayCollect: reduction( - lambda arg: sa.func.array_agg( - sa.func.ifnull(arg, sa.func.parse_json("null")), type_=ARRAY - ) - ), - ops.ArrayContains: fixed_arity( - lambda arr, el: sa.func.array_contains(sa.func.to_variant(el), arr), 2 - ), - ops.ArrayPosition: fixed_arity( - # snowflake is zero-based here, so we don't need to subtract 1 from the result - lambda lst, el: sa.func.coalesce( - sa.func.array_position(sa.func.to_variant(el), lst), -1 - ), - 2, - ), - ops.ArrayDistinct: fixed_arity(sa.func.array_distinct, 1), - ops.ArrayUnion: fixed_arity( - lambda left, right: sa.func.array_distinct(sa.func.array_cat(left, right)), - 2, - ), - ops.ArrayRemove: fixed_arity(sa.func.array_remove, 2), - ops.ArrayIntersect: fixed_arity(sa.func.array_intersection, 2), - ops.ArrayZip: _array_zip, - ops.ArraySort: unary(sa.func.array_sort), - ops.ArrayRepeat: fixed_arity(sa.func.ibis_udfs.public.array_repeat, 2), - ops.ArrayFlatten: fixed_arity(sa.func.array_flatten, 1), - ops.StringSplit: fixed_arity(sa.func.split, 2), - # snowflake typeof only accepts VARIANT, so we cast - ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.func.to_variant(arg))), - ops.All: reduction(sa.func.booland_agg), - ops.Any: reduction(sa.func.boolor_agg), - ops.BitAnd: reduction(sa.func.bitand_agg), - ops.BitOr: reduction(sa.func.bitor_agg), - ops.BitXor: reduction(sa.func.bitxor_agg), - ops.DateFromYMD: fixed_arity(sa.func.date_from_parts, 3), - ops.StringToTimestamp: fixed_arity(sa.func.to_timestamp_tz, 2), - ops.RegexExtract: _regex_extract, - ops.RegexSearch: fixed_arity( - lambda arg, pattern: sa.func.regexp_instr(arg, pattern) != 0, 2 - ), - ops.RegexReplace: fixed_arity(sa.func.regexp_replace, 3), - ops.ExtractMicrosecond: fixed_arity( - lambda arg: sa.cast( - sa.extract("epoch_microsecond", arg) % 1000000, sa.SMALLINT - ), - 1, - ), - ops.ExtractMillisecond: fixed_arity( - lambda arg: sa.cast( - sa.extract("epoch_millisecond", arg) % 1000, sa.SMALLINT - ), - 1, - ), - ops.TimestampFromYMDHMS: fixed_arity(sa.func.timestamp_from_parts, 6), - ops.TimestampFromUNIX: lambda t, op: sa.func.to_timestamp( - t.translate(op.arg), _TIMESTAMP_UNITS_TO_SCALE[op.unit.short] - ), - ops.StructField: lambda t, op: sa.cast( - sa.func.get(t.translate(op.arg), op.field), t.get_sqla_type(op.dtype) - ), - ops.NthValue: _nth_value, - ops.Arbitrary: _arbitrary, - ops.First: reduction(lambda x: sa.func.get(sa.func.array_agg(x), 0)), - ops.Last: reduction( - lambda x: sa.func.get( - sa.func.array_agg(x), sa.func.array_size(sa.func.array_agg(x)) - 1 - ) - ), - ops.StructColumn: lambda t, op: sa.func.object_construct_keep_null( - *itertools.chain.from_iterable(zip(op.names, map(t.translate, op.values))) - ), ops.Unnest: _unnest, - ops.ArgMin: reduction(sa.func.min_by), - ops.ArgMax: reduction(sa.func.max_by), - ops.ToJSONArray: lambda t, op: t.translate(ops.Cast(op.arg, op.dtype)), - ops.ToJSONMap: lambda t, op: t.translate(ops.Cast(op.arg, op.dtype)), - ops.StartsWith: fixed_arity(sa.func.startswith, 2), - ops.EndsWith: fixed_arity(sa.func.endswith, 2), - ops.GroupConcat: _group_concat, - ops.Hash: unary(sa.func.hash), - ops.ApproxMedian: reduction(lambda x: sa.func.approx_percentile(x, 0.5)), - ops.Median: reduction(sa.func.median), - ops.TableColumn: _table_column, - ops.Levenshtein: fixed_arity(sa.func.editdistance, 2), - ops.TimeDelta: fixed_arity( - lambda part, left, right: sa.func.timediff(part, right, left), 3 - ), - ops.DateDelta: fixed_arity( - lambda part, left, right: sa.func.datediff(part, right, left), 3 - ), - ops.TimestampDelta: fixed_arity( - lambda part, left, right: sa.func.timestampdiff(part, right, left), 3 - ), - ops.TimestampBucket: _timestamp_bucket, - ops.IntegerRange: fixed_arity( - lambda start, stop, step: sa.func.iff( - step != 0, - sa.func.array_generate_range(start, stop, step), - sa.func.array_construct(), - ), - 3, - ), } ) diff --git a/ibis/backends/snowflake/tests/conftest.py b/ibis/backends/snowflake/tests/conftest.py index 4faac50223a0..507aa8524430 100644 --- a/ibis/backends/snowflake/tests/conftest.py +++ b/ibis/backends/snowflake/tests/conftest.py @@ -3,10 +3,12 @@ import concurrent.futures import os from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qs, urlparse import pyarrow.parquet as pq import pyarrow_hotfix # noqa: F401 import pytest +import snowflake.connector as sc import sqlalchemy as sa import sqlglot as sg @@ -53,8 +55,8 @@ def copy_into(con, data_dir: Path, table: str) -> None: f"$1:{name}{'::VARCHAR' * typ.is_timestamp()}::{SnowflakeType.to_string(typ)}" for name, typ in schema.items() ) - con.exec_driver_sql(f"PUT {file.as_uri()} @{stage}/{file.name}") - con.exec_driver_sql( + con.execute(f"PUT {file.as_uri()} @{stage}/{file.name}") + con.execute( f""" COPY INTO {table} FROM (SELECT {columns} FROM @{stage}/{file.name}) @@ -66,7 +68,7 @@ def copy_into(con, data_dir: Path, table: str) -> None: class TestConf(BackendTest, RoundAwayFromZero): supports_map = True default_identifier_case_fn = staticmethod(str.upper) - deps = ("snowflake.connector", "snowflake.sqlalchemy") + deps = ("snowflake.connector",) supports_tpch = True def load_tpch(self) -> None: @@ -96,41 +98,61 @@ def add_catalog_and_schema(node): def _load_data(self, **_: Any) -> None: """Load test data into a Snowflake backend instance.""" - snowflake_url = _get_url() - - raw_url = sa.engine.make_url(snowflake_url) - _, schema = raw_url.database.rsplit("/", 1) - url = raw_url.set(database="") - con = sa.create_engine( - url, connect_args={"session_parameters": {"MULTI_STATEMENT_COUNT": "0"}} - ) - - dbschema = f"ibis_testing.{schema}" - - with con.begin() as c: - c.exec_driver_sql( - f"""\ -CREATE DATABASE IF NOT EXISTS ibis_testing; -USE DATABASE ibis_testing; -CREATE SCHEMA IF NOT EXISTS {dbschema}; -USE SCHEMA {dbschema}; -CREATE TEMP STAGE ibis_testing; -{self.script_dir.joinpath("snowflake.sql").read_text()}""" - ) - - with con.begin() as c: - # not much we can do to make this faster, but running these in - # multiple threads seems to save about 2x - with concurrent.futures.ThreadPoolExecutor() as exe: - for future in concurrent.futures.as_completed( - exe.submit(copy_into, c, self.data_dir, table) - for table in TEST_TABLES.keys() - ): - future.result() +# url = urlparse(_get_url()) +# _, schema = url.path[1:].split("/", 1) +# (warehouse,) = parse_qs(url.query)["warehouse"] +# connect_args = { +# "user": url.username, +# "password": url.password, +# "account": url.hostname, +# "warehouse": warehouse, +# } +# +# session_parameters = { +# "MULTI_STATEMENT_COUNT": 0, +# "JSON_INDENT": 0, +# "PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "arrow_force", +# } +# +# con = sc.connect(**connect_args, session_parameters=session_parameters) +# +# dbschema = f"ibis_testing.{schema}" +# +# with con.cursor() as c: +# c.execute( +# f"""\ +# CREATE DATABASE IF NOT EXISTS ibis_testing; +# USE DATABASE ibis_testing; +# CREATE SCHEMA IF NOT EXISTS {dbschema}; +# USE SCHEMA {dbschema}; +# CREATE TEMP STAGE ibis_testing; +# {self.script_dir.joinpath("snowflake.sql").read_text()}""" +# ) +# +# with con.cursor() as c: +# # not much we can do to make this faster, but running these in +# # multiple threads seems to save about 2x +# with concurrent.futures.ThreadPoolExecutor() as exe: +# for future in concurrent.futures.as_completed( +# exe.submit(copy_into, c, self.data_dir, table) +# for table in TEST_TABLES.keys() +# ): +# future.result() @staticmethod def connect(*, tmpdir, worker_id, **kw) -> BaseBackend: - return ibis.connect(_get_url(), **kw) + url = urlparse(_get_url()) + database, schema = url.path[1:].split("/", 1) + (warehouse,) = parse_qs(url.query)["warehouse"] + connect_args = { + "user": url.username, + "password": url.password, + "account": url.hostname, + "database": database, + "schema": schema, + "warehouse": warehouse, + } + return ibis.snowflake.connect(**connect_args, **kw) @pytest.fixture(scope="session") diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index ef4c4bb95160..eb5d1bd8858e 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -370,6 +370,7 @@ def test_unnest_simple(backend): ) expr = array_types.x.cast("!array").unnest() result = expr.execute().astype("Float64").rename("tmp") + breakpoint() tm.assert_series_equal(result, expected) diff --git a/poetry.lock b/poetry.lock index 519c4c006778..f72ea5c54df2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3093,71 +3093,70 @@ files = [ [[package]] name = "pandas" -version = "2.1.3" +version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false -python-versions = ">=3.9" +python-versions = ">=3.8" files = [ - {file = "pandas-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:acf08a73b5022b479c1be155d4988b72f3020f308f7a87c527702c5f8966d34f"}, - {file = "pandas-2.1.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3cc4469ff0cf9aa3a005870cb49ab8969942b7156e0a46cc3f5abd6b11051dfb"}, - {file = "pandas-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35172bff95f598cc5866c047f43c7f4df2c893acd8e10e6653a4b792ed7f19bb"}, - {file = "pandas-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59dfe0e65a2f3988e940224e2a70932edc964df79f3356e5f2997c7d63e758b4"}, - {file = "pandas-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0296a66200dee556850d99b24c54c7dfa53a3264b1ca6f440e42bad424caea03"}, - {file = "pandas-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:465571472267a2d6e00657900afadbe6097c8e1dc43746917db4dfc862e8863e"}, - {file = "pandas-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:04d4c58e1f112a74689da707be31cf689db086949c71828ef5da86727cfe3f82"}, - {file = "pandas-2.1.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fa2ad4ff196768ae63a33f8062e6838efed3a319cf938fdf8b95e956c813042"}, - {file = "pandas-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4441ac94a2a2613e3982e502ccec3bdedefe871e8cea54b8775992485c5660ef"}, - {file = "pandas-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5ded6ff28abbf0ea7689f251754d3789e1edb0c4d0d91028f0b980598418a58"}, - {file = "pandas-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fca5680368a5139d4920ae3dc993eb5106d49f814ff24018b64d8850a52c6ed2"}, - {file = "pandas-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:de21e12bf1511190fc1e9ebc067f14ca09fccfb189a813b38d63211d54832f5f"}, - {file = "pandas-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a5d53c725832e5f1645e7674989f4c106e4b7249c1d57549023ed5462d73b140"}, - {file = "pandas-2.1.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7cf4cf26042476e39394f1f86868d25b265ff787c9b2f0d367280f11afbdee6d"}, - {file = "pandas-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:72c84ec1b1d8e5efcbff5312abe92bfb9d5b558f11e0cf077f5496c4f4a3c99e"}, - {file = "pandas-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f539e113739a3e0cc15176bf1231a553db0239bfa47a2c870283fd93ba4f683"}, - {file = "pandas-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fc77309da3b55732059e484a1efc0897f6149183c522390772d3561f9bf96c00"}, - {file = "pandas-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:08637041279b8981a062899da0ef47828df52a1838204d2b3761fbd3e9fcb549"}, - {file = "pandas-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b99c4e51ef2ed98f69099c72c75ec904dd610eb41a32847c4fcbc1a975f2d2b8"}, - {file = "pandas-2.1.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f7ea8ae8004de0381a2376662c0505bb0a4f679f4c61fbfd122aa3d1b0e5f09d"}, - {file = "pandas-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcd76d67ca2d48f56e2db45833cf9d58f548f97f61eecd3fdc74268417632b8a"}, - {file = "pandas-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1329dbe93a880a3d7893149979caa82d6ba64a25e471682637f846d9dbc10dd2"}, - {file = "pandas-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:321ecdb117bf0f16c339cc6d5c9a06063854f12d4d9bc422a84bb2ed3207380a"}, - {file = "pandas-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:11a771450f36cebf2a4c9dbd3a19dfa8c46c4b905a3ea09dc8e556626060fe71"}, - {file = "pandas-2.1.3.tar.gz", hash = "sha256:22929f84bca106921917eb73c1521317ddd0a4c71b395bcf767a106e3494209f"}, + {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, + {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0c6f76a0f1ba361551f3e6dceaff06bde7514a374aa43e33b588ec10420183"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba619e410a21d8c387a1ea6e8a0e49bb42216474436245718d7f2e88a2f8d7c0"}, + {file = "pandas-2.0.3-cp310-cp310-win32.whl", hash = "sha256:3ef285093b4fe5058eefd756100a367f27029913760773c8bf1d2d8bebe5d210"}, + {file = "pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b084b91d8d66ab19f5bb3256cbd5ea661848338301940e17f4492b2ce0801fe8"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:37673e3bdf1551b95bf5d4ce372b37770f9529743d2498032439371fc7b7eb26"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9cb1e14fdb546396b7e1b923ffaeeac24e4cedd14266c3497216dd4448e4f2d"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd88488cceb7635aebb84809d087468eb33551097d600c6dad13602029c2df"}, + {file = "pandas-2.0.3-cp311-cp311-win32.whl", hash = "sha256:694888a81198786f0e164ee3a581df7d505024fbb1f15202fc7db88a71d84ebd"}, + {file = "pandas-2.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6a21ab5c89dcbd57f78d0ae16630b090eec626360085a4148693def5452d8a6b"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4da0d45e7f34c069fe4d522359df7d23badf83abc1d1cef398895822d11061"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:32fca2ee1b0d93dd71d979726b12b61faa06aeb93cf77468776287f41ff8fdc5"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:258d3624b3ae734490e4d63c430256e716f488c4fcb7c8e9bde2d3aa46c29089"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eae3dc34fa1aa7772dd3fc60270d13ced7346fcbcfee017d3132ec625e23bb0"}, + {file = "pandas-2.0.3-cp38-cp38-win32.whl", hash = "sha256:f3421a7afb1a43f7e38e82e844e2bca9a6d793d66c1a7f9f0ff39a795bbc5e02"}, + {file = "pandas-2.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:69d7f3884c95da3a31ef82b7618af5710dba95bb885ffab339aad925c3e8ce78"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5247fb1ba347c1261cbbf0fcfba4a3121fbb4029d95d9ef4dc45406620b25c8b"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81af086f4543c9d8bb128328b5d32e9986e0c84d3ee673a2ac6fb57fd14f755e"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1994c789bf12a7c5098277fb43836ce090f1073858c10f9220998ac74f37c69b"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec591c48e29226bcbb316e0c1e9423622bc7a4eaf1ef7c3c9fa1a3981f89641"}, + {file = "pandas-2.0.3-cp39-cp39-win32.whl", hash = "sha256:04dbdbaf2e4d46ca8da896e1805bc04eb85caa9a82e259e8eed00254d5e0c682"}, + {file = "pandas-2.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:1168574b036cd8b93abc746171c9b4f1b83467438a5e45909fed645cf8692dbc"}, + {file = "pandas-2.0.3.tar.gz", hash = "sha256:c02f372a88e0d17f36d3093a644c73cfc1788e876a7c4bcb4020a77512e2043c"}, ] [package.dependencies] numpy = [ - {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, - {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.1" [package.extras] -all = ["PyQt5 (>=5.15.6)", "SQLAlchemy (>=1.4.36)", "beautifulsoup4 (>=4.11.1)", "bottleneck (>=1.3.4)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=0.8.1)", "fsspec (>=2022.05.0)", "gcsfs (>=2022.05.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.8.0)", "matplotlib (>=3.6.1)", "numba (>=0.55.2)", "numexpr (>=2.8.0)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.10)", "pandas-gbq (>=0.17.5)", "psycopg2 (>=2.9.3)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.5)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "pyxlsb (>=1.0.9)", "qtpy (>=2.2.0)", "s3fs (>=2022.05.0)", "scipy (>=1.8.1)", "tables (>=3.7.0)", "tabulate (>=0.8.10)", "xarray (>=2022.03.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.3)", "zstandard (>=0.17.0)"] -aws = ["s3fs (>=2022.05.0)"] -clipboard = ["PyQt5 (>=5.15.6)", "qtpy (>=2.2.0)"] -compression = ["zstandard (>=0.17.0)"] -computation = ["scipy (>=1.8.1)", "xarray (>=2022.03.0)"] -consortium-standard = ["dataframe-api-compat (>=0.1.7)"] -excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.10)", "pyxlsb (>=1.0.9)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.3)"] +all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] +aws = ["s3fs (>=2021.08.0)"] +clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"] +compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"] +computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"] feather = ["pyarrow (>=7.0.0)"] -fss = ["fsspec (>=2022.05.0)"] -gcp = ["gcsfs (>=2022.05.0)", "pandas-gbq (>=0.17.5)"] -hdf5 = ["tables (>=3.7.0)"] -html = ["beautifulsoup4 (>=4.11.1)", "html5lib (>=1.1)", "lxml (>=4.8.0)"] -mysql = ["SQLAlchemy (>=1.4.36)", "pymysql (>=1.0.2)"] -output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.8.10)"] +fss = ["fsspec (>=2021.07.0)"] +gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"] +hdf5 = ["tables (>=3.6.1)"] +html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"] +mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.0.0)", "tabulate (>=0.8.9)"] parquet = ["pyarrow (>=7.0.0)"] -performance = ["bottleneck (>=1.3.4)", "numba (>=0.55.2)", "numexpr (>=2.8.0)"] +performance = ["bottleneck (>=1.3.2)", "numba (>=0.53.1)", "numexpr (>=2.7.1)"] plot = ["matplotlib (>=3.6.1)"] -postgresql = ["SQLAlchemy (>=1.4.36)", "psycopg2 (>=2.9.3)"] -spss = ["pyreadstat (>=1.1.5)"] -sql-other = ["SQLAlchemy (>=1.4.36)"] -test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] -xml = ["lxml (>=4.8.0)"] +postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"] +spss = ["pyreadstat (>=1.1.2)"] +sql-other = ["SQLAlchemy (>=1.4.16)"] +test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.6.3)"] [[package]] name = "parso" @@ -5204,7 +5203,9 @@ cryptography = ">=3.1.0,<42.0.0" filelock = ">=3.5,<4" idna = ">=2.5,<4" packaging = "*" +pandas = {version = ">=1.0.0,<2.1.0", optional = true, markers = "extra == \"pandas\""} platformdirs = ">=2.6.0,<4.0.0" +pyarrow = {version = "*", optional = true, markers = "extra == \"pandas\""} pyjwt = "<3.0.0" pyOpenSSL = ">=16.2.0,<24.0.0" pytz = "*" @@ -6058,4 +6059,4 @@ visualization = ["graphviz"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "eba866e109cc185bd1a41aba18a54c92c7cfde3151a81ebacf31cf3dd3255782" +content-hash = "512a78aeb0fdc60c4a897376f4957bb8e88742a3a607809cfa79c1dea83c2f16" diff --git a/pyproject.toml b/pyproject.toml index 871d02b8c97a..82e3eb4326f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,9 @@ pyspark = { version = ">=3,<3.4", optional = true } # pyspark is heavily broken regex = { version = ">=2021.7.6", optional = true } requests = { version = ">=2,<3", optional = true } shapely = { version = ">=2,<3", optional = true } -snowflake-connector-python = { version = ">=3.0.2,<4,!=3.3.0b1", optional = true } +snowflake-connector-python = { version = ">=3.0.2,<4,!=3.3.0b1", optional = true, extras = [ + "pandas", +] } sqlalchemy = { version = ">=1.4,<3", optional = true } sqlalchemy-views = { version = ">=0.3.1,<1", optional = true } trino = { version = ">=0.321,<1", optional = true, extras = ["sqlalchemy"] } diff --git a/requirements-dev.txt b/requirements-dev.txt index 2368a9510988..0b0fd9b17be7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -107,7 +107,7 @@ numpy==1.26.2 ; python_version >= "3.9" and python_version < "4.0" oauthlib==3.2.2 ; python_version >= "3.9" and python_version < "4.0" oracledb==1.4.2 ; python_version >= "3.9" and python_version < "4.0" packaging==23.2 ; python_version >= "3.9" and python_version < "4.0" -pandas==2.1.3 ; python_version >= "3.9" and python_version < "4.0" +pandas==2.0.3 ; python_version >= "3.9" and python_version < "4.0" parso==0.8.3 ; python_version >= "3.9" and python_version < "4.0" parsy==2.1 ; python_version >= "3.9" and python_version < "4.0" partd==1.4.1 ; python_version >= "3.9" and python_version < "4.0" @@ -184,7 +184,7 @@ seaborn==0.13.0 ; python_version >= "3.10" and python_version < "3.13" setuptools==68.2.2 ; python_version >= "3.9" and python_version < "4.0" shapely==2.0.2 ; python_version >= "3.9" and python_version < "4.0" six==1.16.0 ; python_version >= "3.9" and python_version < "4.0" -snowflake-connector-python==3.5.0 ; python_version >= "3.9" and python_version < "4.0" +snowflake-connector-python[pandas]==3.5.0 ; python_version >= "3.9" and python_version < "4.0" sortedcontainers==2.4.0 ; python_version >= "3.9" and python_version < "4.0" sphobjinv==2.3.1 ; python_version >= "3.10" and python_version < "3.13" sqlalchemy-views==0.3.2 ; python_version >= "3.9" and python_version < "4.0"