diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index daae44fa106e..44a3ff2b6d67 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -222,8 +222,8 @@ def do_connect(self, create_object_udfs: bool = True, **kwargs: Any): create_object_udfs=create_object_udfs, ) - def _setup_session(self, *, con, session_parameters, create_object_udfs: bool): - self.con = con + def _setup_session(self, *, session_parameters, create_object_udfs: bool): + con = self.con # enable multiple SQL statements by default session_parameters.setdefault("MULTI_STATEMENT_COUNT", 0) @@ -318,12 +318,15 @@ def from_snowpark(cls, session, *, create_object_udfs: bool = True) -> Backend: │ ansonca01 │ 16 │ └───────────┴───────┘ """ + import snowflake.connector + backend = cls(_from_snowpark=True) - backend._setup_session( - con=session._conn._conn, - session_parameters={}, - create_object_udfs=create_object_udfs, - ) + backend.con = session._conn._conn + with contextlib.suppress(snowflake.connector.errors.ProgrammingError): + # stored procs on snowflake don't allow session mutation it seems + backend._setup_session( + session_parameters={}, create_object_udfs=create_object_udfs + ) return backend def reconnect(self) -> None: diff --git a/ibis/backends/snowflake/tests/test_udf.py b/ibis/backends/snowflake/tests/test_udf.py index 45e2b21ce973..87696fc95788 100644 --- a/ibis/backends/snowflake/tests/test_udf.py +++ b/ibis/backends/snowflake/tests/test_udf.py @@ -1,11 +1,17 @@ from __future__ import annotations +import os +import shutil +import tempfile + import pandas.testing as tm import pytest from pytest import param +import ibis import ibis.expr.datatypes as dt from ibis import udf +from ibis.backends.tests.errors import SnowflakeProgrammingError @udf.scalar.builtin @@ -89,7 +95,6 @@ def test_builtin_agg_udf(con): def test_xgboost_model(con): - import ibis from ibis import _ @udf.scalar.pandas( @@ -150,3 +155,135 @@ def cases(value, mapping): assert not df.empty assert "predicted_price" in df.columns assert len(df) == diamonds.count().execute() + + +def touch_package(pkgpath): + # nix timestamps every file to the UNIX epoch for reproducibility + # so we modify the utime of the _copied_ code since snowflake has + # some annoying checks for zipping files that do not allow files + # older than 1980 ¯\_(ツ)_/¯ + os.utime(pkgpath, None) + for root, dirs, files in os.walk(pkgpath): + for path in dirs + files: + os.utime(os.path.join(root, path), None) + + +def add_packages(d, session): + import parsy + import pyarrow_hotfix + import rich + import sqlglot + + import ibis + + for module in (rich, parsy, sqlglot, pyarrow_hotfix): + pkgname = module.__name__ + pkgpath = os.path.join(d, pkgname) + shutil.copytree(os.path.dirname(module.__file__), pkgpath) + touch_package(pkgpath) + session.add_import(pkgname, import_path=pkgname) + + # no need to touch the package because we're using the local version + shutil.copytree(os.path.dirname(ibis.__file__), "ibis") + session.add_import("ibis", import_path="ibis") + + +@pytest.fixture +def snowpark_session(): + if not os.environ.get("SNOWFLAKE_SNOWPARK"): + pytest.skip("SNOWFLAKE_SNOWPARK is not set") + else: + sp = pytest.importorskip("snowflake.snowpark") + + if connection_name := os.environ.get("SNOWFLAKE_DEFAULT_CONNECTION_NAME"): + builder = sp.Session.builder.config("connection_name", connection_name) + else: + builder = sp.Session.builder.configs( + { + "user": os.environ["SNOWFLAKE_USER"], + "account": os.environ["SNOWFLAKE_ACCOUNT"], + "password": os.environ["SNOWFLAKE_PASSWORD"], + "warehouse": os.environ["SNOWFLAKE_WAREHOUSE"], + "database": os.environ["SNOWFLAKE_DATABASE"], + "schema": os.environ["SNOWFLAKE_SCHEMA"], + } + ) + + session = builder.create() + session.custom_package_usage_config["enabled"] = True + + pwd = os.getcwd() + + with tempfile.TemporaryDirectory() as d: + os.chdir(d) + + try: + add_packages(d, session) + yield session + finally: + os.chdir(pwd) + session.clear_imports() + + +@pytest.mark.parametrize( + "execute_as", + [ + "owner", + param("caller", marks=[pytest.mark.xfail(raises=SnowflakeProgrammingError)]), + ], +) +def test_ibis_inside_snowpark(snowpark_session, execute_as): + import snowflake.snowpark as sp + + def ibis_sproc(session): + import ibis.backends.snowflake + + con = ibis.backends.snowflake.Backend.from_snowpark(session) + + expr = ( + con.tables.functional_alltypes.group_by("string_col") + .agg(n=lambda t: t.count()) + .order_by("string_col") + ) + + return session.sql(ibis.to_sql(expr)) + + expected = ( + snowpark_session.table('"functional_alltypes"') + .group_by('"string_col"') + .count() + .rename("COUNT", '"n"') + .order_by('"string_col"') + .to_pandas() + ) + + local_result = ibis_sproc(snowpark_session).to_pandas() + + tm.assert_frame_equal(local_result, expected) + + name = ibis_sproc.__name__ + + snowpark_session.sproc.register( + ibis_sproc, + name=name, + execute_as=execute_as, + imports=["parsy", "rich", "sqlglot", "pyarrow_hotfix", "ibis"], + # empty struct here tells Snowflake to infer the return type from the + # return value of the function, which is required to be a Snowpark + # table in that case + return_type=sp.types.StructType(), + packages=[ + "snowflake-snowpark-python", + "toolz", + "atpublic", + "bidict", + "pyarrow", + "pandas", + "numpy", + ], + replace=True, + ) + + remote_result = snowpark_session.call(name).to_pandas() + + tm.assert_frame_equal(remote_result, local_result)