diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index 999c42ec53011..e261de08940e5 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -47,6 +47,32 @@ def gcd3(x: int, y: int, z: int) -> int: return gcd(gcd(x, y), z) +# a function that takes a long time to execute +@udf(input_types=["BIGINT"], result_type="BIGINT", workers=10) +def square(x: int) -> int: + sum = 0 + for _ in range(x): + sum += x + return sum + + +# a table function that takes a long time to execute +# returns all primes in range [x, y) +@udtf(input_types=["INT", "INT"], result_types="INT", workers=10) +def primes(x: int, y: int) -> Iterator[int]: + def is_prime(n: int) -> bool: + if n <= 1: + return False + for i in range(2, n): + if n % i == 0: + return False + return True + + for i in range(x, y): + if is_prime(i): + yield i + + @udf(input_types=["BYTEA"], result_type="STRUCT") def extract_tcp_info(tcp_packet: bytes): src_addr, dst_addr = struct.unpack("!4s4s", tcp_packet[12:20]) @@ -218,6 +244,8 @@ def return_all_arrays( server.add_function(sleep) server.add_function(gcd) server.add_function(gcd3) + server.add_function(square) + server.add_function(primes) server.add_function(series) server.add_function(split) server.add_function(extract_tcp_info) diff --git a/src/expr/udf/python/CHANGELOG.md b/src/expr/udf/python/CHANGELOG.md index a20411e69d83e..86d07f8bc9b58 100644 --- a/src/expr/udf/python/CHANGELOG.md +++ b/src/expr/udf/python/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Add `workers` option in `@udf` and `@udtf` to specify the number of worker processes for CPU bound functions. + ## [0.1.1] - 2023-12-06 ### Fixed diff --git a/src/expr/udf/python/pyproject.toml b/src/expr/udf/python/pyproject.toml index b535355168363..04e933772d1ea 100644 --- a/src/expr/udf/python/pyproject.toml +++ b/src/expr/udf/python/pyproject.toml @@ -14,7 +14,7 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", ] requires-python = ">=3.8" -dependencies = ["pyarrow"] +dependencies = ["pyarrow", "loky"] [project.optional-dependencies] test = ["pytest"] diff --git a/src/expr/udf/python/risingwave/udf.py b/src/expr/udf/python/risingwave/udf.py index 803ab1acbcbfb..0110593774483 100644 --- a/src/expr/udf/python/risingwave/udf.py +++ b/src/expr/udf/python/risingwave/udf.py @@ -19,11 +19,14 @@ import inspect import traceback import json -from concurrent.futures import ThreadPoolExecutor -import concurrent +from concurrent.futures import Executor, ThreadPoolExecutor from decimal import Decimal import signal +# using loky ProcessPoolExecutor instead of concurrent.futures, +# because the latter does not support pickling functions with decorators +from loky import get_reusable_executor + class UserDefinedFunction: """ @@ -33,8 +36,7 @@ class UserDefinedFunction: _name: str _input_schema: pa.Schema _result_schema: pa.Schema - _io_threads: Optional[int] - _executor: Optional[ThreadPoolExecutor] + _executor: Optional[Executor] def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: """ @@ -49,14 +51,13 @@ class ScalarFunction(UserDefinedFunction): or multiple scalar values to a new scalar value. """ - def __init__(self, *args, **kwargs): - self._io_threads = kwargs.pop("io_threads") - self._executor = ( - ThreadPoolExecutor(max_workers=self._io_threads) - if self._io_threads is not None - else None - ) - super().__init__(*args, **kwargs) + def __init__(self, io_threads=None, workers=None): + if io_threads is not None: + self._executor = ThreadPoolExecutor(max_workers=io_threads) + elif workers is not None: + self._executor = get_reusable_executor(max_workers=workers) + else: + self._executor = None def eval(self, *args) -> Any: """ @@ -71,17 +72,15 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: _process_func(pa.list_(type), False)(array) for array, type in zip(inputs, self._input_schema.types) ] + + # evaluate the function for each row if self._executor is not None: - # evaluate the function for each row tasks = [ self._executor.submit(self._func, *[col[i] for col in inputs]) for i in range(batch.num_rows) ] - column = [ - future.result() for future in concurrent.futures.as_completed(tasks) - ] + column = [future.result() for future in tasks] else: - # evaluate the function for each row column = [ self.eval(*[col[i] for col in inputs]) for i in range(batch.num_rows) ] @@ -142,6 +141,14 @@ class TableFunction(UserDefinedFunction): BATCH_SIZE = 1024 + def __init__(self, io_threads=None, workers=None): + if io_threads is not None: + self._executor = ThreadPoolExecutor(max_workers=io_threads) + elif workers is not None: + self._executor = get_reusable_executor(max_workers=workers) + else: + self._executor = None + def eval(self, *args) -> Iterator: """ Method which defines the logic of the table function. @@ -182,12 +189,28 @@ def build(self) -> pa.RecordBatch: builder = RecordBatchBuilder(self._result_schema) # Iterate through rows in the input RecordBatch - for row_index in range(batch.num_rows): - row = tuple(column[row_index].as_py() for column in batch) - for result in self.eval(*row): - builder.append(row_index, result) - if builder.len() == self.BATCH_SIZE: - yield builder.build() + if self._executor is not None: + input_rows = [ + [col[i].as_py() for col in batch] for i in range(batch.num_rows) + ] + # XXX: make the function picklable + func = self._func + tasks = [ + self._executor.submit(lambda *args: list(func(*args)), *row) + for row in input_rows + ] + for row_index, task in enumerate(tasks): + for result in task.result(): + builder.append(row_index, result) + if builder.len() == self.BATCH_SIZE: + yield builder.build() + else: + for row_index in range(batch.num_rows): + row = tuple(column[row_index].as_py() for column in batch) + for result in self.eval(*row): + builder.append(row_index, result) + if builder.len() == self.BATCH_SIZE: + yield builder.build() if builder.len() != 0: yield builder.build() @@ -199,7 +222,10 @@ class UserDefinedScalarFunctionWrapper(ScalarFunction): _func: Callable - def __init__(self, func, input_types, result_type, name=None, io_threads=None): + def __init__( + self, func, input_types, result_type, name=None, io_threads=None, workers=None + ): + super().__init__(io_threads, workers) self._func = func self._input_schema = pa.schema( zip( @@ -211,13 +237,18 @@ def __init__(self, func, input_types, result_type, name=None, io_threads=None): self._name = name or ( func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ ) - super().__init__(io_threads=io_threads) def __call__(self, *args): return self._func(*args) def eval(self, *args): - return self._func(*args) + try: + return self._func(*args) + except Exception as e: + print( + f"Error when calling function: {self._name}({', '.join([str(arg) for arg in args])})" + ) + raise e class UserDefinedTableFunctionWrapper(TableFunction): @@ -227,7 +258,10 @@ class UserDefinedTableFunctionWrapper(TableFunction): _func: Callable - def __init__(self, func, input_types, result_types, name=None): + def __init__( + self, func, input_types, result_types, name=None, io_threads=None, workers=None + ): + super().__init__(io_threads, workers) self._func = func self._name = name or ( func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ @@ -254,7 +288,13 @@ def __call__(self, *args): return self._func(*args) def eval(self, *args): - return self._func(*args) + try: + return self._func(*args) + except Exception as e: + print( + f"Error when calling function: {self._name}({', '.join([str(arg) for arg in args])})" + ) + raise e def _to_list(x): @@ -269,6 +309,7 @@ def udf( result_type: Union[str, pa.DataType], name: Optional[str] = None, io_threads: Optional[int] = None, + workers: Optional[int] = None, ) -> Callable: """ Annotation for creating a user-defined scalar function. @@ -278,6 +319,7 @@ def udf( - result_type: A string or an Arrow data type that specifies the return value type. - name: An optional string specifying the function name. If not provided, the original name will be used. - io_threads: Number of I/O threads used per data chunk for I/O bound functions. + - workers: Number of worker processes used for CPU bound functions. Example: ``` @@ -295,22 +337,29 @@ def external_api(x): response = requests.get(my_endpoint + '?param=' + x) return response["data"] ``` + + CPU bound Example: + ``` + @udf(input_types=["BIGINT"], result_type="BIGINT", workers=10) + def square(x: int) -> int: + sum = 0 + for _ in range(x): + sum += x + return sum + ``` """ - if io_threads is not None and io_threads > 1: - return lambda f: UserDefinedScalarFunctionWrapper( - f, input_types, result_type, name, io_threads=io_threads - ) - else: - return lambda f: UserDefinedScalarFunctionWrapper( - f, input_types, result_type, name - ) + return lambda f: UserDefinedScalarFunctionWrapper( + f, input_types, result_type, name, io_threads, workers + ) def udtf( input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], result_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], name: Optional[str] = None, + io_threads: Optional[int] = None, + workers: Optional[int] = None, ) -> Callable: """ Annotation for creating a user-defined table function. @@ -319,6 +368,8 @@ def udtf( - input_types: A list of strings or Arrow data types that specifies the input data types. - result_types A list of strings or Arrow data types that specifies the return value types. - name: An optional string specifying the function name. If not provided, the original name will be used. + - io_threads: Number of I/O threads used per data chunk for I/O bound functions. + - workers: Number of worker processes used for CPU bound functions. Example: ``` @@ -329,7 +380,9 @@ def series(n): ``` """ - return lambda f: UserDefinedTableFunctionWrapper(f, input_types, result_types, name) + return lambda f: UserDefinedTableFunctionWrapper( + f, input_types, result_types, name, io_threads, workers + ) class UdfServer(pa.flight.FlightServerBase): @@ -351,7 +404,7 @@ class UdfServer(pa.flight.FlightServerBase): _functions: Dict[str, UserDefinedFunction] def __init__(self, location="0.0.0.0:8815", **kwargs): - super(UdfServer, self).__init__("grpc://" + location, **kwargs) + super().__init__("grpc://" + location, **kwargs) self._location = location self._functions = {} @@ -420,7 +473,7 @@ def serve(self): f"\n\nlistening on {self._location}" ) signal.signal(signal.SIGTERM, lambda s, f: self.shutdown()) - super(UdfServer, self).serve() + super().serve() def _to_data_type(t: Union[str, pa.DataType]) -> pa.DataType: