Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(udf): fix memory leak in Java UDF #13789

Merged
merged 2 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions java/udf/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fix index-out-of-bound error when string or string list is large.
- Fix memory leak.

## [0.1.1] - 2023-12-03

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch {
MethodHandle methodHandle;
Function<Object, Object>[] processInputs;

ScalarFunctionBatch(ScalarFunction function, BufferAllocator allocator) {
ScalarFunctionBatch(ScalarFunction function) {
this.function = function;
this.allocator = allocator;
var method = Reflection.getEvalMethod(function);
this.methodHandle = Reflection.getMethodHandle(method);
this.inputSchema = TypeUtils.methodToInputSchema(method);
Expand All @@ -38,7 +37,7 @@ class ScalarFunctionBatch extends UserDefinedFunctionBatch {
}

@Override
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) {
var row = new Object[batch.getSchema().getFields().size() + 1];
row[0] = this.function;
var outputValues = new Object[batch.getRowCount()];
Expand All @@ -55,7 +54,7 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
}
var outputVector =
TypeUtils.createVector(
this.outputSchema.getFields().get(0), this.allocator, outputValues);
this.outputSchema.getFields().get(0), allocator, outputValues);
var outputBatch = VectorSchemaRoot.of(outputVector);
return Collections.singleton(outputBatch).iterator();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ class TableFunctionBatch extends UserDefinedFunctionBatch {
Function<Object, Object>[] processInputs;
int chunkSize = 1024;

TableFunctionBatch(TableFunction function, BufferAllocator allocator) {
TableFunctionBatch(TableFunction function) {
this.function = function;
this.allocator = allocator;
var method = Reflection.getEvalMethod(function);
this.methodHandle = Reflection.getMethodHandle(method);
this.inputSchema = TypeUtils.methodToInputSchema(method);
Expand All @@ -39,7 +38,7 @@ class TableFunctionBatch extends UserDefinedFunctionBatch {
}

@Override
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) {
var outputs = new ArrayList<VectorSchemaRoot>();
var row = new Object[batch.getSchema().getFields().size() + 1];
row[0] = this.function;
Expand All @@ -49,10 +48,9 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
() -> {
var fields = this.outputSchema.getFields();
var indexVector =
TypeUtils.createVector(
fields.get(0), this.allocator, indexes.toArray());
TypeUtils.createVector(fields.get(0), allocator, indexes.toArray());
var valueVector =
TypeUtils.createVector(fields.get(1), this.allocator, values.toArray());
TypeUtils.createVector(fields.get(1), allocator, values.toArray());
indexes.clear();
values.clear();
var outputBatch = VectorSchemaRoot.of(indexVector, valueVector);
Expand Down
25 changes: 15 additions & 10 deletions java/udf/src/main/java/com/risingwave/functions/UdfProducer.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class UdfProducer extends NoOpFlightProducer {
void addFunction(String name, UserDefinedFunction function) throws IllegalArgumentException {
UserDefinedFunctionBatch udf;
if (function instanceof ScalarFunction) {
udf = new ScalarFunctionBatch((ScalarFunction) function, this.allocator);
udf = new ScalarFunctionBatch((ScalarFunction) function);
} else if (function instanceof TableFunction) {
udf = new TableFunctionBatch((TableFunction) function, this.allocator);
udf = new TableFunctionBatch((TableFunction) function);
} else {
throw new IllegalArgumentException(
"Unknown function type: " + function.getClass().getName());
Expand Down Expand Up @@ -76,21 +76,26 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor

@Override
public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) {
try {
try (var allocator = this.allocator.newChildAllocator("exchange", 0, Long.MAX_VALUE)) {
var functionName = reader.getDescriptor().getPath().get(0);
logger.debug("call function: " + functionName);

var udf = this.functions.get(functionName);
try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), this.allocator)) {
try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), allocator)) {
var loader = new VectorLoader(root);
writer.start(root);
while (reader.next()) {
var outputBatches = udf.evalBatch(reader.getRoot());
while (outputBatches.hasNext()) {
var outputRoot = outputBatches.next();
var unloader = new VectorUnloader(outputRoot);
loader.load(unloader.getRecordBatch());
writer.putNext();
try (var input = reader.getRoot()) {
var outputBatches = udf.evalBatch(input, allocator);
while (outputBatches.hasNext()) {
try (var outputRoot = outputBatches.next()) {
var unloader = new VectorUnloader(outputRoot);
try (var outputBatch = unloader.getRecordBatch()) {
loader.load(outputBatch);
}
}
writer.putNext();
}
}
}
writer.completed();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
abstract class UserDefinedFunctionBatch {
protected Schema inputSchema;
protected Schema outputSchema;
protected BufferAllocator allocator;

/** Get the input schema of the function. */
Schema getInputSchema() {
Expand All @@ -44,9 +43,11 @@ Schema getOutputSchema() {
* Evaluate the function by processing a batch of input data.
*
* @param batch the input data batch to process
* @param allocator the allocator to use for allocating output data
* @return an iterator over the output data batches
*/
abstract Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch);
abstract Iterator<VectorSchemaRoot> evalBatch(
VectorSchemaRoot batch, BufferAllocator allocator);
}

/** Utility class for reflection. */
Expand Down
Loading