Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udf): support multiprocess pool for CPU-bound Python UDFs #13838

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions e2e_test/udf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,32 @@ def gcd3(x: int, y: int, z: int) -> int:
return gcd(gcd(x, y), z)


# a function that takes a long time to execute
@udf(input_types=["BIGINT"], result_type="BIGINT", workers=10)
def square(x: int) -> int:
sum = 0
for _ in range(x):
sum += x
return sum


# a table function that takes a long time to execute
# returns all primes in range [x, y)
@udtf(input_types=["INT", "INT"], result_types="INT", workers=10)
def primes(x: int, y: int) -> Iterator[int]:
def is_prime(n: int) -> bool:
if n <= 1:
return False
for i in range(2, n):
if n % i == 0:
return False
return True

for i in range(x, y):
if is_prime(i):
yield i


@udf(input_types=["BYTEA"], result_type="STRUCT<VARCHAR, VARCHAR, SMALLINT, SMALLINT>")
def extract_tcp_info(tcp_packet: bytes):
src_addr, dst_addr = struct.unpack("!4s4s", tcp_packet[12:20])
Expand Down Expand Up @@ -213,6 +239,8 @@ def return_all_arrays(
server.add_function(sleep)
server.add_function(gcd)
server.add_function(gcd3)
server.add_function(square)
server.add_function(primes)
server.add_function(series)
server.add_function(split)
server.add_function(extract_tcp_info)
Expand Down
4 changes: 4 additions & 0 deletions src/udf/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Add `workers` option in `@udf` and `@udtf` to specify the number of worker processes for CPU bound functions.

## [0.1.0] - 2023-12-01

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion src/udf/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
]
requires-python = ">=3.8"
dependencies = ["pyarrow"]
dependencies = ["pyarrow", "loky"]

[project.optional-dependencies]
test = ["pytest"]
131 changes: 92 additions & 39 deletions src/udf/python/risingwave/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
import inspect
import traceback
import json
from concurrent.futures import ThreadPoolExecutor
import concurrent
from concurrent.futures import Executor, ThreadPoolExecutor
from decimal import Decimal
import signal

# using loky ProcessPoolExecutor instead of concurrent.futures,
# because the latter does not support pickling functions with decorators
from loky import get_reusable_executor


class UserDefinedFunction:
"""
Expand All @@ -33,8 +36,7 @@ class UserDefinedFunction:
_name: str
_input_schema: pa.Schema
_result_schema: pa.Schema
_io_threads: Optional[int]
_executor: Optional[ThreadPoolExecutor]
_executor: Optional[Executor]

def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
"""
Expand All @@ -49,14 +51,13 @@ class ScalarFunction(UserDefinedFunction):
or multiple scalar values to a new scalar value.
"""

def __init__(self, *args, **kwargs):
self._io_threads = kwargs.pop("io_threads")
self._executor = (
ThreadPoolExecutor(max_workers=self._io_threads)
if self._io_threads is not None
else None
)
super().__init__(*args, **kwargs)
def __init__(self, io_threads=None, workers=None):
if io_threads is not None:
self._executor = ThreadPoolExecutor(max_workers=io_threads)
elif workers is not None:
self._executor = get_reusable_executor(max_workers=workers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pondering whether io_threads can also be replaced by this...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. But using threads is more efficient than processes, because multiprocess needs pickling (serializing code) and data transfer between processes. That's why it's even slower than single thread mode for not very computationally heavy functions. 🤪

else:
self._executor = None

def eval(self, *args) -> Any:
"""
Expand All @@ -71,17 +72,15 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
_process_func(pa.list_(type), False)(array)
for array, type in zip(inputs, self._input_schema.types)
]

# evaluate the function for each row
if self._executor is not None:
# evaluate the function for each row
tasks = [
self._executor.submit(self._func, *[col[i] for col in inputs])
for i in range(batch.num_rows)
]
column = [
future.result() for future in concurrent.futures.as_completed(tasks)
]
column = [future.result() for future in tasks]
else:
# evaluate the function for each row
column = [
self.eval(*[col[i] for col in inputs]) for i in range(batch.num_rows)
]
Expand Down Expand Up @@ -135,6 +134,14 @@ class TableFunction(UserDefinedFunction):

BATCH_SIZE = 1024

def __init__(self, io_threads=None, workers=None):
if io_threads is not None:
self._executor = ThreadPoolExecutor(max_workers=io_threads)
elif workers is not None:
self._executor = get_reusable_executor(max_workers=workers)
else:
self._executor = None

