diff --git a/ci/scripts/build-other.sh b/ci/scripts/build-other.sh index b04276ce7a7ae..13119c59bfa08 100755 --- a/ci/scripts/build-other.sh +++ b/ci/scripts/build-other.sh @@ -6,10 +6,13 @@ set -euo pipefail source ci/scripts/common.sh -echo "--- Build Java connector node" +echo "--- Build Java packages" cd java - mvn -B package -Dmaven.test.skip=true +cd .. + echo "--- Upload Java artifacts" -cp connector-node/assembly/target/risingwave-connector-1.0.0.tar.gz ./risingwave-connector.tar.gz +cp java/connector-node/assembly/target/risingwave-connector-1.0.0.tar.gz ./risingwave-connector.tar.gz +cp java/udf/target/risingwave-udf-example.jar ./risingwave-udf-example.jar buildkite-agent artifact upload ./risingwave-connector.tar.gz +buildkite-agent artifact upload ./risingwave-udf-example.jar diff --git a/ci/scripts/run-e2e-test.sh b/ci/scripts/run-e2e-test.sh index e464080992b77..1ffcf56faee7e 100755 --- a/ci/scripts/run-e2e-test.sh +++ b/ci/scripts/run-e2e-test.sh @@ -24,6 +24,7 @@ download_and_prepare_rw "$profile" common echo "--- Download artifacts" download-and-decompress-artifact e2e_test_generated ./ download-and-decompress-artifact risingwave_e2e_extended_mode_test-"$profile" target/debug/ +buildkite-agent artifact download risingwave-udf-example.jar ./ mv target/debug/risingwave_e2e_extended_mode_test-"$profile" target/debug/risingwave_e2e_extended_mode_test chmod +x ./target/debug/risingwave_e2e_extended_mode_test @@ -46,12 +47,18 @@ sqllogictest -p 4566 -d dev './e2e_test/visibility_mode/*.slt' --junit "batch-${ sqllogictest -p 4566 -d dev './e2e_test/database/prepare.slt' sqllogictest -p 4566 -d test './e2e_test/database/test.slt' -echo "--- e2e, ci-3streaming-2serving-3fe, udf" +echo "--- e2e, ci-3streaming-2serving-3fe, python udf" python3 e2e_test/udf/test.py & sleep 2 -sqllogictest -p 4566 -d dev './e2e_test/udf/python.slt' +sqllogictest -p 4566 -d dev './e2e_test/udf/udf.slt' pkill python3 +echo "--- e2e, ci-3streaming-2serving-3fe, java udf" +java -jar risingwave-udf-example.jar & +sleep 2 +sqllogictest -p 4566 -d dev './e2e_test/udf/udf.slt' +pkill java + echo "--- Kill cluster" cargo make ci-kill diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index db5e596bbe9cc..d35565d44bfdc 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -41,10 +41,10 @@ def series(n: int) -> Iterator[int]: yield i -@udtf(input_types="INT", result_types=["INT", "VARCHAR"]) -def series2(n: int) -> Iterator[Tuple[int, str]]: - for i in range(n): - yield i, f"#{i}" +@udtf(input_types="VARCHAR", result_types=["VARCHAR", "INT"]) +def split(string: str) -> Iterator[Tuple[str, int]]: + for s in string.split(" "): + yield s, len(s) @udf(input_types="VARCHAR", result_type="DECIMAL") @@ -100,7 +100,7 @@ def jsonb_array_struct_identity(v: Tuple[List[Any], int]) -> Tuple[List[Any], in server.add_function(gcd) server.add_function(gcd3) server.add_function(series) - server.add_function(series2) + server.add_function(split) server.add_function(extract_tcp_info) server.add_function(hex_to_dec) server.add_function(array_access) diff --git a/e2e_test/udf/python.slt b/e2e_test/udf/udf.slt similarity index 84% rename from e2e_test/udf/python.slt rename to e2e_test/udf/udf.slt index bd3bd0c4a253a..10a0370aa7853 100644 --- a/e2e_test/udf/python.slt +++ b/e2e_test/udf/udf.slt @@ -1,5 +1,7 @@ # Before running this test: # python3 e2e_test/udf/test.py +# or: +# cd src/udf/java && mvn package && java -jar target/risingwave-udf-example.jar # Create a function. statement ok @@ -35,7 +37,7 @@ create function series(int) returns table (x int) language python as series usin # Create a table function that returns multiple columns. statement ok -create function series2(int) returns table (x int, y varchar) language python as series2 using link 'http://localhost:8815'; +create function split(varchar) returns table (word varchar, length int) language python as split using link 'http://localhost:8815'; statement ok create function hex_to_dec(varchar) returns decimal language python as hex_to_dec using link 'http://localhost:8815'; @@ -70,7 +72,7 @@ jsonb_array_identity jsonb[] jsonb[] python http://localhost:8815 jsonb_array_struct_identity struct struct python http://localhost:8815 jsonb_concat jsonb[] jsonb python http://localhost:8815 series integer integer python http://localhost:8815 -series2 integer struct python http://localhost:8815 +split varchar struct python http://localhost:8815 query I select int_42(); @@ -111,12 +113,12 @@ select jsonb_concat(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb [null, 1, "str", {}] query T -select jsonb_array_identity(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]); +select jsonb_array_identity(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]); ---- {NULL,1,"\"str\"","{}"} query T -select jsonb_array_struct_identity(ROW(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb], 4)::struct); +select jsonb_array_struct_identity(ROW(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb], 4)::struct); ---- ({NULL,1,"\"str\"","{}"},4) @@ -130,18 +132,16 @@ select series(5); 4 query IT -select * from series2(3); +select * from split('rising wave'); ---- -0 #0 -1 #1 -2 #2 +rising 6 +wave 4 query T -select series2(3); +select split('rising wave'); ---- -(0,#0) -(1,#1) -(2,#2) +(rising,6) +(wave,4) query II select x, series(x) from series(4) t(x); @@ -153,16 +153,6 @@ select x, series(x) from series(4) t(x); 3 1 3 2 -query IT -select x, series2(x) from series(4) t(x); ----- -1 (0,#0) -2 (0,#0) -2 (1,#1) -3 (0,#0) -3 (1,#1) -3 (2,#2) - # test large output for table function query I select count(*) from series(1000000); @@ -229,3 +219,30 @@ drop function gcd(); # Drop a function without arguments. Now the function name is unique. statement ok drop function gcd; + +statement ok +drop function extract_tcp_info; + +statement ok +drop function series; + +statement ok +drop function split; + +statement ok +drop function hex_to_dec; + +statement ok +drop function array_access; + +statement ok +drop function jsonb_access; + +statement ok +drop function jsonb_concat; + +statement ok +drop function jsonb_array_identity; + +statement ok +drop function jsonb_array_struct_identity; diff --git a/java/pom.xml b/java/pom.xml index b45029847799b..86cc708bfb2be 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -8,6 +8,7 @@ 1.0-SNAPSHOT proto + udf java-binding common-utils java-binding-integration-test diff --git a/java/udf/README.md b/java/udf/README.md new file mode 100644 index 0000000000000..74542f51a801d --- /dev/null +++ b/java/udf/README.md @@ -0,0 +1,8 @@ +## How to run example + +Make sure you have installed Java 11 and Maven 3 or later. + +```sh +mvn package +java -jar target/risingwave-udf-example.jar +``` diff --git a/java/udf/pom.xml b/java/udf/pom.xml new file mode 100644 index 0000000000000..0db50f4071c53 --- /dev/null +++ b/java/udf/pom.xml @@ -0,0 +1,95 @@ + + 4.0.0 + com.risingwave.java + risingwave-udf + jar + 0.0.1 + risingwave-udf + http://maven.apache.org + + + 11 + 11 + + + + + org.junit.jupiter + junit-jupiter-engine + 5.9.1 + test + + + org.apache.arrow + arrow-vector + 12.0.0 + + + org.apache.arrow + flight-core + 12.0.0 + + + com.google.code.gson + gson + 2.10.1 + + + org.slf4j + slf4j-api + 2.0.7 + + + org.slf4j + slf4j-simple + 2.0.7 + + + + + + kr.motd.maven + os-maven-plugin + 1.7.0 + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.0.0-M7 + + --add-opens=java.base/java.nio=ALL-UNNAMED + + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.4.2 + + + + com.risingwave.functions.example.UdfExample + + + + jar-with-dependencies + + risingwave-udf-example + false + + + + udf-example + package + + single + + + + + + + diff --git a/java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java b/java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java new file mode 100644 index 0000000000000..100a953bca75c --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java @@ -0,0 +1,23 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +import java.lang.annotation.*; + +@Retention(RetentionPolicy.RUNTIME) +@Target({ ElementType.METHOD, ElementType.FIELD, ElementType.PARAMETER }) +public @interface DataTypeHint { + String value(); +} diff --git a/java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java b/java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java new file mode 100644 index 0000000000000..a40507602ce77 --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java @@ -0,0 +1,60 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +/** + * Base interface for a user-defined scalar function. A user-defined scalar + * function maps zero, one, or multiple scalar values to a new scalar value. + * + *

