Skip to content

Commit

Permalink
pass all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wangrunji0408 committed Jun 9, 2023
1 parent b6813c3 commit 0effbc6
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 61 deletions.
30 changes: 2 additions & 28 deletions e2e_test/udf/python.slt
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,19 @@ create function hex_to_dec(varchar) returns decimal language python as hex_to_de
statement ok
create function array_access(varchar[], int) returns varchar language python as array_access using link 'http://localhost:8815';

skipif java
statement ok
create function jsonb_access(jsonb, int) returns jsonb language python as jsonb_access using link 'http://localhost:8815';

skipif java
statement ok
create function jsonb_concat(jsonb[]) returns jsonb language python as jsonb_concat using link 'http://localhost:8815';

skipif java
statement ok
create function jsonb_array_identity(jsonb[]) returns jsonb[] language python as jsonb_array_identity using link 'http://localhost:8815';

skipif java
statement ok
create function jsonb_array_struct_identity(struct<v jsonb[], len int>) returns struct<v jsonb[], len int>
language python as jsonb_array_struct_identity using link 'http://localhost:8815';

skipif java
query TTTTT rowsort
show functions
----
Expand All @@ -77,19 +72,6 @@ jsonb_concat jsonb[] jsonb python http://localhost:8815
series integer integer python http://localhost:8815
split varchar struct<word varchar,length integer> python http://localhost:8815

onlyif java
query TTTTT rowsort
show functions
----
array_access varchar[], integer varchar python http://localhost:8815
extract_tcp_info bytea struct<src_ip varchar,dst_ip varchar,src_port smallint,dst_port smallint> python http://localhost:8815
gcd integer, integer integer python http://localhost:8815
gcd integer, integer, integer integer python http://localhost:8815
hex_to_dec varchar numeric python http://localhost:8815
int_42 (empty) integer python http://localhost:8815
series integer integer python http://localhost:8815
split varchar struct<word varchar,length integer> python http://localhost:8815

query I
select int_42();
----
Expand All @@ -115,7 +97,6 @@ select array_access(ARRAY['a', 'b', 'c'], 2);
----
b

skipif java
query T
select jsonb_access(a::jsonb, 1) from
(values ('["a", "b", "c"]'), (null), ('[0, false]')) t(a);
Expand All @@ -124,21 +105,18 @@ select jsonb_access(a::jsonb, 1) from
NULL
false

skipif java
query T
select jsonb_concat(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]);
----
[null, 1, "str", {}]

skipif java
query T
select jsonb_array_identity(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]);
select jsonb_array_identity(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]);
----
{NULL,1,"\"str\"","{}"}

skipif java
query T
select jsonb_array_struct_identity(ROW(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb], 4)::struct<v jsonb[], len int>);
select jsonb_array_struct_identity(ROW(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb], 4)::struct<v jsonb[], len int>);
----
({NULL,1,"\"str\"","{}"},4)

Expand Down Expand Up @@ -255,18 +233,14 @@ drop function hex_to_dec;
statement ok
drop function array_access;

skipif java
statement ok
drop function jsonb_access;

skipif java
statement ok
drop function jsonb_concat;

skipif java
statement ok
drop function jsonb_array_identity;

skipif java
statement ok
drop function jsonb_array_struct_identity;
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,13 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch {
Method method;
Function<Object, Object>[] processInputs;

@SuppressWarnings("unchecked")
ScalarFunctionBatch(ScalarFunction function, BufferAllocator allocator) {
this.function = function;
this.allocator = allocator;
this.method = Reflection.getEvalMethod(function);
this.inputSchema = TypeUtils.methodToInputSchema(this.method);
this.outputSchema = TypeUtils.methodToOutputSchema(this.method);
this.processInputs = this.inputSchema.getFields().stream()
.map(TypeUtils::processFunc)
.toArray(Function[]::new);
this.processInputs = TypeUtils.methodToProcessInputs(this.method);
}

@Override
Expand All @@ -69,7 +66,7 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
for (int i = 0; i < batch.getRowCount(); i++) {
for (int j = 0; j < row.length; j++) {
var val = batch.getVector(j).getObject(i);
row[j] = val == null ? null : this.processInputs[j].apply(val);
row[j] = this.processInputs[j].apply(val);
}
try {
outputValues[i] = this.method.invoke(this.function, row);
Expand All @@ -93,16 +90,13 @@ class TableFunctionBatch extends UserDefinedFunctionBatch {
Function<Object, Object>[] processInputs;
int chunk_size = 1024;

@SuppressWarnings("unchecked")
TableFunctionBatch(TableFunction<?> function, BufferAllocator allocator) {
this.function = function;
this.allocator = allocator;
this.method = Reflection.getEvalMethod(function);
this.inputSchema = TypeUtils.methodToInputSchema(this.method);
this.outputSchema = TypeUtils.tableFunctionToOutputSchema(function.getClass());
this.processInputs = this.inputSchema.getFields().stream()
.map(TypeUtils::processFunc)
.toArray(Function[]::new);
this.processInputs = TypeUtils.methodToProcessInputs(this.method);
}

@Override
Expand All @@ -126,7 +120,7 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
// prepare input row
for (int j = 0; j < row.length; j++) {
var val = batch.getVector(j).getObject(i);
row[j] = val == null ? null : this.processInputs[j].apply(val);
row[j] = this.processInputs[j].apply(val);
}
// call function
var size_before = this.function.size();
Expand Down
100 changes: 82 additions & 18 deletions src/udf/java/src/main/java/com/risingwave/functions/TypeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.types.*;
import org.apache.arrow.vector.types.pojo.*;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.math.BigDecimal;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -68,11 +71,14 @@ static Field stringToField(String typeStr, String name) {
* Convert a Java class to an Arrow type.
*
* @param param The Java class.
* @param hint An optional DataTypeHint annotation.
* @param name The name of the field.
* @return The Arrow type.
*/
static Field classToField(Class<?> param, String name) {
if (param == Boolean.class || param == boolean.class) {
static Field classToField(Class<?> param, DataTypeHint hint, String name) {
if (hint != null) {
return stringToField(hint.value(), name);
} else if (param == Boolean.class || param == boolean.class) {
return Field.nullable(name, new ArrowType.Bool());
} else if (param == Short.class || param == short.class) {
return Field.nullable(name, new ArrowType.Int(16, true));
Expand All @@ -91,13 +97,14 @@ static Field classToField(Class<?> param, String name) {
} else if (param == byte[].class) {
return Field.nullable(name, new ArrowType.Binary());
} else if (param.isArray()) {
var innerField = classToField(param.getComponentType(), "");
var innerField = classToField(param.getComponentType(), null, "");
return new Field(name, FieldType.nullable(new ArrowType.List()), Arrays.asList(innerField));
} else {
// struct type
var fields = new ArrayList<Field>();
for (var field : param.getDeclaredFields()) {
fields.add(classToField(field.getType(), field.getName()));
var subhint = field.getAnnotation(DataTypeHint.class);
fields.add(classToField(field.getType(), subhint, field.getName()));
}
return new Field("", FieldType.nullable(new ArrowType.Struct()), fields);
// TODO: more types
Expand All @@ -112,11 +119,7 @@ static Schema methodToInputSchema(Method method) {
var fields = new ArrayList<Field>();
for (var param : method.getParameters()) {
var hint = param.getAnnotation(DataTypeHint.class);
if (hint != null) {
fields.add(stringToField(hint.value(), param.getName()));
} else {
fields.add(classToField(param.getType(), param.getName()));
}
fields.add(classToField(param.getType(), hint, param.getName()));
}
return new Schema(fields);
}
Expand All @@ -127,20 +130,34 @@ static Schema methodToInputSchema(Method method) {
static Schema methodToOutputSchema(Method method) {
var type = method.getReturnType();
var hint = method.getAnnotation(DataTypeHint.class);
var field = hint != null ? stringToField(hint.value(), "") : classToField(type, "");
return new Schema(Arrays.asList(field));
return new Schema(Arrays.asList(classToField(type, hint, "")));
}

/**
* Get the output schema of a table function from a Java class.
*/
static Schema tableFunctionToOutputSchema(Class<?> type) {
var hint = type.getAnnotation(DataTypeHint.class);
var parameterizedType = (ParameterizedType) type.getGenericSuperclass();
var typeArguments = parameterizedType.getActualTypeArguments();
type = (Class<?>) typeArguments[0];

var row_index = Field.nullable("row_index", new ArrowType.Int(32, true));
return new Schema(Arrays.asList(row_index, classToField(type, "")));
return new Schema(Arrays.asList(row_index, classToField(type, hint, "")));
}

/**
* Return functions to process input values from a Java method.
*/
static Function<Object, Object>[] methodToProcessInputs(Method method) {
var schema = methodToInputSchema(method);
var params = method.getParameters();
@SuppressWarnings("unchecked")
Function<Object, Object>[] funcs = new Function[schema.getFields().size()];
for (int i = 0; i < schema.getFields().size(); i++) {
funcs[i] = processFunc(schema.getFields().get(i), params[i].getType());
}
return funcs;
}

/**
Expand Down Expand Up @@ -252,6 +269,27 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
vector.set(i, (byte[]) values[i]);
}
}
} else if (fieldVector instanceof ListVector) {
var vector = (ListVector) fieldVector;
vector.allocateNew();
// we have to enumerate the inner type again
if (vector.getDataVector() instanceof LargeVarCharVector) {
var innerVector = (LargeVarCharVector) vector.getDataVector();
for (int i = 0; i < values.length; i++) {
var array = (String[]) values[i];
if (array != null) {
vector.startNewValue(i);
for (int j = 0; j < array.length; j++) {
if (array[j] != null) {
innerVector.setSafe(j, array[j].getBytes());
}
}
vector.endValue(i, array.length);
}
}
} else {
throw new IllegalArgumentException("Unsupported type: " + fieldVector.getClass());
}
} else if (fieldVector instanceof StructVector) {
var vector = (StructVector) fieldVector;
vector.allocateNew();
Expand Down Expand Up @@ -284,22 +322,48 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
* Return a function that converts the object get from input array to the
* correct type.
*/
static Function<Object, Object> processFunc(Field field) {
static Function<Object, Object> processFunc(Field field, Class<?> targetClass) {
if (field.getType() instanceof ArrowType.Utf8) {
// object is org.apache.arrow.vector.util.Text
return obj -> obj.toString();
return obj -> obj == null ? null : obj.toString();
} else if (field.getType() instanceof ArrowType.LargeUtf8) {
// object is org.apache.arrow.vector.util.Text
return obj -> obj.toString();
return obj -> obj == null ? null : obj.toString();
} else if (field.getType() instanceof ArrowType.List) {
// object is org.apache.arrow.vector.util.JsonStringArrayList
var subfield = field.getChildren().get(0);
var subfunc = processFunc(subfield);
var subfunc = processFunc(subfield, null);
if (subfield.getType() instanceof ArrowType.Utf8) {
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(String[]::new);
return obj -> obj == null ? null : ((List<?>) obj).stream().map(subfunc).toArray(String[]::new);
} else if (subfield.getType() instanceof ArrowType.LargeUtf8) {
return obj -> ((List<?>) obj).stream().map(subfunc).toArray(String[]::new);
return obj -> obj == null ? null : ((List<?>) obj).stream().map(subfunc).toArray(String[]::new);
}
throw new IllegalArgumentException("Unsupported type: " + field.getType());
} else if (field.getType() instanceof ArrowType.Struct) {
// object is org.apache.arrow.vector.util.JsonStringHashMap
var subfields = field.getChildren();
@SuppressWarnings("unchecked")
Function<Object, Object>[] subfunc = new Function[subfields.size()];
for (int i = 0; i < subfields.size(); i++) {
subfunc[i] = processFunc(subfields.get(i), targetClass.getFields()[i].getType());
}
return obj -> {
if (obj == null)
return null;
var map = (AbstractMap<?, ?>) obj;
try {
var row = targetClass.getDeclaredConstructor().newInstance();
for (int i = 0; i < subfields.size(); i++) {
var field0 = targetClass.getFields()[i];
var val = subfunc[i].apply(map.get(field0.getName()));
field0.set(row, val);
}
return row;
} catch (InstantiationException | IllegalAccessException | InvocationTargetException
| NoSuchMethodException e) {
throw new RuntimeException(e);
}
};
}
return Function.identity();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public static void main(String[] args) throws IOException {
server.addFunction("hex_to_dec", new HexToDec());
server.addFunction("array_access", new ArrayAccess());
server.addFunction("jsonb_access", new JsonbAccess());
server.addFunction("jsonb_concat", new JsonbConcat());
server.addFunction("jsonb_array_identity", new JsonbArrayIdentity());
server.addFunction("jsonb_array_struct_identity", new JsonbArrayStructIdentity());
server.addFunction("series", new Series());
server.addFunction("split", new Split());

Expand Down Expand Up @@ -114,6 +117,31 @@ public static class JsonbAccess extends ScalarFunction {
}
}

public static class JsonbConcat extends ScalarFunction {
public static @DataTypeHint("JSONB") String eval(@DataTypeHint("JSONB[]") String[] jsons) {
if (jsons == null)
return null;
return "[" + String.join(",", jsons) + "]";
}
}

public static class JsonbArrayIdentity extends ScalarFunction {
public static @DataTypeHint("JSONB[]") String[] eval(@DataTypeHint("JSONB[]") String[] jsons) {
return jsons;
}
}

public static class JsonbArrayStructIdentity extends ScalarFunction {
public static class Row {
public @DataTypeHint("JSONB[]") String[] v;
public int len;
}

public static Row eval(Row s) {
return s;
}
}

public static class Series extends TableFunction<Integer> {
public void eval(int n) {
for (int i = 0; i < n; i++) {
Expand Down
8 changes: 3 additions & 5 deletions src/udf/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="risingwave",
version="0.0.7",
version="0.0.8",
author="RisingWave Labs",
description="RisingWave Python API",
long_description=long_description,
Expand All @@ -17,8 +17,6 @@
"License :: OSI Approved :: Apache Software License",
],
python_requires=">=3.8",
install_requires=['pyarrow'],
extras_require={
'test': ['pytest']
},
install_requires=["pyarrow"],
extras_require={"test": ["pytest"]},
)

0 comments on commit 0effbc6

Please sign in to comment.