From 529a3a26fcd16e186f3e491a7ded7fa767c79d69 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 18 Nov 2023 11:15:26 -0500 Subject: [PATCH] feat(snowflake): support udf arguments for reading from staged files --- .github/workflows/ibis-backends-cloud.yml | 1 - ibis/backends/snowflake/__init__.py | 65 +++++++++++++++-------- ibis/backends/snowflake/tests/conftest.py | 18 ++++++- ibis/backends/snowflake/tests/test_udf.py | 64 ++++++++++++++++++++++ 4 files changed, 122 insertions(+), 26 deletions(-) diff --git a/.github/workflows/ibis-backends-cloud.yml b/.github/workflows/ibis-backends-cloud.yml index 4ea522e8deac..e0b74c35588a 100644 --- a/.github/workflows/ibis-backends-cloud.yml +++ b/.github/workflows/ibis-backends-cloud.yml @@ -78,7 +78,6 @@ jobs: run: just download-data - uses: google-github-actions/auth@v2 - if: matrix.backend.name == 'bigquery' with: credentials_json: ${{ secrets.GCP_CREDENTIALS }} diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 033334a6cc45..d5a3a7ca0161 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -294,29 +294,53 @@ def _get_udf_source(self, udf_node: ops.ScalarUDF): for name, arg in zip(udf_node.argnames, udf_node.args) ) return_type = self._compile_type(udf_node.dtype) - source = textwrap.dedent(inspect.getsource(udf_node.__func__)).strip() - source = "\n".join( - line for line in source.splitlines() if not line.startswith("@udf") + lines, _ = inspect.getsourcelines(udf_node.__func__) + source = textwrap.dedent( + "".join( + itertools.dropwhile( + lambda line: not line.lstrip().startswith("def "), lines + ) + ) + ).strip() + + config = udf_node.__config__ + + preamble_lines = [*self._UDF_PREAMBLE_LINES] + + if imports := config.get("imports"): + preamble_lines.append(f"IMPORTS = ({', '.join(map(repr, imports))})") + + packages = "({})".format( + ", ".join(map(repr, ("pandas", *config.get("packages", ())))) ) + preamble_lines.append(f"PACKAGES = {packages}") + return dict( source=source, name=name, - signature=signature, - return_type=return_type, - comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}", - version=".".join( - map(str, min(sys.version_info[:2], self._latest_udf_python_version)) + preamble="\n".join(preamble_lines).format( + name=name, + signature=signature, + return_type=return_type, + comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}", + version=".".join( + map(str, min(sys.version_info[:2], self._latest_udf_python_version)) + ), ), ) + _UDF_PREAMBLE_LINES = ( + "CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})", + "RETURNS {return_type}", + "LANGUAGE PYTHON", + "IMMUTABLE", + "RUNTIME_VERSION = '{version}'", + "COMMENT = '{comment}'", + ) + def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str: return """\ -CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature}) -RETURNS {return_type} -LANGUAGE PYTHON -IMMUTABLE -RUNTIME_VERSION = '{version}' -COMMENT = '{comment}' +{preamble} HANDLER = '{name}' AS $$ from __future__ import annotations @@ -327,14 +351,8 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str: $$""".format(**self._get_udf_source(udf_node)) def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: - return """\ -CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature}) -RETURNS {return_type} -LANGUAGE PYTHON -IMMUTABLE -RUNTIME_VERSION = '{version}' -COMMENT = '{comment}' -PACKAGES = ('pandas') + template = """\ +{preamble} HANDLER = 'wrapper' AS $$ from __future__ import annotations @@ -349,7 +367,8 @@ def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: @_snowflake.vectorized(input=pd.DataFrame) def wrapper(df): return {name}(*(col for _, col in df.items())) -$$""".format(**self._get_udf_source(udf_node)) +$$""" + return template.format(**self._get_udf_source(udf_node)) def to_pyarrow( self, diff --git a/ibis/backends/snowflake/tests/conftest.py b/ibis/backends/snowflake/tests/conftest.py index 0ff7ce7cf586..1ff557bb8063 100644 --- a/ibis/backends/snowflake/tests/conftest.py +++ b/ibis/backends/snowflake/tests/conftest.py @@ -2,7 +2,10 @@ import concurrent.futures import os +import tempfile +from pathlib import Path from typing import TYPE_CHECKING, Any +from urllib.request import urlretrieve import pyarrow.parquet as pq import pyarrow_hotfix # noqa: F401 @@ -17,8 +20,6 @@ from ibis.formats.pyarrow import PyArrowSchema if TYPE_CHECKING: - from pathlib import Path - from ibis.backends.base import BaseBackend @@ -115,9 +116,22 @@ def _load_data(self, **_: Any) -> None: CREATE SCHEMA IF NOT EXISTS {dbschema}; USE SCHEMA {dbschema}; CREATE TEMP STAGE ibis_testing; +CREATE STAGE IF NOT EXISTS models; {self.script_dir.joinpath("snowflake.sql").read_text()}""" ) + with tempfile.TemporaryDirectory() as d: + path, _ = urlretrieve( + "https://storage.googleapis.com/ibis-testing-data/model.joblib", + os.path.join(d, "model.joblib"), + ) + + assert os.path.exists(path) + assert os.path.getsize(path) > 0 + + with con.begin() as c: + c.exec_driver_sql(f"PUT {Path(path).as_uri()} @MODELS") + with con.begin() as c: # not much we can do to make this faster, but running these in # multiple threads seems to save about 2x diff --git a/ibis/backends/snowflake/tests/test_udf.py b/ibis/backends/snowflake/tests/test_udf.py index 7310dce83b00..4ee17d757684 100644 --- a/ibis/backends/snowflake/tests/test_udf.py +++ b/ibis/backends/snowflake/tests/test_udf.py @@ -92,3 +92,67 @@ def test_builtin_agg_udf(con): expected = c.exec_driver_sql(query).cursor.fetch_pandas_all() tm.assert_frame_equal(result, expected) + + +def test_xgboost_model(con): + import ibis + from ibis import _ + + @udf.scalar.pandas( + packages=("joblib", "xgboost"), imports=("@MODELS/model.joblib",) + ) + def predict_price( + carat_scaled: float, cut_encoded: int, color_encoded: int, clarity_encoded: int + ) -> int: + import sys + + import joblib + import pandas as pd + + import_dir = sys._xoptions.get("snowflake_import_directory") + model = joblib.load(f"{import_dir}model.joblib") + df = pd.concat( + [carat_scaled, cut_encoded, color_encoded, clarity_encoded], axis=1 + ) + df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"] + return model.predict(df) + + def cases(value, mapping): + """This should really be a top-level function or method.""" + expr = ibis.case() + for k, v in mapping.items(): + expr = expr.when(value == k, v) + return expr.end() + + diamonds = con.tables.DIAMONDS + expr = diamonds.mutate( + predicted_price=predict_price( + (_.carat - _.carat.mean()) / _.carat.std(), + cases( + _.cut, + { + c: i + for i, c in enumerate( + ("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1 + ) + }, + ), + cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}), + cases( + _.clarity, + { + c: i + for i, c in enumerate( + ("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"), + start=1, + ) + }, + ), + ) + ) + + df = expr.execute() + + assert not df.empty + assert "predicted_price" in df.columns + assert len(df) == diamonds.count().execute()