+ * The behavior of a {@link ScalarFunction} can be defined by implementing a + * custom evaluation method. An evaluation method must be declared publicly, not + * static, and named eval. Multiple overloaded methods named + * eval are not supported yet. + * + *

+ * By default, input and output data types are automatically extracted using + * reflection. + * + *

+ * The following examples show how to specify a scalar function: + * + *

+ * {@code
+ * // a function that accepts two INT arguments and computes a sum
+ * class SumFunction implements ScalarFunction {
+ *     public Integer eval(Integer a, Integer b) {
+ *         return a + b;
+ *     }
+ * }
+ * 
+ * // a function that returns a struct type
+ * class StructFunction implements ScalarFunction {
+ *     public static class KeyValue {
+ *         public String key;
+ *         public int value;
+ *     }
+ * 
+ *     public KeyValue eval(int a) {
+ *         KeyValue kv = new KeyValue();
+ *         kv.key = a.toString();
+ *         kv.value = a;
+ *         return kv;
+ *     }
+ * }
+ * }
+ */ +public interface ScalarFunction extends UserDefinedFunction { +} diff --git a/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java new file mode 100644 index 0000000000000..e81a1019aa7db --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java @@ -0,0 +1,63 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; + +import java.lang.invoke.MethodHandle; +import java.util.Collections; +import java.util.Iterator; +import java.util.function.Function; + +/** + * Batch-processing wrapper over a user-defined scalar function. + */ +class ScalarFunctionBatch extends UserDefinedFunctionBatch { + ScalarFunction function; + MethodHandle methodHandle; + Function[] processInputs; + + ScalarFunctionBatch(ScalarFunction function, BufferAllocator allocator) { + this.function = function; + this.allocator = allocator; + var method = Reflection.getEvalMethod(function); + this.methodHandle = Reflection.getMethodHandle(method); + this.inputSchema = TypeUtils.methodToInputSchema(method); + this.outputSchema = TypeUtils.methodToOutputSchema(method); + this.processInputs = TypeUtils.methodToProcessInputs(method); + } + + @Override + Iterator evalBatch(VectorSchemaRoot batch) { + var row = new Object[batch.getSchema().getFields().size() + 1]; + row[0] = this.function; + var outputValues = new Object[batch.getRowCount()]; + for (int i = 0; i < batch.getRowCount(); i++) { + for (int j = 0; j < row.length - 1; j++) { + var val = batch.getVector(j).getObject(i); + row[j + 1] = this.processInputs[j].apply(val); + } + try { + outputValues[i] = this.methodHandle.invokeWithArguments(row); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + var outputVector = TypeUtils.createVector(this.outputSchema.getFields().get(0), this.allocator, outputValues); + var outputBatch = VectorSchemaRoot.of(outputVector); + return Collections.singleton(outputBatch).iterator(); + } +} diff --git a/java/udf/src/main/java/com/risingwave/functions/TableFunction.java b/java/udf/src/main/java/com/risingwave/functions/TableFunction.java new file mode 100644 index 0000000000000..45e2266b3ff38 --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/TableFunction.java @@ -0,0 +1,67 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +/** + * Base interface for a user-defined table function. A user-defined table + * function maps zero, one, or multiple scalar values to zero, one, or multiple + * rows (or structured types). If an output record consists of only one field, + * the structured record can be omitted, and a scalar value can be emitted that + * will be implicitly wrapped into a row by the runtime. + * + *

+ * The behavior of a {@link TableFunction} can be defined by implementing a + * custom evaluation method. An evaluation method must be declared publicly, not + * static, and named eval. The return type must be an Iterator. + * Multiple overloaded methods named eval are not supported yet. + * + *

+ * By default, input and output data types are automatically extracted using + * reflection. + * + *

+ * The following examples show how to specify a table function: + * + *

+ * {@code
+ * // a function that accepts an INT arguments and emits the range from 0 to the
+ * // given number.
+ * class Series implements TableFunction {
+ *     public Iterator eval(int n) {
+ *         return IntStream.range(0, n).iterator();
+ *     }
+ * }
+ * 
+ * // a function that accepts an String arguments and emits the words of the
+ * // given string.
+ * class Split implements TableFunction {
+ *     public static class Row {
+ *         public String word;
+ *         public int length;
+ *     }
+ * 
+ *     public Iterator eval(String str) {
+ *         return Stream.of(str.split(" ")).map(s -> {
+ *             Row row = new Row();
+ *             row.word = s;
+ *             row.length = s.length();
+ *             return row;
+ *         }).iterator();
+ *     }
+ * }
+ * }
+ */ +public interface TableFunction extends UserDefinedFunction { +} diff --git a/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java new file mode 100644 index 0000000000000..83b7532e465cd --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java @@ -0,0 +1,88 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.function.Function; + +/** + * Batch-processing wrapper over a user-defined table function. + */ +class TableFunctionBatch extends UserDefinedFunctionBatch { + TableFunction function; + MethodHandle methodHandle; + Function[] processInputs; + int chunkSize = 1024; + + TableFunctionBatch(TableFunction function, BufferAllocator allocator) { + this.function = function; + this.allocator = allocator; + var method = Reflection.getEvalMethod(function); + this.methodHandle = Reflection.getMethodHandle(method); + this.inputSchema = TypeUtils.methodToInputSchema(method); + this.outputSchema = TypeUtils.tableFunctionToOutputSchema(method); + this.processInputs = TypeUtils.methodToProcessInputs(method); + } + + @Override + Iterator evalBatch(VectorSchemaRoot batch) { + var outputs = new ArrayList(); + var row = new Object[batch.getSchema().getFields().size() + 1]; + row[0] = this.function; + var indexes = new ArrayList(); + var values = new ArrayList(); + Runnable buildChunk = () -> { + var fields = this.outputSchema.getFields(); + var indexVector = TypeUtils.createVector(fields.get(0), this.allocator, indexes.toArray()); + var valueVector = TypeUtils.createVector(fields.get(1), this.allocator, values.toArray()); + indexes.clear(); + values.clear(); + var outputBatch = VectorSchemaRoot.of(indexVector, valueVector); + outputs.add(outputBatch); + }; + for (int i = 0; i < batch.getRowCount(); i++) { + // prepare input row + for (int j = 0; j < row.length - 1; j++) { + var val = batch.getVector(j).getObject(i); + row[j + 1] = this.processInputs[j].apply(val); + } + // call function + Iterator iterator; + try { + iterator = (Iterator) this.methodHandle.invokeWithArguments(row); + } catch (Throwable e) { + throw new RuntimeException(e); + } + // push values + while (iterator.hasNext()) { + indexes.add(i); + values.add(iterator.next()); + // check if we need to flush + if (indexes.size() >= this.chunkSize) { + buildChunk.run(); + } + } + } + if (indexes.size() > 0) { + buildChunk.run(); + } + return outputs.iterator(); + } +} diff --git a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java new file mode 100644 index 0000000000000..65d6f168d68b1 --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java @@ -0,0 +1,387 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.*; +import org.apache.arrow.vector.types.pojo.*; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.math.BigDecimal; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +class TypeUtils { + /** + * Convert a string to an Arrow type. + */ + static Field stringToField(String typeStr, String name) { + typeStr = typeStr.toUpperCase(); + if (typeStr.equals("BOOLEAN") || typeStr.equals("BOOL")) { + return Field.nullable(name, new ArrowType.Bool()); + } else if (typeStr.equals("SMALLINT") || typeStr.equals("INT2")) { + return Field.nullable(name, new ArrowType.Int(16, true)); + } else if (typeStr.equals("INT") || typeStr.equals("INTEGER") || typeStr.equals("INT4")) { + return Field.nullable(name, new ArrowType.Int(32, true)); + } else if (typeStr.equals("BIGINT") || typeStr.equals("INT8")) { + return Field.nullable(name, new ArrowType.Int(64, true)); + } else if (typeStr.equals("FLOAT4") || typeStr.equals("REAL")) { + return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); + } else if (typeStr.equals("FLOAT8") || typeStr.equals("DOUBLE PRECISION")) { + return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); + } else if (typeStr.startsWith("DECIMAL") || typeStr.startsWith("NUMERIC")) { + return Field.nullable(name, new ArrowType.Decimal(38, 28, 128)); + } else if (typeStr.equals("DATE")) { + return Field.nullable(name, new ArrowType.Date(DateUnit.DAY)); + } else if (typeStr.equals("TIME") || typeStr.equals("TIME WITHOUT TIME ZONE")) { + return Field.nullable(name, new ArrowType.Time(TimeUnit.MICROSECOND, 32)); + } else if (typeStr.equals("TIMESTAMP") || typeStr.equals("TIMESTAMP WITHOUT TIME ZONE")) { + return Field.nullable(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)); + } else if (typeStr.startsWith("INTERVAL")) { + return Field.nullable(name, new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO)); + } else if (typeStr.equals("VARCHAR")) { + return Field.nullable(name, new ArrowType.Utf8()); + } else if (typeStr.equals("JSONB")) { + return Field.nullable(name, new ArrowType.LargeUtf8()); + } else if (typeStr.equals("BYTEA")) { + return Field.nullable(name, new ArrowType.Binary()); + } else if (typeStr.endsWith("[]")) { + Field innerField = stringToField(typeStr.substring(0, typeStr.length() - 2), ""); + return new Field(name, FieldType.nullable(new ArrowType.List()), Arrays.asList(innerField)); + } else if (typeStr.startsWith("STRUCT")) { + // extract "STRUCT" + var typeList = typeStr.substring(7, typeStr.length() - 1); + var fields = Arrays.stream(typeList.split(",")) + .map(s -> stringToField(s.trim(), "")) + .collect(Collectors.toList()); + return new Field(name, FieldType.nullable(new ArrowType.Struct()), fields); + } else { + throw new IllegalArgumentException("Unsupported type: " + typeStr); + } + } + + /** + * Convert a Java class to an Arrow type. + * + * @param param The Java class. + * @param hint An optional DataTypeHint annotation. + * @param name The name of the field. + * @return The Arrow type. + */ + static Field classToField(Class param, DataTypeHint hint, String name) { + if (hint != null) { + return stringToField(hint.value(), name); + } else if (param == Boolean.class || param == boolean.class) { + return Field.nullable(name, new ArrowType.Bool()); + } else if (param == Short.class || param == short.class) { + return Field.nullable(name, new ArrowType.Int(16, true)); + } else if (param == Integer.class || param == int.class) { + return Field.nullable(name, new ArrowType.Int(32, true)); + } else if (param == Long.class || param == long.class) { + return Field.nullable(name, new ArrowType.Int(64, true)); + } else if (param == Float.class || param == float.class) { + return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); + } else if (param == Double.class || param == double.class) { + return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); + } else if (param == BigDecimal.class) { + return Field.nullable(name, new ArrowType.Decimal(28, 0, 128)); + } else if (param == String.class) { + return Field.nullable(name, new ArrowType.Utf8()); + } else if (param == byte[].class) { + return Field.nullable(name, new ArrowType.Binary()); + } else if (param.isArray()) { + var innerField = classToField(param.getComponentType(), null, ""); + return new Field(name, FieldType.nullable(new ArrowType.List()), Arrays.asList(innerField)); + } else { + // struct type + var fields = new ArrayList(); + for (var field : param.getDeclaredFields()) { + var subhint = field.getAnnotation(DataTypeHint.class); + fields.add(classToField(field.getType(), subhint, field.getName())); + } + return new Field("", FieldType.nullable(new ArrowType.Struct()), fields); + // TODO: more types + // throw new IllegalArgumentException("Unsupported type: " + param); + } + } + + /** + * Get the input schema from a Java method. + */ + static Schema methodToInputSchema(Method method) { + var fields = new ArrayList(); + for (var param : method.getParameters()) { + var hint = param.getAnnotation(DataTypeHint.class); + fields.add(classToField(param.getType(), hint, param.getName())); + } + return new Schema(fields); + } + + /** + * Get the output schema of a scalar function from a Java method. + */ + static Schema methodToOutputSchema(Method method) { + var type = method.getReturnType(); + var hint = method.getAnnotation(DataTypeHint.class); + return new Schema(Arrays.asList(classToField(type, hint, ""))); + } + + /** + * Get the output schema of a table function from a Java class. + */ + static Schema tableFunctionToOutputSchema(Method method) { + var hint = method.getAnnotation(DataTypeHint.class); + var type = method.getReturnType(); + if (!Iterator.class.isAssignableFrom(type)) { + throw new IllegalArgumentException("Table function must return Iterator"); + } + var typeArguments = ((ParameterizedType) method.getGenericReturnType()).getActualTypeArguments(); + type = (Class) typeArguments[0]; + var row_index = Field.nullable("row_index", new ArrowType.Int(32, true)); + return new Schema(Arrays.asList(row_index, classToField(type, hint, ""))); + } + + /** + * Return functions to process input values from a Java method. + */ + static Function[] methodToProcessInputs(Method method) { + var schema = methodToInputSchema(method); + var params = method.getParameters(); + @SuppressWarnings("unchecked") + Function[] funcs = new Function[schema.getFields().size()]; + for (int i = 0; i < schema.getFields().size(); i++) { + funcs[i] = processFunc(schema.getFields().get(i), params[i].getType()); + } + return funcs; + } + + /** + * Create an Arrow vector from an array of values. + */ + static FieldVector createVector(Field field, BufferAllocator allocator, Object[] values) { + var vector = field.createVector(allocator); + fillVector(vector, values); + return vector; + } + + /** + * Fill an Arrow vector with an array of values. + */ + static void fillVector(FieldVector fieldVector, Object[] values) { + if (fieldVector instanceof SmallIntVector) { + var vector = (SmallIntVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (short) values[i]); + } + } + } else if (fieldVector instanceof IntVector) { + var vector = (IntVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (int) values[i]); + } + } + } else if (fieldVector instanceof BigIntVector) { + var vector = (BigIntVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (long) values[i]); + } + } + } else if (fieldVector instanceof Float4Vector) { + var vector = (Float4Vector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (float) values[i]); + } + } + } else if (fieldVector instanceof Float8Vector) { + var vector = (Float8Vector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (double) values[i]); + } + } + } else if (fieldVector instanceof DecimalVector) { + var vector = (DecimalVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (BigDecimal) values[i]); + } + } + } else if (fieldVector instanceof DateDayVector) { + var vector = (DateDayVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (int) values[i]); + } + } + } else if (fieldVector instanceof TimeMicroVector) { + var vector = (TimeMicroVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (long) values[i]); + } + } + } else if (fieldVector instanceof TimeStampMicroVector) { + var vector = (TimeStampMicroVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (long) values[i]); + } + } + } else if (fieldVector instanceof VarCharVector) { + var vector = (VarCharVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, ((String) values[i]).getBytes()); + } + } + } else if (fieldVector instanceof LargeVarCharVector) { + var vector = (LargeVarCharVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, ((String) values[i]).getBytes()); + } + } + } else if (fieldVector instanceof VarBinaryVector) { + var vector = (VarBinaryVector) fieldVector; + vector.allocateNew(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != null) { + vector.set(i, (byte[]) values[i]); + } + } + } else if (fieldVector instanceof ListVector) { + var vector = (ListVector) fieldVector; + vector.allocateNew(); + // we have to enumerate the inner type again + if (vector.getDataVector() instanceof LargeVarCharVector) { + var innerVector = (LargeVarCharVector) vector.getDataVector(); + for (int i = 0; i < values.length; i++) { + var array = (String[]) values[i]; + if (array != null) { + vector.startNewValue(i); + for (int j = 0; j < array.length; j++) { + if (array[j] != null) { + innerVector.setSafe(j, array[j].getBytes()); + } + } + vector.endValue(i, array.length); + } + } + } else { + throw new IllegalArgumentException("Unsupported type: " + fieldVector.getClass()); + } + } else if (fieldVector instanceof StructVector) { + var vector = (StructVector) fieldVector; + vector.allocateNew(); + for (var field : vector.getField().getChildren()) { + // extract field from values + var subvalues = new Object[values.length]; + if (values.length != 0) { + try { + var javaField = values[0].getClass().getDeclaredField(field.getName()); + for (int i = 0; i < values.length; i++) { + subvalues[i] = javaField.get(values[i]); + } + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + var subvector = vector.getChild(field.getName()); + fillVector(subvector, subvalues); + } + for (int i = 0; i < values.length; i++) { + vector.setIndexDefined(i); + } + } else { + throw new IllegalArgumentException("Unsupported type: " + fieldVector.getClass()); + } + fieldVector.setValueCount(values.length); + } + + /** + * Return a function that converts the object get from input array to the + * correct type. + */ + static Function processFunc(Field field, Class targetClass) { + if (field.getType() instanceof ArrowType.Utf8) { + // object is org.apache.arrow.vector.util.Text + return obj -> obj == null ? null : obj.toString(); + } else if (field.getType() instanceof ArrowType.LargeUtf8) { + // object is org.apache.arrow.vector.util.Text + return obj -> obj == null ? null : obj.toString(); + } else if (field.getType() instanceof ArrowType.List) { + // object is org.apache.arrow.vector.util.JsonStringArrayList + var subfield = field.getChildren().get(0); + var subfunc = processFunc(subfield, null); + if (subfield.getType() instanceof ArrowType.Utf8) { + return obj -> obj == null ? null : ((List) obj).stream().map(subfunc).toArray(String[]::new); + } else if (subfield.getType() instanceof ArrowType.LargeUtf8) { + return obj -> obj == null ? null : ((List) obj).stream().map(subfunc).toArray(String[]::new); + } + throw new IllegalArgumentException("Unsupported type: " + field.getType()); + } else if (field.getType() instanceof ArrowType.Struct) { + // object is org.apache.arrow.vector.util.JsonStringHashMap + var subfields = field.getChildren(); + @SuppressWarnings("unchecked") + Function[] subfunc = new Function[subfields.size()]; + for (int i = 0; i < subfields.size(); i++) { + subfunc[i] = processFunc(subfields.get(i), targetClass.getFields()[i].getType()); + } + return obj -> { + if (obj == null) + return null; + var map = (AbstractMap) obj; + try { + var row = targetClass.getDeclaredConstructor().newInstance(); + for (int i = 0; i < subfields.size(); i++) { + var field0 = targetClass.getFields()[i]; + var val = subfunc[i].apply(map.get(field0.getName())); + field0.set(row, val); + } + return row; + } catch (InstantiationException | IllegalAccessException | InvocationTargetException + | NoSuchMethodException e) { + throw new RuntimeException(e); + } + }; + } + return Function.identity(); + } +} diff --git a/java/udf/src/main/java/com/risingwave/functions/UdfServer.java b/java/udf/src/main/java/com/risingwave/functions/UdfServer.java new file mode 100644 index 0000000000000..35d4e7066f5b7 --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/UdfServer.java @@ -0,0 +1,167 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; + +import org.apache.arrow.flight.*; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A server that exposes user-defined functions over Apache Arrow Flight. + */ +public class UdfServer implements AutoCloseable { + + private FlightServer server; + private UdfProducer producer; + private static final Logger logger = LoggerFactory.getLogger(UdfServer.class); + + public UdfServer(String host, int port) { + var location = Location.forGrpcInsecure(host, port); + var allocator = new RootAllocator(); + this.producer = new UdfProducer(allocator); + this.server = FlightServer.builder( + allocator, + location, + this.producer).build(); + } + + /** + * Add a user-defined function to the server. + * + * @param name the name of the function + * @param udf the function to add + * @throws IllegalArgumentException if a function with the same name already + * exists + */ + public void addFunction(String name, UserDefinedFunction udf) throws IllegalArgumentException { + logger.info("added function: " + name); + this.producer.addFunction(name, udf); + } + + /** + * Start the server. + */ + public void start() throws IOException { + this.server.start(); + logger.info("listening on " + this.server.getLocation().toSocketAddress()); + } + + /** + * Get the port the server is listening on. + */ + public int getPort() { + return this.server.getPort(); + } + + /** + * Wait for the server to terminate. + */ + public void awaitTermination() throws InterruptedException { + this.server.awaitTermination(); + } + + /** + * Close the server. + */ + public void close() throws InterruptedException { + this.server.close(); + } +} + +class UdfProducer extends NoOpFlightProducer { + + private BufferAllocator allocator; + private HashMap functions = new HashMap<>(); + private static final Logger logger = LoggerFactory.getLogger(UdfServer.class); + + UdfProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + void addFunction(String name, UserDefinedFunction function) throws IllegalArgumentException { + UserDefinedFunctionBatch udf; + if (function instanceof ScalarFunction) { + udf = new ScalarFunctionBatch((ScalarFunction) function, this.allocator); + } else if (function instanceof TableFunction) { + udf = new TableFunctionBatch((TableFunction) function, this.allocator); + } else { + throw new IllegalArgumentException("Unknown function type: " + function.getClass().getName()); + } + if (functions.containsKey(name)) { + throw new IllegalArgumentException("Function already exists: " + name); + } + functions.put(name, udf); + } + + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + try { + var functionName = descriptor.getPath().get(0); + var udf = functions.get(functionName); + if (udf == null) { + throw new IllegalArgumentException("Unknown function: " + functionName); + } + var fields = new ArrayList(); + fields.addAll(udf.getInputSchema().getFields()); + fields.addAll(udf.getOutputSchema().getFields()); + var fullSchema = new Schema(fields); + var input_len = udf.getInputSchema().getFields().size(); + + return new FlightInfo(fullSchema, descriptor, Collections.emptyList(), 0, input_len); + } catch (Exception e) { + logger.error("Error occurred during getFlightInfo", e); + throw e; + } + } + + @Override + public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + try { + var functionName = reader.getDescriptor().getPath().get(0); + logger.debug("call function: " + functionName); + + var udf = this.functions.get(functionName); + try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), this.allocator)) { + var loader = new VectorLoader(root); + writer.start(root); + while (reader.next()) { + var outputBatches = udf.evalBatch(reader.getRoot()); + while (outputBatches.hasNext()) { + var outputRoot = outputBatches.next(); + var unloader = new VectorUnloader(outputRoot); + loader.load(unloader.getRecordBatch()); + writer.putNext(); + } + } + writer.completed(); + } + } catch (Exception e) { + logger.error("Error occurred during UDF execution", e); + writer.error(e); + } + } +} diff --git a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java new file mode 100644 index 0000000000000..492cd9e245f3c --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java @@ -0,0 +1,24 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +/** + * Base interface for all user-defined functions. + * + * @see ScalarFunction + * @see TableFunction + */ +public interface UserDefinedFunction { +} diff --git a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java new file mode 100644 index 0000000000000..cc12675840076 --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java @@ -0,0 +1,97 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.Iterator; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; + +/** + * Base class for a batch-processing user-defined function. + */ +abstract class UserDefinedFunctionBatch { + protected Schema inputSchema; + protected Schema outputSchema; + protected BufferAllocator allocator; + + /** + * Get the input schema of the function. + */ + Schema getInputSchema() { + return inputSchema; + } + + /** + * Get the output schema of the function. + */ + Schema getOutputSchema() { + return outputSchema; + } + + /** + * Evaluate the function by processing a batch of input data. + * + * @param batch the input data batch to process + * @return an iterator over the output data batches + */ + abstract Iterator evalBatch(VectorSchemaRoot batch); +} + +/** + * Utility class for reflection. + */ +class Reflection { + /** + * Get the method named eval. + */ + static Method getEvalMethod(UserDefinedFunction obj) { + var methods = new ArrayList(); + for (Method method : obj.getClass().getDeclaredMethods()) { + if (method.getName().equals("eval")) { + methods.add(method); + } + } + if (methods.size() != 1) { + throw new IllegalArgumentException( + "Exactly one eval method must be defined for class " + obj.getClass().getName()); + } + var method = methods.get(0); + if (Modifier.isStatic(method.getModifiers())) { + throw new IllegalArgumentException( + "The eval method should not be static for class " + obj.getClass().getName()); + } + return method; + } + + /** + * Get the method handle of the given method. + */ + static MethodHandle getMethodHandle(Method method) { + var lookup = MethodHandles.lookup(); + try { + return lookup.unreflect(method); + } catch (IllegalAccessException e) { + throw new IllegalArgumentException( + "The eval method must be public for class " + method.getDeclaringClass().getName()); + } + } +} diff --git a/java/udf/src/main/java/com/risingwave/functions/example/UdfExample.java b/java/udf/src/main/java/com/risingwave/functions/example/UdfExample.java new file mode 100644 index 0000000000000..bc913813f460f --- /dev/null +++ b/java/udf/src/main/java/com/risingwave/functions/example/UdfExample.java @@ -0,0 +1,185 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions.example; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.stream.Stream; +import java.util.stream.IntStream; + +import com.google.gson.Gson; + +import com.risingwave.functions.DataTypeHint; +import com.risingwave.functions.ScalarFunction; +import com.risingwave.functions.TableFunction; +import com.risingwave.functions.UdfServer; + +public class UdfExample { + public static void main(String[] args) throws IOException { + try (var server = new UdfServer("0.0.0.0", 8815)) { + server.addFunction("int_42", new Int42()); + server.addFunction("gcd", new Gcd()); + server.addFunction("gcd3", new Gcd3()); + server.addFunction("to_string", new ToString()); + server.addFunction("extract_tcp_info", new ExtractTcpInfo()); + server.addFunction("hex_to_dec", new HexToDec()); + server.addFunction("array_access", new ArrayAccess()); + server.addFunction("jsonb_access", new JsonbAccess()); + server.addFunction("jsonb_concat", new JsonbConcat()); + server.addFunction("jsonb_array_identity", new JsonbArrayIdentity()); + server.addFunction("jsonb_array_struct_identity", new JsonbArrayStructIdentity()); + server.addFunction("series", new Series()); + server.addFunction("split", new Split()); + + server.start(); + server.awaitTermination(); + } catch (Exception e) { + e.printStackTrace(); + } + } + + public static class Int42 implements ScalarFunction { + public int eval() { + return 42; + } + } + + public static class Gcd implements ScalarFunction { + public int eval(int a, int b) { + while (b != 0) { + int temp = b; + b = a % b; + a = temp; + } + return a; + } + } + + public static class Gcd3 implements ScalarFunction { + public int eval(int a, int b, int c) { + var gcd = new Gcd(); + return gcd.eval(gcd.eval(a, b), c); + } + } + + public static class ToString implements ScalarFunction { + public String eval(String s) { + return s; + } + } + + public static class ExtractTcpInfo implements ScalarFunction { + public static class TcpPacketInfo { + public String srcAddr; + public String dstAddr; + public short srcPort; + public short dstPort; + } + + public TcpPacketInfo eval(byte[] tcpPacket) { + var info = new TcpPacketInfo(); + var buffer = ByteBuffer.wrap(tcpPacket); + info.srcAddr = intToIpAddr(buffer.getInt(12)); + info.dstAddr = intToIpAddr(buffer.getInt(16)); + info.srcPort = buffer.getShort(20); + info.dstPort = buffer.getShort(22); + return info; + } + + static String intToIpAddr(int addr) { + return String.format("%d.%d.%d.%d", (addr >> 24) & 0xff, (addr >> 16) & 0xff, (addr >> 8) & 0xff, + addr & 0xff); + } + } + + public static class HexToDec implements ScalarFunction { + public BigDecimal eval(String hex) { + if (hex == null) { + return null; + } + return new BigDecimal(new BigInteger(hex, 16)); + } + } + + public static class ArrayAccess implements ScalarFunction { + public String eval(String[] array, int index) { + return array[index - 1]; + } + } + + public static class JsonbAccess implements ScalarFunction { + static Gson gson = new Gson(); + + public @DataTypeHint("JSONB") String eval(@DataTypeHint("JSONB") String json, int index) { + if (json == null) + return null; + var array = gson.fromJson(json, Object[].class); + if (index >= array.length || index < 0) + return null; + var obj = array[index]; + return gson.toJson(obj); + } + } + + public static class JsonbConcat implements ScalarFunction { + public @DataTypeHint("JSONB") String eval(@DataTypeHint("JSONB[]") String[] jsons) { + if (jsons == null) + return null; + return "[" + String.join(",", jsons) + "]"; + } + } + + public static class JsonbArrayIdentity implements ScalarFunction { + public @DataTypeHint("JSONB[]") String[] eval(@DataTypeHint("JSONB[]") String[] jsons) { + return jsons; + } + } + + public static class JsonbArrayStructIdentity implements ScalarFunction { + public static class Row { + public @DataTypeHint("JSONB[]") String[] v; + public int len; + } + + public Row eval(Row s) { + return s; + } + } + + public static class Series implements TableFunction { + public Iterator eval(int n) { + return IntStream.range(0, n).iterator(); + } + } + + public static class Split implements TableFunction { + public static class Row { + public String word; + public int length; + } + + public Iterator eval(String str) { + return Stream.of(str.split(" ")).map(s -> { + Row row = new Row(); + row.word = s; + row.length = s.length(); + return row; + }).iterator(); + } + } +} diff --git a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java new file mode 100644 index 0000000000000..da598fe3dd6d5 --- /dev/null +++ b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java @@ -0,0 +1,109 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +import java.io.IOException; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.risingwave.functions.example.*; + +/** + * Unit test for simple App. + */ +public class TestUdfServer { + private static UdfClient client; + private static UdfServer server; + private static BufferAllocator allocator = new RootAllocator(); + + @BeforeAll + public static void setup() throws IOException { + server = new UdfServer("localhost", 0); + server.addFunction("gcd", new UdfExample.Gcd()); + server.addFunction("to_string", new UdfExample.ToString()); + server.addFunction("series", new UdfExample.Series()); + server.start(); + + client = new UdfClient("localhost", server.getPort()); + } + + @AfterAll + public static void teardown() throws InterruptedException { + client.close(); + server.close(); + } + + @Test + public void gcd() throws Exception { + var c0 = new IntVector("", allocator); + c0.allocateNew(1); + c0.set(0, 15); + c0.setValueCount(1); + + var c1 = new IntVector("", allocator); + c1.allocateNew(1); + c1.set(0, 12); + c1.setValueCount(1); + + var input = VectorSchemaRoot.of(c0, c1); + + try (var stream = client.call("gcd", input)) { + var output = stream.getRoot(); + assertTrue(stream.next()); + assertEquals(output.contentToTSVString().trim(), "3"); + } + } + + @Test + public void to_string() throws Exception { + var c0 = new VarCharVector("", allocator); + c0.allocateNew(1); + c0.set(0, "string".getBytes()); + c0.setValueCount(1); + var input = VectorSchemaRoot.of(c0); + + try (var stream = client.call("to_string", input)) { + var output = stream.getRoot(); + assertTrue(stream.next()); + assertEquals(output.contentToTSVString().trim(), "string"); + } + } + + @Test + public void series() throws Exception { + var c0 = new IntVector("", allocator); + c0.allocateNew(3); + c0.set(0, 0); + c0.set(1, 1); + c0.set(2, 2); + c0.setValueCount(3); + + var input = VectorSchemaRoot.of(c0); + + try (var stream = client.call("series", input)) { + var output = stream.getRoot(); + assertTrue(stream.next()); + assertEquals(output.contentToTSVString(), "row_index\t\n1\t0\n2\t0\n2\t1\n"); + } + } +} diff --git a/java/udf/src/test/java/com/risingwave/functions/UdfClient.java b/java/udf/src/test/java/com/risingwave/functions/UdfClient.java new file mode 100644 index 0000000000000..4db47fc00f521 --- /dev/null +++ b/java/udf/src/test/java/com/risingwave/functions/UdfClient.java @@ -0,0 +1,54 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.functions; + +import org.apache.arrow.flight.*; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class UdfClient implements AutoCloseable { + + private FlightClient client; + private static final Logger logger = LoggerFactory.getLogger(UdfClient.class); + + public UdfClient(String host, int port) { + var allocator = new RootAllocator(); + var location = Location.forGrpcInsecure(host, port); + this.client = FlightClient.builder(allocator, location).build(); + } + + public void close() throws InterruptedException { + this.client.close(); + } + + public FlightInfo getFlightInfo(String functionName) { + var descriptor = FlightDescriptor.command(functionName.getBytes()); + return client.getInfo(descriptor); + } + + public FlightStream call(String functionName, VectorSchemaRoot root) { + var descriptor = FlightDescriptor.path(functionName); + var readerWriter = client.doExchange(descriptor); + var writer = readerWriter.getWriter(); + var reader = readerWriter.getReader(); + + writer.start(root); + writer.putNext(); + writer.completed(); + return reader; + } +} diff --git a/src/batch/src/executor/table_function.rs b/src/batch/src/executor/table_function.rs index 74e40295126d1..3032fdd876ad9 100644 --- a/src/batch/src/executor/table_function.rs +++ b/src/batch/src/executor/table_function.rs @@ -13,7 +13,7 @@ // limitations under the License. use futures_async_stream::try_stream; -use risingwave_common::array::DataChunk; +use risingwave_common::array::{ArrayImpl, DataChunk}; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::error::{Result, RwError}; use risingwave_common::types::DataType; @@ -53,8 +53,11 @@ impl TableFunctionExecutor { #[for_await] for chunk in self.table_function.eval(&dummy_chunk).await { let chunk = chunk?; - // remove the first column - yield chunk.split_column_at(1).1; + // remove the first column and expand the second column if its data type is struct + yield match chunk.column_at(1).as_ref() { + ArrayImpl::Struct(struct_array) => struct_array.into(), + _ => chunk.split_column_at(1).1, + }; } } } diff --git a/src/common/src/array/struct_array.rs b/src/common/src/array/struct_array.rs index d344571af85da..dfb2ad721cfd0 100644 --- a/src/common/src/array/struct_array.rs +++ b/src/common/src/array/struct_array.rs @@ -154,15 +154,6 @@ pub struct StructArray { heap_size: usize, } -impl StructArrayBuilder { - pub fn append_array_refs(&mut self, refs: Vec, len: usize) { - self.bitmap.append_n(len, true); - for (a, r) in self.children_array.iter_mut().zip_eq_fast(refs.iter()) { - a.append_array(r); - } - } -} - impl Array for StructArray { type Builder = StructArrayBuilder; type OwnedItem = StructValue; diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index a9cb555a0190d..a2f62c0e04efa 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -523,6 +523,18 @@ impl FunctionAttr { .map(|i| quote! { self.return_type.as_struct().types().nth(#i).unwrap().clone() }) .collect() }; + let build_value_array = if return_types.len() == 1 { + quote! { let [value_array] = value_arrays; } + } else { + quote! { + let bitmap = value_arrays[0].null_bitmap().clone(); + let value_array = StructArray::new( + self.return_type.as_struct().clone(), + value_arrays.to_vec(), + bitmap, + ).into_ref(); + } + }; let const_arg = match &self.prebuild { Some(_) => quote! { &self.const_arg }, None => quote! {}, @@ -594,25 +606,24 @@ impl FunctionAttr { let #arrays: &#arg_arrays = #array_refs.as_ref().into(); )* - let mut index_builder = I64ArrayBuilder::new(self.chunk_size); + let mut index_builder = I32ArrayBuilder::new(self.chunk_size); #(let mut #builders = #builder_types::with_type(self.chunk_size, #return_types);)* for (i, (row, visible)) in multizip((#(#arrays.iter(),)*)).zip_eq_fast(input.vis().iter()).enumerate() { if let (#(Some(#inputs),)*) = row && visible { let iter = #fn_name(#(#inputs,)* #const_arg); for output in #iter { - index_builder.append(Some(i as i64)); + index_builder.append(Some(i as i32)); match #output { Some((#(#outputs),*)) => { #(#builders.append(Some(#outputs.as_scalar_ref()));)* } None => { #(#builders.append_null();)* } } if index_builder.len() == self.chunk_size { - let columns = vec![ - std::mem::replace(&mut index_builder, I64ArrayBuilder::new(self.chunk_size)).finish().into_ref(), - #(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref(),)* - ]; - yield DataChunk::new(columns, self.chunk_size); + let index_array = std::mem::replace(&mut index_builder, I32ArrayBuilder::new(self.chunk_size)).finish().into_ref(); + let value_arrays = [#(std::mem::replace(&mut #builders, #builder_types::with_type(self.chunk_size, #return_types)).finish().into_ref()),*]; + #build_value_array + yield DataChunk::new(vec![index_array, value_array], self.chunk_size); } } } @@ -620,11 +631,10 @@ impl FunctionAttr { if index_builder.len() > 0 { let len = index_builder.len(); - let columns = vec![ - index_builder.finish().into_ref(), - #(#builders.finish().into_ref(),)* - ]; - yield DataChunk::new(columns, len); + let index_array = index_builder.finish().into_ref(); + let value_arrays = [#(#builders.finish().into_ref()),*]; + #build_value_array + yield DataChunk::new(vec![index_array, value_array], len); } } } diff --git a/src/expr/src/expr/expr_nested_construct.rs b/src/expr/src/expr/expr_nested_construct.rs index c7723b15fa0c7..ece26ed138258 100644 --- a/src/expr/src/expr/expr_nested_construct.rs +++ b/src/expr/src/expr/expr_nested_construct.rs @@ -16,7 +16,7 @@ use std::convert::TryFrom; use std::sync::Arc; use risingwave_common::array::{ - ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, ListArrayBuilder, ListValue, StructArrayBuilder, + ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, ListArrayBuilder, ListValue, StructArray, StructValue, }; use risingwave_common::row::OwnedRow; @@ -45,11 +45,9 @@ impl Expression for NestedConstructExpression { columns.push(e.eval_checked(input).await?); } - if let DataType::Struct(_) = &self.data_type { - let mut builder = - StructArrayBuilder::with_type(input.capacity(), self.data_type.clone()); - builder.append_array_refs(columns, input.capacity()); - Ok(Arc::new(ArrayImpl::Struct(builder.finish()))) + if let DataType::Struct(ty) = &self.data_type { + let array = StructArray::new(ty.clone(), columns, input.vis().to_bitmap()); + Ok(Arc::new(ArrayImpl::Struct(array))) } else if let DataType::List { .. } = &self.data_type { let chunk = DataChunk::new(columns, input.vis().clone()); let mut builder = ListArrayBuilder::with_type(input.capacity(), self.data_type.clone()); diff --git a/src/expr/src/table_function/mod.rs b/src/expr/src/table_function/mod.rs index e996331ca5f70..c7f891632b651 100644 --- a/src/expr/src/table_function/mod.rs +++ b/src/expr/src/table_function/mod.rs @@ -12,16 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use either::Either; use futures_async_stream::try_stream; use futures_util::stream::BoxStream; use futures_util::StreamExt; use itertools::Itertools; -use risingwave_common::array::{ - Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, I64ArrayBuilder, StructArray, -}; +use risingwave_common::array::{Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk}; use risingwave_common::types::{DataType, DataTypeName, DatumRef}; use risingwave_pb::expr::project_set_select_item::SelectItem; use risingwave_pb::expr::table_function::PbType; @@ -51,10 +47,10 @@ pub trait TableFunction: std::fmt::Debug + Sync + Send { /// # Contract of the output /// - /// The returned `DataChunk` contains at least two columns: - /// - The first column is the row indexes of input chunk. It should be monotonically increasing. - /// - The remaining columns are the output values. More than one columns are allowed, which will - /// be transformed into a single `STRUCT` column later. + /// The returned `DataChunk` contains exact two columns: + /// - The first column is an I32Array containing row indexes of input chunk. It should be + /// monotonically increasing. + /// - The second column is the output values. The data type of the column is `return_type`. /// /// i.e., for the `i`-th input row, the output rows are `(i, output_1)`, `(i, output_2)`, ... /// @@ -219,12 +215,12 @@ impl ProjectSetSelectItem { /// let mut iter = TableFunctionOutputIter::new( /// futures_util::stream::iter([ /// DataChunk::from_pretty( -/// "I I +/// "i I /// 0 0 /// 1 1", /// ), /// DataChunk::from_pretty( -/// "I I +/// "i I /// 2 2 /// 3 3", /// ), @@ -245,7 +241,7 @@ impl ProjectSetSelectItem { /// ``` pub struct TableFunctionOutputIter<'a> { stream: BoxStream<'a, Result>, - chunk: Option<(ArrayRef, ArrayRef)>, + chunk: Option, index: usize, } @@ -264,9 +260,9 @@ impl<'a> TableFunctionOutputIter<'a> { /// Gets the current row. pub fn peek(&'a self) -> Option<(usize, DatumRef<'a>)> { - let (indexes, values) = self.chunk.as_ref()?; - let index = indexes.as_int64().value_at(self.index).unwrap() as usize; - let value = values.value_at(self.index); + let chunk = self.chunk.as_ref()?; + let index = chunk.column_at(0).as_int32().value_at(self.index).unwrap() as usize; + let value = chunk.column_at(1).value_at(self.index); Some((index, value)) } @@ -274,10 +270,10 @@ impl<'a> TableFunctionOutputIter<'a> { /// /// This method is cancellation safe. pub async fn next(&mut self) -> Result<()> { - let Some((indexes, _)) = &self.chunk else { + let Some(chunk) = &self.chunk else { return Ok(()); }; - if self.index + 1 == indexes.len() { + if self.index + 1 == chunk.capacity() { // note: for cancellation safety, do not mutate self before await. self.pop_from_stream().await?; self.index = 0; @@ -289,17 +285,7 @@ impl<'a> TableFunctionOutputIter<'a> { /// Gets the next chunk from stream. async fn pop_from_stream(&mut self) -> Result<()> { - let chunk = self.stream.next().await.transpose()?; - self.chunk = chunk.map(|c| { - let (c1, c2) = c.split_column_at(1); - let indexes = c1.column_at(0).clone(); - let values = if c2.columns().len() > 1 { - Arc::new(StructArray::from(c2).into()) - } else { - c2.column_at(0).clone() - }; - (indexes, values) - }); + self.chunk = self.stream.next().await.transpose()?; Ok(()) } } diff --git a/src/expr/src/table_function/repeat.rs b/src/expr/src/table_function/repeat.rs index b73d0df82a8c1..255c8d9c0a68b 100644 --- a/src/expr/src/table_function/repeat.rs +++ b/src/expr/src/table_function/repeat.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_common::array::I32ArrayBuilder; + use super::*; /// Repeat an expression n times. @@ -43,10 +45,10 @@ impl RepeatN { async fn eval_inner<'a>(&'a self, input: &'a DataChunk) { let array = self.expr.eval(input).await?; - let mut index_builder = I64ArrayBuilder::new(0x100); + let mut index_builder = I32ArrayBuilder::new(0x100); let mut value_builder = self.return_type().create_array_builder(0x100); for (i, value) in array.iter().enumerate() { - index_builder.append_n(self.n, Some(i as i64)); + index_builder.append_n(self.n, Some(i as i32)); value_builder.append_n(self.n, value); } let len = index_builder.len(); diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 8cf6cfe39d8af..120e6f6d17c79 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -138,18 +138,10 @@ pub async fn handle_create_function( let args = arrow_schema::Schema::new(arg_types.iter().map(|t| to_field(t.into())).collect()); let returns = arrow_schema::Schema::new(match kind { Kind::Scalar(_) => vec![to_field(return_type.clone().into())], - Kind::Table(_) => { - let mut fields = vec![arrow_schema::Field::new( - "row_index", - arrow_schema::DataType::Int64, - true, - )]; - match &return_type { - DataType::Struct(s) => fields.extend(s.types().map(|t| to_field(t.clone().into()))), - _ => fields.push(to_field(return_type.clone().into())), - } - fields - } + Kind::Table(_) => vec![ + arrow_schema::Field::new("row_index", arrow_schema::DataType::Int32, true), + to_field(return_type.clone().into()), + ], _ => unreachable!(), }); client diff --git a/src/udf/python/risingwave/udf.py b/src/udf/python/risingwave/udf.py index aeb90bf2f3c6b..853003b871e64 100644 --- a/src/udf/python/risingwave/udf.py +++ b/src/udf/python/risingwave/udf.py @@ -131,11 +131,10 @@ def len(self) -> int: """Returns the number of rows in the RecordBatch being built.""" return len(self.columns[0]) - def append(self, index: int, args: Tuple): + def append(self, index: int, value: Any): """Appends a new row to the RecordBatch being built.""" self.columns[0].append(index) - for column, val in zip(self.columns[1:], args): - column.append(val) + self.columns[1].append(value) def build(self) -> pa.RecordBatch: """Builds the RecordBatch from the accumulated data and clears the state.""" @@ -153,10 +152,8 @@ def build(self) -> pa.RecordBatch: # Iterate through rows in the input RecordBatch for row_index in range(batch.num_rows): row = tuple(column[row_index].as_py() for column in batch) - for result_row in self.eval(*row): - if not isinstance(result_row, tuple): - result_row = (result_row,) - builder.append(row_index, result_row) + for result in self.eval(*row): + builder.append(row_index, result) if builder.len() == self.BATCH_SIZE: yield builder.build() if builder.len() != 0: @@ -200,6 +197,9 @@ class UserDefinedTableFunctionWrapper(TableFunction): def __init__(self, func, input_types, result_types, name=None): self._func = func + self._name = name or ( + func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ + ) self._input_schema = pa.schema( zip( inspect.getfullargspec(func)[0], @@ -207,11 +207,15 @@ def __init__(self, func, input_types, result_types, name=None): ) ) self._result_schema = pa.schema( - [("row_index", pa.int64())] - + [("", _to_data_type(t)) for t in _to_list(result_types)] - ) - self._name = name or ( - func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ + [ + ("row_index", pa.int32()), + ( + self._name, + pa.struct([("", _to_data_type(t)) for t in result_types]) + if isinstance(result_types, list) + else _to_data_type(result_types), + ), + ] ) def __call__(self, *args): diff --git a/src/udf/python/setup.py b/src/udf/python/setup.py index 85db2f059e608..6c169dfb10e3e 100644 --- a/src/udf/python/setup.py +++ b/src/udf/python/setup.py @@ -5,7 +5,7 @@ setup( name="risingwave", - version="0.0.7", + version="0.0.8", author="RisingWave Labs", description="RisingWave Python API", long_description=long_description, @@ -17,8 +17,6 @@ "License :: OSI Approved :: Apache Software License", ], python_requires=">=3.8", - install_requires=['pyarrow'], - extras_require={ - 'test': ['pytest'] - }, + install_requires=["pyarrow"], + extras_require={"test": ["pytest"]}, )