diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index bc4c09caa48f..d85adf947d95 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -126,7 +126,7 @@ jobs: extras: - sqlite - name: datafusion - title: Datafusion + title: DataFusion extras: - datafusion - name: polars diff --git a/gen_redirects.py b/gen_redirects.py index d0fa9ee3271a..a74e5e380358 100644 --- a/gen_redirects.py +++ b/gen_redirects.py @@ -10,7 +10,8 @@ "/backends/{version}/BigQuery/": "/backends/bigquery/", "/backends/{version}/Clickhouse/": "/backends/clickhouse/", "/backends/{version}/Dask/": "/backends/dask/", - "/backends/{version}/Datafusion/": "/backends/datafusion/", + "/backends/{version}/DataFusion/": "/backends/datafusion/", + "/backends/{version}/Datafusion/": "/backends/datafusion/", # For backwards compatibility "/backends/{version}/Druid/": "/backends/druid/", "/backends/{version}/DuckDB/": "/backends/duckdb/", "/backends/{version}/Impala/": "/backends/impala/", @@ -30,7 +31,8 @@ "/docs/{version}/backends/BigQuery/": "/backends/bigquery/", "/docs/{version}/backends/Clickhouse/": "/backends/clickhouse/", "/docs/{version}/backends/Dask/": "/backends/dask/", - "/docs/{version}/backends/Datafusion/": "/backends/datafusion/", + "/docs/{version}/backends/DataFusion/": "/backends/datafusion/", + "/docs/{version}/backends/Datafusion/": "/backends/datafusion/", # For backwards compatibility "/docs/{version}/backends/Druid/": "/backends/druid/", "/docs/{version}/backends/DuckDB/": "/backends/duckdb/", "/docs/{version}/backends/Impala/": "/backends/impala/", @@ -73,7 +75,8 @@ "/backends/BigQuery/": "/backends/bigquery/", "/backends/Clickhouse/": "/backends/clickhouse/", "/backends/Dask/": "/backends/dask/", - "/backends/Datafusion/": "/backends/datafusion/", + "/backends/DataFusion/": "/backends/datafusion/", + "/backends/Datafusion/": "/backends/datafusion/", # For backwards compatibility "/backends/Druid/": "/backends/druid/", "/backends/DuckDB/": "/backends/duckdb/", "/backends/Impala/": "/backends/impala/", diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index bfc3f461adb4..b3c6a6e30107 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -767,6 +767,7 @@ class BaseBackend(abc.ABC, _FileIOHandler): def __init__(self, *args, **kwargs): self._con_args: tuple[Any] = args self._con_kwargs: dict[str, Any] = kwargs + self._can_reconnect: bool = True # expression cache self._query_cache = RefCountedCache( populate=self._load_into_cache, @@ -856,7 +857,10 @@ def _convert_kwargs(kwargs: MutableMapping) -> None: # TODO(kszucs): should call self.connect(*self._con_args, **self._con_kwargs) def reconnect(self) -> None: """Reconnect to the database already configured with connect.""" - self.do_connect(*self._con_args, **self._con_kwargs) + if self._can_reconnect: + self.do_connect(*self._con_args, **self._con_kwargs) + else: + raise exc.IbisError("Cannot reconnect to unconfigured {self.name} backend") def do_connect(self, *args, **kwargs) -> None: """Connect to database specified by `args` and `kwargs`.""" diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index 907e74d3b76c..beaf10490db6 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -478,6 +478,37 @@ def do_connect( self.partition_column = partition_column + @util.experimental + @classmethod + def from_connection( + cls, + client: bq.Client, + partition_column: str | None = "PARTITIONTIME", + storage_client: bqstorage.BigQueryReadClient | None = None, + dataset_id: str = "", + ) -> Backend: + """Create a BigQuery `Backend` from an existing ``Client``. + + Parameters + ---------- + client + A `Client` from the `google.cloud.bigquery` package. + partition_column + Identifier to use instead of default `_PARTITIONTIME` partition + column. Defaults to `'PARTITIONTIME'`. + storage_client + A `BigQueryReadClient` from the `google.cloud.bigquery_storage_v1` + package. + dataset_id + A dataset id that lives inside of the project attached to `client`. + """ + return ibis.bigquery.connect( + client=client, + partition_column=partition_column, + storage_client=storage_client, + dataset_id=dataset_id, + ) + def disconnect(self) -> None: self.client.close() diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index b934708d2489..ec7ce922359d 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -163,6 +163,21 @@ def do_connect( **kwargs, ) + @util.experimental + @classmethod + def from_connection(cls, con: cc.driver.Client) -> Backend: + """Create an Ibis client from an existing ClickHouse Connect Client instance. + + Parameters + ---------- + con + An existing ClickHouse Connect Client instance. + """ + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = con + return new_backend + @property def version(self) -> str: return self.con.server_version diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 1e06ead4b17e..f88956337f36 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -14,11 +14,13 @@ import sqlglot as sg import sqlglot.expressions as sge +import ibis import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir +from ibis import util from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl from ibis.backends.sql import SQLBackend from ibis.backends.sql.compilers import DataFusionCompiler @@ -77,12 +79,13 @@ def version(self): def do_connect( self, config: Mapping[str, str | Path] | SessionContext | None = None ) -> None: - """Create a Datafusion backend for use with Ibis. + """Create a DataFusion `Backend` for use with Ibis. Parameters ---------- config - Mapping of table names to files. + Mapping of table names to files or a `SessionContext` + instance. Examples -------- @@ -112,6 +115,18 @@ def do_connect( for name, path in config.items(): self.register(path, table_name=name) + @util.experimental + @classmethod + def from_connection(cls, con: SessionContext) -> Backend: + """Create a DataFusion `Backend` from an existing `SessionContext` instance. + + Parameters + ---------- + con + A `SessionContext` instance. + """ + return ibis.datafusion.connect(con) + def disconnect(self) -> None: pass @@ -329,7 +344,7 @@ def register( table_name The name of the table kwargs - Datafusion-specific keyword arguments + DataFusion-specific keyword arguments Examples -------- @@ -423,7 +438,7 @@ def read_csv( An optional name to use for the created table. This defaults to a sequentially generated name. **kwargs - Additional keyword arguments passed to Datafusion loading function. + Additional keyword arguments passed to DataFusion loading function. Returns ------- @@ -451,7 +466,7 @@ def read_parquet( An optional name to use for the created table. This defaults to a sequentially generated name. **kwargs - Additional keyword arguments passed to Datafusion loading function. + Additional keyword arguments passed to DataFusion loading function. Returns ------- @@ -576,7 +591,7 @@ def create_table( temp: bool = False, overwrite: bool = False, ): - """Create a table in Datafusion. + """Create a table in DataFusion. Parameters ---------- @@ -697,7 +712,7 @@ def truncate_table( def _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite): """Workaround inability to overwrite tables in dataframe API. - Datafusion has helper methods for loading in-memory data, but these methods + DataFusion has helper methods for loading in-memory data, but these methods don't allow overwriting tables. The SQL interface allows creating tables from existing tables, so we register the data as a table using the dataframe API, then run a diff --git a/ibis/backends/druid/__init__.py b/ibis/backends/druid/__init__.py index 4c737e001ded..1b794279e105 100644 --- a/ibis/backends/druid/__init__.py +++ b/ibis/backends/druid/__init__.py @@ -12,6 +12,7 @@ import ibis.expr.datatypes as dt import ibis.expr.schema as sch +from ibis import util from ibis.backends.sql import SQLBackend from ibis.backends.sql.compilers import DruidCompiler from ibis.backends.sql.compilers.base import STAR @@ -81,6 +82,21 @@ def do_connect(self, **kwargs: Any) -> None: header = kwargs.pop("header", True) self.con = pydruid.db.connect(**kwargs, header=header) + @util.experimental + @classmethod + def from_connection(cls, con: pydruid.db.api.Connection) -> Backend: + """Create an Ibis client from an existing connection to a Druid database. + + Parameters + ---------- + con + An existing connection to a Druid database. + """ + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = con + return new_backend + @contextlib.contextmanager def _safe_raw_sql(self, query, *args, **kwargs): with contextlib.suppress(AttributeError): diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 8f23c3bce534..747dba0d25a8 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -473,6 +473,31 @@ def do_connect( self.con = duckdb.connect(str(database), config=config, read_only=read_only) + self._post_connect(extensions) + + @util.experimental + @classmethod + def from_connection( + cls, + con: duckdb.DuckDBPyConnection, + extensions: Sequence[str] | None = None, + ) -> Backend: + """Create an Ibis client from an existing connection to a DuckDB database. + + Parameters + ---------- + con + An existing connection to a DuckDB database. + extensions + A list of duckdb extensions to install/load upon connection. + """ + new_backend = cls(extensions=extensions) + new_backend._can_reconnect = False + new_backend.con = con + new_backend._post_connect(extensions) + return new_backend + + def _post_connect(self, extensions: Sequence[str] | None = None) -> None: # Load any pre-specified extensions if extensions is not None: self._load_extensions(extensions) diff --git a/ibis/backends/exasol/__init__.py b/ibis/backends/exasol/__init__.py index e7c91383e2ca..5b1acd584718 100644 --- a/ibis/backends/exasol/__init__.py +++ b/ibis/backends/exasol/__init__.py @@ -97,6 +97,32 @@ def do_connect( quote_ident=True, **kwargs, ) + self._post_connect(timezone) + + @util.experimental + @classmethod + def from_connection( + cls, con: pyexasol.ExaConnection, timezone: str | None = None + ) -> Backend: + """Create an Ibis client from an existing connection to an Exasol database. + + Parameters + ---------- + con + An existing connection to an Exasol database. + timezone + The session timezone. + """ + if timezone is None: + timezone = (con.execute("SELECT SESSIONTIMEZONE").fetchone() or ("UTC",))[0] + + new_backend = cls(timezone=timezone) + new_backend._can_reconnect = False + new_backend.con = con + new_backend._post_connect(timezone) + return new_backend + + def _post_connect(self, timezone: str = "UTC") -> None: with self.begin() as con: con.execute(f"ALTER SESSION SET TIME_ZONE = {timezone!r}") diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index 44e4cc08f6c7..8445e5250161 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -11,6 +11,7 @@ import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir +from ibis import util from ibis.backends import CanCreateDatabase, NoUrl from ibis.backends.flink.ddl import ( CreateDatabase, @@ -71,6 +72,18 @@ def do_connect(self, table_env: TableEnvironment) -> None: """ self._table_env = table_env + @util.experimental + @classmethod + def from_connection(cls, table_env: TableEnvironment) -> Backend: + """Create a Flink `Backend` from an existing table environment. + + Parameters + ---------- + table_env + A table environment. + """ + return ibis.flink.connect(table_env) + def disconnect(self) -> None: pass diff --git a/ibis/backends/impala/__init__.py b/ibis/backends/impala/__init__.py index 44d9759d2501..93122250c5ec 100644 --- a/ibis/backends/impala/__init__.py +++ b/ibis/backends/impala/__init__.py @@ -45,6 +45,7 @@ from pathlib import Path from urllib.parse import ParseResult + import impala.hiveserver2 as hs2 import pandas as pd import polars as pl import pyarrow as pa @@ -183,6 +184,25 @@ def do_connect( cur.ping() self.con = con + self._post_connect() + + @util.experimental + @classmethod + def from_connection(cls, con: hs2.HiveServer2Connection) -> Backend: + """Create an Impala `Backend` from an existing HS2 connection. + + Parameters + ---------- + con + An existing connection to HiveServer2 (HS2). + """ + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = con + new_backend._post_connect() + return new_backend + + def _post_connect(self) -> None: self.options = {} @cached_property diff --git a/ibis/backends/mssql/__init__.py b/ibis/backends/mssql/__init__.py index a259850767a0..1af80abf145e 100644 --- a/ibis/backends/mssql/__init__.py +++ b/ibis/backends/mssql/__init__.py @@ -112,7 +112,7 @@ def do_connect( if user is None and password is None: kwargs.setdefault("Trusted_Connection", "yes") - con = pyodbc.connect( + self.con = pyodbc.connect( user=user, server=f"{host},{port}", password=password, @@ -121,14 +121,31 @@ def do_connect( **kwargs, ) + self._post_connect() + + @util.experimental + @classmethod + def from_connection(cls, con: pyodbc.Connection) -> Backend: + """Create an Ibis client from an existing connection to a MSSQL database. + + Parameters + ---------- + con + An existing connection to a MSSQL database. + """ + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = con + new_backend._post_connect() + return new_backend + + def _post_connect(self): # -155 is the code for datetimeoffset - con.add_output_converter(-155, datetimeoffset_to_datetime) + self.con.add_output_converter(-155, datetimeoffset_to_datetime) - with closing(con.cursor()) as cur: + with closing(self.con.cursor()) as cur: cur.execute("SET DATEFIRST 1") - self.con = con - def get_schema( self, name: str, *, catalog: str | None = None, database: str | None = None ) -> sch.Schema: diff --git a/ibis/backends/mysql/__init__.py b/ibis/backends/mysql/__init__.py index 0fbd6eb48ad0..f2f5067c6bdf 100644 --- a/ibis/backends/mysql/__init__.py +++ b/ibis/backends/mysql/__init__.py @@ -151,7 +151,7 @@ def do_connect( month : int32 """ - con = pymysql.connect( + self.con = pymysql.connect( user=user, host=host, port=port, @@ -162,14 +162,31 @@ def do_connect( **kwargs, ) - with contextlib.closing(con.cursor()) as cur: + self._post_connect() + + @util.experimental + @classmethod + def from_connection(cls, con: pymysql.Connection) -> Backend: + """Create an Ibis client from an existing connection to a MySQL database. + + Parameters + ---------- + con + An existing connection to a MySQL database. + """ + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = con + new_backend._post_connect() + return new_backend + + def _post_connect(self) -> None: + with contextlib.closing(self.con.cursor()) as cur: try: cur.execute("SET @@session.time_zone = 'UTC'") except Exception as e: # noqa: BLE001 warnings.warn(f"Unable to set session timezone to UTC: {e}") - self.con = con - @property def current_database(self) -> str: with self._safe_raw_sql(sg.select(self.compiler.f.database())) as cur: diff --git a/ibis/backends/oracle/__init__.py b/ibis/backends/oracle/__init__.py index f3f4e22fb62d..2339859f0000 100644 --- a/ibis/backends/oracle/__init__.py +++ b/ibis/backends/oracle/__init__.py @@ -154,6 +154,25 @@ def do_connect( # https://python-oracledb.readthedocs.io/en/latest/user_guide/appendix_b.html#statement-caching-in-thin-and-thick-modes self.con = oracledb.connect(dsn, user=user, password=password, stmtcachesize=0) + self._post_connect() + + @util.experimental + @classmethod + def from_connection(cls, con: oracledb.Connection) -> Backend: + """Create an Ibis client from an existing connection to an Oracle database. + + Parameters + ---------- + con + An existing connection to an Oracle database. + """ + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = con + new_backend._post_connect() + return new_backend + + def _post_connect(self) -> None: # turn on autocommit # TODO: it would be great if this worked but it doesn't seem to do the trick # I had to hack in the commit lines to the compiler diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index 39797ca982d6..e1963503b986 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -37,6 +37,7 @@ import pandas as pd import polars as pl + import psycopg2 import pyarrow as pa @@ -280,6 +281,25 @@ def do_connect( **kwargs, ) + self._post_connect() + + @util.experimental + @classmethod + def from_connection(cls, con: psycopg2.extensions.connection) -> Backend: + """Create an Ibis client from an existing connection to a PostgreSQL database. + + Parameters + ---------- + con + An existing connection to a PostgreSQL database. + """ + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = con + new_backend._post_connect() + return new_backend + + def _post_connect(self) -> None: with self.begin() as cur: cur.execute("SET TIMEZONE = UTC") diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 4f8bab7e5fb7..13dcc0264a2e 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -144,7 +144,7 @@ def do_connect( Parameters ---------- session - A SparkSession instance + A `SparkSession` instance. mode Can be either "batch" or "streaming". If "batch", every source, sink, and query executed within this connection will be interpreted as a batch @@ -185,6 +185,27 @@ def do_connect( for key, value in kwargs.items(): self._session.conf.set(key, value) + @util.experimental + @classmethod + def from_connection( + cls, session: SparkSession, mode: ConnectionMode = "batch", **kwargs + ) -> Backend: + """Create a PySpark `Backend` from an existing `SparkSession` instance. + + Parameters + ---------- + session + A `SparkSession` instance. + mode + Can be either "batch" or "streaming". If "batch", every source, sink, and + query executed within this connection will be interpreted as a batch + workload. If "streaming", every source, sink, and query executed within + this connection will be interpreted as a streaming workload. + kwargs + Additional keyword arguments used to configure the SparkSession. + """ + return ibis.pyspark.connect(session, mode, **kwargs) + def disconnect(self) -> None: self._session.stop() diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index ea4b64407ab5..6a03a8eb4e49 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -44,6 +44,8 @@ import pandas as pd import polars as pl + import snowflake.connector + import snowflake.snowpark _SNOWFLAKE_MAP_UDFS = { @@ -83,7 +85,7 @@ class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema): supports_python_udfs = True _latest_udf_python_version = (3, 10) - _top_level_methods = ("from_snowpark",) + _top_level_methods = ("from_connection", "from_snowpark") def __init__(self, *args, _from_snowpark: bool = False, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -197,7 +199,7 @@ def do_connect(self, create_object_udfs: bool = True, **kwargs: Any): `ibis.snowflake.connect(...)` can succeed, while subsequent API calls fail if the authentication fails for any reason. create_object_udfs - Enable object UDF extensions defined by ibis on the first + Enable object UDF extensions defined by Ibis on the first connection to the database. kwargs Additional arguments passed to the DBAPI connection call. @@ -274,7 +276,9 @@ def _setup_session(self, *, session_parameters, create_object_udfs: bool): @util.experimental @classmethod - def from_snowpark(cls, session, *, create_object_udfs: bool = True) -> Backend: + def from_snowpark( + cls, session: snowflake.snowpark.Session, *, create_object_udfs: bool = True + ) -> Backend: """Create an Ibis Snowflake backend from a Snowpark session. Parameters @@ -282,7 +286,7 @@ def from_snowpark(cls, session, *, create_object_udfs: bool = True) -> Backend: session A Snowpark session instance. create_object_udfs - Enable object UDF extensions defined by ibis on the first + Enable object UDF extensions defined by Ibis on the first connection to the database. Returns @@ -322,6 +326,67 @@ def from_snowpark(cls, session, *, create_object_udfs: bool = True) -> Backend: ) return backend + @util.experimental + @classmethod + def from_connection( + cls, + con: snowflake.connector.SnowflakeConnection | snowflake.snowpark.Session, + *, + create_object_udfs: bool = True, + ) -> Backend: + """Create an Ibis Snowflake backend from an existing connection. + + Parameters + ---------- + con + A Snowflake Connector for Python connection or a Snowpark + session instance. + create_object_udfs + Enable object UDF extensions defined by Ibis on the first + connection to the database. + + Returns + ------- + Backend + An Ibis Snowflake backend instance. + + Examples + -------- + >>> import ibis + >>> ibis.options.interactive = True + >>> import snowflake.snowpark as sp # doctest: +SKIP + >>> session = sp.Session.builder.configs(...).create() # doctest: +SKIP + >>> con = ibis.snowflake.from_connection(session) # doctest: +SKIP + >>> batting = con.tables.BATTING # doctest: +SKIP + >>> batting[["playerID", "RBI"]].head() # doctest: +SKIP + ┏━━━━━━━━━━━┳━━━━━━━┓ + ┃ playerID ┃ RBI ┃ + ┡━━━━━━━━━━━╇━━━━━━━┩ + │ string │ int64 │ + ├───────────┼───────┤ + │ abercda01 │ 0 │ + │ addybo01 │ 13 │ + │ allisar01 │ 19 │ + │ allisdo01 │ 27 │ + │ ansonca01 │ 16 │ + └───────────┴───────┘ + """ + import snowflake.connector + + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = ( + con + if isinstance(con, snowflake.connector.SnowflakeConnection) + else con._conn._conn + ) + with contextlib.suppress(snowflake.connector.errors.ProgrammingError): + # stored procs on snowflake don't allow session mutation it seems + new_backend._setup_session( + session_parameters={}, create_object_udfs=create_object_udfs + ) + return new_backend + def reconnect(self) -> None: if self._from_snowpark: raise com.IbisError( diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index a8356a694547..27f81066f350 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -70,6 +70,8 @@ class SQLBackend(BaseBackend, _DatabaseSchemaHandler): compiler: ClassVar[SQLGlotCompiler] name: ClassVar[str] + _top_level_methods = ("from_connection",) + @property def dialect(self) -> sg.Dialect: return self.compiler.dialect @@ -547,6 +549,22 @@ def truncate_table( with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"): pass + @util.experimental + @classmethod + def from_connection(cls, con: Any, **kwargs: Any) -> BaseBackend: + """Create an Ibis client from an existing connection. + + Parameters + ---------- + con + An existing connection. + **kwargs + Extra arguments to be applied to the newly-created backend. + """ + raise NotImplementedError( + f"{cls.name} backend cannot be constructed from an existing connection" + ) + def disconnect(self): # This is part of the Python DB-API specification so should work for # _most_ sqlglot backends diff --git a/ibis/backends/sqlite/__init__.py b/ibis/backends/sqlite/__init__.py index fddcc0ee5543..09980edd52f2 100644 --- a/ibis/backends/sqlite/__init__.py +++ b/ibis/backends/sqlite/__init__.py @@ -73,7 +73,7 @@ def do_connect( files type_map An optional mapping from a string name of a SQLite "type" to the - corresponding ibis DataType that it represents. This can be used + corresponding Ibis DataType that it represents. This can be used to override schema inference for a given SQLite database. Examples @@ -84,13 +84,42 @@ def do_connect( """ _init_sqlite3() + self.con = sqlite3.connect(":memory:" if database is None else database) + + self._post_connect(type_map) + + @util.experimental + @classmethod + def from_connection( + cls, + con: sqlite3.Connection, + type_map: dict[str, str | dt.DataType] | None = None, + ) -> Backend: + """Create an Ibis client from an existing connection to a SQLite database. + + Parameters + ---------- + con + An existing connection to a SQLite database. + type_map + An optional mapping from a string name of a SQLite "type" to the + corresponding Ibis DataType that it represents. This can be used + to override schema inference for a given SQLite database. + """ + new_backend = cls(type_map=type_map) + new_backend._can_reconnect = False + new_backend.con = con + new_backend._post_connect(type_map) + return new_backend + + def _post_connect( + self, type_map: dict[str, str | dt.DataType] | None = None + ) -> None: if type_map: self._type_map = {k.lower(): ibis.dtype(v) for k, v in type_map.items()} else: self._type_map = {} - self.con = sqlite3.connect(":memory:" if database is None else database) - register_all(self.con) self.con.execute("PRAGMA case_sensitive_like=ON") diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 6c1398c4cbc6..ccb01e1cc87c 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1614,3 +1614,16 @@ def test_insert_using_col_name_not_position(con, first_row, second_row, monkeypa # Ideally we'd use a temp table for this test, but several backends don't # support them and it's nice to know that data are being inserted correctly. con.drop_table(table_name) + + +CON_ATTR = {"bigquery": "client", "flink": "_table_env", "pyspark": "_session"} +DEFAULT_CON_ATTR = "con" + + +@pytest.mark.parametrize("top_level", [True, False]) +@pytest.mark.never(["dask", "pandas", "polars"], reason="don't have connection concept") +def test_from_connection(con, top_level): + backend = getattr(ibis, con.name) if top_level else type(con) + new_con = backend.from_connection(getattr(con, CON_ATTR.get(con.name, "con"))) + result = int(new_con.execute(ibis.literal(1, type="int"))) + assert result == 1 diff --git a/ibis/backends/trino/__init__.py b/ibis/backends/trino/__init__.py index 38d1e8171253..579755926001 100644 --- a/ibis/backends/trino/__init__.py +++ b/ibis/backends/trino/__init__.py @@ -325,6 +325,21 @@ def do_connect( **kwargs, ) + @util.experimental + @classmethod + def from_connection(cls, con: trino.dbapi.Connection) -> Backend: + """Create an Ibis client from an existing connection to a Trino database. + + Parameters + ---------- + con + An existing connection to a Trino database. + """ + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = con + return new_backend + def _get_schema_using_query(self, query: str) -> sch.Schema: name = util.gen_name(f"{self.name}_metadata") with self.begin() as cur: