Skip to content

Commit

Permalink
feat(flink): implement UDF support for the backend (#8142)
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman authored Feb 20, 2024
1 parent 0320d01 commit a3b1cc6
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 16 deletions.
9 changes: 9 additions & 0 deletions docker/flink/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
FROM flink:1.18.1

# ibis-flink requires PyFlink dependency
RUN wget -nv -P $FLINK_HOME/lib/ https://repo1.maven.org/maven2/org/apache/flink/flink-python/1.18.1/flink-python-1.18.1.jar

# install python3 and pip3
RUN apt-get update -y && \
apt-get install -y python3 python3-pip python3-dev && rm -rf /var/lib/apt/lists/*
RUN ln -s /usr/bin/python3 /usr/bin/python

# install PyFlink
RUN pip3 install apache-flink==1.18.1
2 changes: 1 addition & 1 deletion ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:
self._register_udfs(expr)
self._register_in_memory_tables(expr)

def _define_udf_translation_rules(self, expr):
def _define_udf_translation_rules(self, expr: ir.Expr):
if self.supports_python_udfs:
raise NotImplementedError(self.name)

Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,8 +1424,6 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
self.con.register(name, op.data.to_pyarrow(op.schema))

def _register_udfs(self, expr: ir.Expr) -> None:
import ibis.expr.operations as ops

con = self.con

for udf_node in expr.op().find(ops.ScalarUDF):
Expand All @@ -1439,7 +1437,7 @@ def _register_udfs(self, expr: ir.Expr) -> None:
if registration_func is not None:
registration_func(con)

def _compile_udf(self, udf_node: ops.ScalarUDF) -> None:
def _compile_udf(self, udf_node: ops.ScalarUDF):
func = udf_node.__func__
name = type(udf_node).__name__
type_mapper = self.compiler.type_mapper
Expand Down
29 changes: 29 additions & 0 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
RenameTable,
)
from ibis.backends.tests.errors import Py4JJavaError
from ibis.expr.operations.udf import InputType
from ibis.util import gen_name

if TYPE_CHECKING:
Expand All @@ -36,6 +37,8 @@

from ibis.expr.api import Watermark

_INPUT_TYPE_TO_FUNC_TYPE = {InputType.PYTHON: "general", InputType.PANDAS: "pandas"}


class Backend(SQLGlotBackend, CanCreateDatabase, NoUrl):
name = "flink"
Expand Down Expand Up @@ -301,6 +304,30 @@ def version(self) -> str:

return pyflink.version.__version__

def _register_udfs(self, expr: ir.Expr) -> None:
for udf_node in expr.op().find(ops.ScalarUDF):
register_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
register_func(udf_node)

def _register_udf(self, udf_node: ops.ScalarUDF):
import pyflink.table.udf

from ibis.backends.flink.datatypes import FlinkType

name = type(udf_node).__name__
self._table_env.drop_temporary_function(name)
udf = pyflink.table.udf.udf(
udf_node.__func__,
result_type=FlinkType.from_ibis(udf_node.dtype),
func_type=_INPUT_TYPE_TO_FUNC_TYPE[udf_node.__input_type__],
)
self._table_env.create_temporary_function(name, udf)

_compile_pandas_udf = _register_udf
_compile_python_udf = _register_udf

def compile(
self, expr: ir.Expr, params: Mapping[ir.Expr, Any] | None = None, **_: Any
) -> Any:
Expand All @@ -312,6 +339,8 @@ def _to_sql(self, expr: ir.Expr, **kwargs: Any) -> str:

def execute(self, expr: ir.Expr, **kwargs: Any) -> Any:
"""Execute an expression."""
self._register_udfs(expr)

table_expr = expr.as_table()
sql = self.compile(table_expr, **kwargs)
df = self._table_env.sql_query(sql).to_pandas()
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ReductionVectorizedUDF,
ops.RegexSplit,
ops.RowID,
ops.ScalarUDF,
ops.StringSplit,
ops.Translate,
ops.Unnest,
Expand Down
25 changes: 25 additions & 0 deletions ibis/backends/flink/tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from ibis import udf


def test_builtin_scalar_udf(con):
@udf.scalar.builtin
def parse_url(string1: str, string2: str) -> str:
...

expr = parse_url("http://facebook.com/path1/p.php?k1=v1&k2=v2#Ref1", "HOST")
result = con.execute(expr)
assert result == "facebook.com"


def test_builtin_agg_udf(con):
@udf.agg.builtin
def json_arrayagg(a) -> str:
"""Glom together some JSON."""

ft = con.tables.functional_alltypes[:5]
expr = json_arrayagg(ft.string_col)
result = expr.execute()
expected = '["0","1","2","3","4"]'
assert result == expected
3 changes: 1 addition & 2 deletions ibis/backends/mysql/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
from datetime import date
from operator import methodcaller

Expand Down Expand Up @@ -184,5 +183,5 @@ def json_arrayagg(a) -> str:
ft = con.tables.functional_alltypes[:5]
expr = json_arrayagg(ft.string_col)
result = expr.execute()
expected = json.dumps(list(map(str, range(5))), separators=",:")
expected = '["0","1","2","3","4"]'
assert result == expected
13 changes: 4 additions & 9 deletions ibis/backends/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ibis.common.exceptions as com
from ibis import _, udf
from ibis.backends.tests.errors import Py4JJavaError

no_python_udfs = mark.notimpl(
[
Expand All @@ -12,7 +13,6 @@
"dask",
"druid",
"exasol",
"flink",
"impala",
"mssql",
"mysql",
Expand All @@ -27,7 +27,6 @@
@no_python_udfs
@mark.notimpl(["pyspark"])
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(["flink"], raises=com.OperationNotDefinedError)
def test_udf(batting):
@udf.scalar.python
def num_vowels(s: str, include_y: bool = False) -> int:
Expand Down Expand Up @@ -55,6 +54,7 @@ def num_vowels(s: str, include_y: bool = False) -> int:
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notimpl(["polars"])
@mark.notimpl(["flink"], raises=Py4JJavaError)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(
["sqlite"], raises=com.IbisTypeError, reason="sqlite doesn't support map types"
Expand Down Expand Up @@ -84,7 +84,7 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notimpl(["polars"])
@mark.notimpl(["flink"], raises=com.OperationNotDefinedError)
@mark.notimpl(["flink"], raises=Py4JJavaError)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(["sqlite"], raises=TypeError, reason="sqlite doesn't support map types")
def test_map_merge_udf(batting):
Expand Down Expand Up @@ -147,11 +147,6 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
raises=NotImplementedError,
reason="postgres only supports Python-native UDFs",
)
@mark.notimpl(
["flink"],
raises=com.OperationNotDefinedError,
reason="No translation rule for Pandas or PyArrow",
)
@mark.parametrize(
"add_one",
[
Expand All @@ -169,7 +164,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
add_one_pyarrow,
marks=[
mark.notyet(
["snowflake", "sqlite", "pyspark"],
["snowflake", "sqlite", "pyspark", "flink"],
raises=NotImplementedError,
reason="backend doesn't support pyarrow UDFs",
)
Expand Down

0 comments on commit a3b1cc6

Please sign in to comment.