Skip to content

Commit

Permalink
feat(snowflake): support udf arguments for reading from staged files
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Dec 8, 2023
1 parent 45ee391 commit 529a3a2
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 26 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ibis-backends-cloud.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ jobs:
run: just download-data

- uses: google-github-actions/auth@v2
if: matrix.backend.name == 'bigquery'
with:
credentials_json: ${{ secrets.GCP_CREDENTIALS }}

Expand Down
65 changes: 42 additions & 23 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,29 +294,53 @@ def _get_udf_source(self, udf_node: ops.ScalarUDF):
for name, arg in zip(udf_node.argnames, udf_node.args)
)
return_type = self._compile_type(udf_node.dtype)
source = textwrap.dedent(inspect.getsource(udf_node.__func__)).strip()
source = "\n".join(
line for line in source.splitlines() if not line.startswith("@udf")
lines, _ = inspect.getsourcelines(udf_node.__func__)
source = textwrap.dedent(
"".join(
itertools.dropwhile(
lambda line: not line.lstrip().startswith("def "), lines
)
)
).strip()

config = udf_node.__config__

preamble_lines = [*self._UDF_PREAMBLE_LINES]

if imports := config.get("imports"):
preamble_lines.append(f"IMPORTS = ({', '.join(map(repr, imports))})")

packages = "({})".format(
", ".join(map(repr, ("pandas", *config.get("packages", ()))))
)
preamble_lines.append(f"PACKAGES = {packages}")

return dict(
source=source,
name=name,
signature=signature,
return_type=return_type,
comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}",
version=".".join(
map(str, min(sys.version_info[:2], self._latest_udf_python_version))
preamble="\n".join(preamble_lines).format(
name=name,
signature=signature,
return_type=return_type,
comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}",
version=".".join(
map(str, min(sys.version_info[:2], self._latest_udf_python_version))
),
),
)

_UDF_PREAMBLE_LINES = (
"CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})",
"RETURNS {return_type}",
"LANGUAGE PYTHON",
"IMMUTABLE",
"RUNTIME_VERSION = '{version}'",
"COMMENT = '{comment}'",
)

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str:
return """\
CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})
RETURNS {return_type}
LANGUAGE PYTHON
IMMUTABLE
RUNTIME_VERSION = '{version}'
COMMENT = '{comment}'
{preamble}
HANDLER = '{name}'
AS $$
from __future__ import annotations
Expand All @@ -327,14 +351,8 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str:
$$""".format(**self._get_udf_source(udf_node))

def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
return """\
CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})
RETURNS {return_type}
LANGUAGE PYTHON
IMMUTABLE
RUNTIME_VERSION = '{version}'
COMMENT = '{comment}'
PACKAGES = ('pandas')
template = """\
{preamble}
HANDLER = 'wrapper'
AS $$
from __future__ import annotations
Expand All @@ -349,7 +367,8 @@ def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
@_snowflake.vectorized(input=pd.DataFrame)
def wrapper(df):
return {name}(*(col for _, col in df.items()))
$$""".format(**self._get_udf_source(udf_node))
$$"""
return template.format(**self._get_udf_source(udf_node))

def to_pyarrow(
self,
Expand Down
18 changes: 16 additions & 2 deletions ibis/backends/snowflake/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import concurrent.futures
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.request import urlretrieve

import pyarrow.parquet as pq
import pyarrow_hotfix # noqa: F401
Expand All @@ -17,8 +20,6 @@
from ibis.formats.pyarrow import PyArrowSchema

if TYPE_CHECKING:
from pathlib import Path

from ibis.backends.base import BaseBackend


Expand Down Expand Up @@ -115,9 +116,22 @@ def _load_data(self, **_: Any) -> None:
CREATE SCHEMA IF NOT EXISTS {dbschema};
USE SCHEMA {dbschema};
CREATE TEMP STAGE ibis_testing;
CREATE STAGE IF NOT EXISTS models;
{self.script_dir.joinpath("snowflake.sql").read_text()}"""
)

with tempfile.TemporaryDirectory() as d:
path, _ = urlretrieve(
"https://storage.googleapis.com/ibis-testing-data/model.joblib",
os.path.join(d, "model.joblib"),
)

assert os.path.exists(path)
assert os.path.getsize(path) > 0

with con.begin() as c:
c.exec_driver_sql(f"PUT {Path(path).as_uri()} @MODELS")

with con.begin() as c:
# not much we can do to make this faster, but running these in
# multiple threads seems to save about 2x
Expand Down
64 changes: 64 additions & 0 deletions ibis/backends/snowflake/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,67 @@ def test_builtin_agg_udf(con):
expected = c.exec_driver_sql(query).cursor.fetch_pandas_all()

tm.assert_frame_equal(result, expected)


def test_xgboost_model(con):
import ibis
from ibis import _

@udf.scalar.pandas(
packages=("joblib", "xgboost"), imports=("@MODELS/model.joblib",)
)
def predict_price(
carat_scaled: float, cut_encoded: int, color_encoded: int, clarity_encoded: int
) -> int:
import sys

import joblib
import pandas as pd

import_dir = sys._xoptions.get("snowflake_import_directory")
model = joblib.load(f"{import_dir}model.joblib")
df = pd.concat(
[carat_scaled, cut_encoded, color_encoded, clarity_encoded], axis=1
)
df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"]
return model.predict(df)

def cases(value, mapping):
"""This should really be a top-level function or method."""
expr = ibis.case()
for k, v in mapping.items():
expr = expr.when(value == k, v)
return expr.end()

diamonds = con.tables.DIAMONDS
expr = diamonds.mutate(
predicted_price=predict_price(
(_.carat - _.carat.mean()) / _.carat.std(),
cases(
_.cut,
{
c: i
for i, c in enumerate(
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
)
},
),
cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}),
cases(
_.clarity,
{
c: i
for i, c in enumerate(
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
start=1,
)
},
),
)
)

df = expr.execute()

assert not df.empty
assert "predicted_price" in df.columns
assert len(df) == diamonds.count().execute()

0 comments on commit 529a3a2

Please sign in to comment.