def eval(self, *args) -> Iterator:
"""
Method which defines the logic of the table function.
Expand Down Expand Up @@ -175,12 +182,28 @@ def build(self) -> pa.RecordBatch:
builder = RecordBatchBuilder(self._result_schema)

# Iterate through rows in the input RecordBatch
for row_index in range(batch.num_rows):
row = tuple(column[row_index].as_py() for column in batch)
for result in self.eval(*row):
builder.append(row_index, result)
if builder.len() == self.BATCH_SIZE:
yield builder.build()
if self._executor is not None:
input_rows = [
[col[i].as_py() for col in batch] for i in range(batch.num_rows)
]
# XXX: make the function picklable
func = self._func
tasks = [
self._executor.submit(lambda *args: list(func(*args)), *row)
for row in input_rows
]
for row_index, task in enumerate(tasks):
for result in task.result():
builder.append(row_index, result)
if builder.len() == self.BATCH_SIZE:
yield builder.build()
else:
for row_index in range(batch.num_rows):
row = tuple(column[row_index].as_py() for column in batch)
for result in self.eval(*row):
builder.append(row_index, result)
if builder.len() == self.BATCH_SIZE:
yield builder.build()
if builder.len() != 0:
yield builder.build()

Expand All @@ -192,7 +215,10 @@ class UserDefinedScalarFunctionWrapper(ScalarFunction):

_func: Callable

def __init__(self, func, input_types, result_type, name=None, io_threads=None):
def __init__(
self, func, input_types, result_type, name=None, io_threads=None, workers=None
):
super().__init__(io_threads, workers)
self._func = func
self._input_schema = pa.schema(
zip(
Expand All @@ -204,13 +230,18 @@ def __init__(self, func, input_types, result_type, name=None, io_threads=None):
self._name = name or (
func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
)
super().__init__(io_threads=io_threads)

def __call__(self, *args):
return self._func(*args)

def eval(self, *args):
return self._func(*args)
try:
return self._func(*args)
except Exception as e:
print(
f"Error when calling function: {self._name}({', '.join([str(arg) for arg in args])})"
)
raise e


class UserDefinedTableFunctionWrapper(TableFunction):
Expand All @@ -220,7 +251,10 @@ class UserDefinedTableFunctionWrapper(TableFunction):

_func: Callable

def __init__(self, func, input_types, result_types, name=None):
def __init__(
self, func, input_types, result_types, name=None, io_threads=None, workers=None
):
super().__init__(io_threads, workers)
self._func = func
self._name = name or (
func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
Expand All @@ -247,7 +281,13 @@ def __call__(self, *args):
return self._func(*args)

def eval(self, *args):
return self._func(*args)
try:
return self._func(*args)
except Exception as e:
print(
f"Error when calling function: {self._name}({', '.join([str(arg) for arg in args])})"
)
raise e


def _to_list(x):
Expand All @@ -262,6 +302,7 @@ def udf(
result_type: Union[str, pa.DataType],
name: Optional[str] = None,
io_threads: Optional[int] = None,
workers: Optional[int] = None,
) -> Callable:
"""
Annotation for creating a user-defined scalar function.
Expand All @@ -271,6 +312,7 @@ def udf(
- result_type: A string or an Arrow data type that specifies the return value type.
- name: An optional string specifying the function name. If not provided, the original name will be used.
- io_threads: Number of I/O threads used per data chunk for I/O bound functions.
- workers: Number of worker processes used for CPU bound functions.

Example:
```
Expand All @@ -288,22 +330,29 @@ def external_api(x):
response = requests.get(my_endpoint + '?param=' + x)
return response["data"]
```

CPU bound Example:
```
@udf(input_types=["BIGINT"], result_type="BIGINT", workers=10)
def square(x: int) -> int:
sum = 0
for _ in range(x):
sum += x
return sum
```
"""

if io_threads is not None and io_threads > 1:
return lambda f: UserDefinedScalarFunctionWrapper(
f, input_types, result_type, name, io_threads=io_threads
)
else:
return lambda f: UserDefinedScalarFunctionWrapper(
f, input_types, result_type, name
)
return lambda f: UserDefinedScalarFunctionWrapper(
f, input_types, result_type, name, io_threads, workers
)


def udtf(
input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]],
result_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]],
name: Optional[str] = None,
io_threads: Optional[int] = None,
workers: Optional[int] = None,
) -> Callable:
"""
Annotation for creating a user-defined table function.
Expand All @@ -312,6 +361,8 @@ def udtf(
- input_types: A list of strings or Arrow data types that specifies the input data types.
- result_types A list of strings or Arrow data types that specifies the return value types.
- name: An optional string specifying the function name. If not provided, the original name will be used.
- io_threads: Number of I/O threads used per data chunk for I/O bound functions.
- workers: Number of worker processes used for CPU bound functions.

Example:
```
Expand All @@ -322,7 +373,9 @@ def series(n):
```
"""

return lambda f: UserDefinedTableFunctionWrapper(f, input_types, result_types, name)
return lambda f: UserDefinedTableFunctionWrapper(
f, input_types, result_types, name, io_threads, workers
)


class UdfServer(pa.flight.FlightServerBase):
Expand All @@ -344,7 +397,7 @@ class UdfServer(pa.flight.FlightServerBase):
_functions: Dict[str, UserDefinedFunction]

def __init__(self, location="0.0.0.0:8815", **kwargs):
super(UdfServer, self).__init__("grpc://" + location, **kwargs)
super().__init__("grpc://" + location, **kwargs)
self._location = location
self._functions = {}

Expand Down Expand Up @@ -413,7 +466,7 @@ def serve(self):
f"\n\nlistening on {self._location}"
)
signal.signal(signal.SIGTERM, lambda s, f: self.shutdown())
super(UdfServer, self).serve()
super().serve()


def _to_data_type(t: Union[str, pa.DataType]) -> pa.DataType:
Expand Down
Loading