diff --git a/docs/source/java/dataset.rst b/docs/source/java/dataset.rst index 35ffa81058072..a4381e0814638 100644 --- a/docs/source/java/dataset.rst +++ b/docs/source/java/dataset.rst @@ -132,12 +132,10 @@ within method ``Scanner::schema()``: .. _java-dataset-projection: -Projection -========== +Projection (Subset of Columns) +============================== -User can specify projections in ScanOptions. For ``FileSystemDataset``, only -column projection is allowed for now, which means, only column names -in the projection list will be accepted. For example: +User can specify projections in ScanOptions. For example: .. code-block:: Java @@ -159,6 +157,27 @@ Or use shortcut construtor: Then all columns will be emitted during scanning. +Projection (Produce New Columns) and Filters +============================================ + +User can specify projections (new columns) or filters in ScanOptions using Substrait. For example: + +.. code-block:: Java + + ByteBuffer substraitExpressionFilter = getSubstraitExpressionFilter(); + ByteBuffer substraitExpressionProject = getSubstraitExpressionProjection(); + // Use Substrait APIs to create an Expression and serialize to a ByteBuffer + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitExpressionFilter(substraitExpressionFilter) + .substraitExpressionProjection(getSubstraitExpressionProjection()) + .build(); + +.. seealso:: + + :doc:`Executing Projections and Filters Using Extended Expressions ` + Projections and Filters using Substrait. + Read Data from HDFS =================== diff --git a/docs/source/java/substrait.rst b/docs/source/java/substrait.rst index 41effedbf01d9..d8d49a96e88f8 100644 --- a/docs/source/java/substrait.rst +++ b/docs/source/java/substrait.rst @@ -22,8 +22,10 @@ Substrait The ``arrow-dataset`` module can execute Substrait_ plans via the :doc:`Acero <../cpp/streaming_execution>` query engine. -Executing Substrait Plans -========================= +.. contents:: + +Executing Queries Using Substrait Plans +======================================= Plans can reference data in files via URIs, or "named tables" that must be provided along with the plan. @@ -102,6 +104,349 @@ Here is an example of a Java program that queries a Parquet file using Java Subs 0 ALGERIA 0 haggle. carefully final deposits detect slyly agai 1 ARGENTINA 1 al foxes promise slyly according to the regular accounts. bold requests alon +Executing Projections and Filters Using Extended Expressions +============================================================ + +Dataset also supports projections and filters with Substrait's `Extended Expression`_. +This requires the substrait-java library. + +This Java program: + +- Loads a Parquet file containing the "nation" table from the TPC-H benchmark. +- Projects two new columns: + - ``N_NAME || ' - ' || N_COMMENT`` + - ``N_REGIONKEY + 10`` +- Applies a filter: ``N_NATIONKEY > 18`` + +.. code-block:: Java + + import io.substrait.extension.ExtensionCollector; + import io.substrait.proto.Expression; + import io.substrait.proto.ExpressionReference; + import io.substrait.proto.ExtendedExpression; + import io.substrait.proto.FunctionArgument; + import io.substrait.proto.SimpleExtensionDeclaration; + import io.substrait.proto.SimpleExtensionURI; + import io.substrait.type.NamedStruct; + import io.substrait.type.Type; + import io.substrait.type.TypeCreator; + import io.substrait.type.proto.TypeProtoConverter; + import java.nio.ByteBuffer; + import java.util.ArrayList; + import java.util.Arrays; + import java.util.Base64; + import java.util.HashMap; + import java.util.List; + import java.util.Optional; + import org.apache.arrow.dataset.file.FileFormat; + import org.apache.arrow.dataset.file.FileSystemDatasetFactory; + import org.apache.arrow.dataset.jni.NativeMemoryPool; + import org.apache.arrow.dataset.scanner.ScanOptions; + import org.apache.arrow.dataset.scanner.Scanner; + import org.apache.arrow.dataset.source.Dataset; + import org.apache.arrow.dataset.source.DatasetFactory; + import org.apache.arrow.memory.BufferAllocator; + import org.apache.arrow.memory.RootAllocator; + import org.apache.arrow.vector.ipc.ArrowReader; + + public class ClientSubstraitExtendedExpressionsCookbook { + + public static void main(String[] args) throws Exception { + // project and filter dataset using extended expression definition - 03 Expressions: + // Expression 01 - CONCAT: N_NAME || ' - ' || N_COMMENT = col 1 || ' - ' || col 3 + // Expression 02 - ADD: N_REGIONKEY + 10 = col 1 + 10 + // Expression 03 - FILTER: N_NATIONKEY > 18 = col 3 > 18 + projectAndFilterDataset(); + } + + public static void projectAndFilterDataset() { + String uri = "file:///Users/data/tpch_parquet/nation.parquet"; + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(getSubstraitExpressionFilter()) + .substraitProjection(getSubstraitExpressionProjection()) + .build(); + try ( + BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = new FileSystemDatasetFactory( + allocator, NativeMemoryPool.getDefault(), + FileFormat.PARQUET, uri); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + while (reader.loadNextBatch()) { + System.out.println( + reader.getVectorSchemaRoot().contentToTSVString()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static ByteBuffer getSubstraitExpressionProjection() { + // Expression: N_REGIONKEY + 10 = col 3 + 10 + Expression.Builder selectionBuilderProjectOne = Expression.newBuilder(). + setSelection( + Expression.FieldReference.newBuilder(). + setDirectReference( + Expression.ReferenceSegment.newBuilder(). + setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField( + 2) + ) + ) + ); + Expression.Builder literalBuilderProjectOne = Expression.newBuilder() + .setLiteral( + Expression.Literal.newBuilder().setI32(10) + ); + io.substrait.proto.Type outputProjectOne = TypeCreator.NULLABLE.I32.accept( + new TypeProtoConverter(new ExtensionCollector())); + Expression.Builder expressionBuilderProjectOne = Expression. + newBuilder(). + setScalarFunction( + Expression. + ScalarFunction. + newBuilder(). + setFunctionReference(0). + setOutputType(outputProjectOne). + addArguments( + 0, + FunctionArgument.newBuilder().setValue( + selectionBuilderProjectOne) + ). + addArguments( + 1, + FunctionArgument.newBuilder().setValue( + literalBuilderProjectOne) + ) + ); + ExpressionReference.Builder expressionReferenceBuilderProjectOne = ExpressionReference.newBuilder(). + setExpression(expressionBuilderProjectOne) + .addOutputNames("ADD_TEN_TO_COLUMN_N_REGIONKEY"); + + // Expression: name || name = N_NAME || "-" || N_COMMENT = col 1 || col 3 + Expression.Builder selectionBuilderProjectTwo = Expression.newBuilder(). + setSelection( + Expression.FieldReference.newBuilder(). + setDirectReference( + Expression.ReferenceSegment.newBuilder(). + setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField( + 1) + ) + ) + ); + Expression.Builder selectionBuilderProjectTwoConcatLiteral = Expression.newBuilder() + .setLiteral( + Expression.Literal.newBuilder().setString(" - ") + ); + Expression.Builder selectionBuilderProjectOneToConcat = Expression.newBuilder(). + setSelection( + Expression.FieldReference.newBuilder(). + setDirectReference( + Expression.ReferenceSegment.newBuilder(). + setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField( + 3) + ) + ) + ); + io.substrait.proto.Type outputProjectTwo = TypeCreator.NULLABLE.STRING.accept( + new TypeProtoConverter(new ExtensionCollector())); + Expression.Builder expressionBuilderProjectTwo = Expression. + newBuilder(). + setScalarFunction( + Expression. + ScalarFunction. + newBuilder(). + setFunctionReference(1). + setOutputType(outputProjectTwo). + addArguments( + 0, + FunctionArgument.newBuilder().setValue( + selectionBuilderProjectTwo) + ). + addArguments( + 1, + FunctionArgument.newBuilder().setValue( + selectionBuilderProjectTwoConcatLiteral) + ). + addArguments( + 2, + FunctionArgument.newBuilder().setValue( + selectionBuilderProjectOneToConcat) + ) + ); + ExpressionReference.Builder expressionReferenceBuilderProjectTwo = ExpressionReference.newBuilder(). + setExpression(expressionBuilderProjectTwo) + .addOutputNames("CONCAT_COLUMNS_N_NAME_AND_N_COMMENT"); + + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", + "N_REGIONKEY", "N_COMMENT"); + List dataTypes = Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING + ); + NamedStruct of = NamedStruct.of( + columnNames, + Type.Struct.builder().fields(dataTypes).nullable(false).build() + ); + // Extensions URI + HashMap extensionUris = new HashMap<>(); + extensionUris.put( + "key-001", + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("/functions_arithmetic.yaml") + .build() + ); + // Extensions + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionAdd = SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(0) + .setName("add:i32_i32") + .setExtensionUriReference(1)) + .build(); + SimpleExtensionDeclaration extensionFunctionGreaterThan = SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("concat:vchar") + .setExtensionUriReference(2)) + .build(); + extensions.add(extensionFunctionAdd); + extensions.add(extensionFunctionGreaterThan); + // Extended Expression + ExtendedExpression.Builder extendedExpressionBuilder = + ExtendedExpression.newBuilder(). + addReferredExpr(0, + expressionReferenceBuilderProjectOne). + addReferredExpr(1, + expressionReferenceBuilderProjectTwo). + setBaseSchema(of.toProto(new TypeProtoConverter( + new ExtensionCollector()))); + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + ExtendedExpression extendedExpression = extendedExpressionBuilder.build(); + byte[] extendedExpressions = Base64.getDecoder().decode( + Base64.getEncoder().encodeToString( + extendedExpression.toByteArray())); + ByteBuffer substraitExpressionProjection = ByteBuffer.allocateDirect( + extendedExpressions.length); + substraitExpressionProjection.put(extendedExpressions); + return substraitExpressionProjection; + } + + private static ByteBuffer getSubstraitExpressionFilter() { + // Expression: Filter: N_NATIONKEY > 18 = col 1 > 18 + Expression.Builder selectionBuilderFilterOne = Expression.newBuilder(). + setSelection( + Expression.FieldReference.newBuilder(). + setDirectReference( + Expression.ReferenceSegment.newBuilder(). + setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField( + 0) + ) + ) + ); + Expression.Builder literalBuilderFilterOne = Expression.newBuilder() + .setLiteral( + Expression.Literal.newBuilder().setI32(18) + ); + io.substrait.proto.Type outputFilterOne = TypeCreator.NULLABLE.BOOLEAN.accept( + new TypeProtoConverter(new ExtensionCollector())); + Expression.Builder expressionBuilderFilterOne = Expression. + newBuilder(). + setScalarFunction( + Expression. + ScalarFunction. + newBuilder(). + setFunctionReference(1). + setOutputType(outputFilterOne). + addArguments( + 0, + FunctionArgument.newBuilder().setValue( + selectionBuilderFilterOne) + ). + addArguments( + 1, + FunctionArgument.newBuilder().setValue( + literalBuilderFilterOne) + ) + ); + ExpressionReference.Builder expressionReferenceBuilderFilterOne = ExpressionReference.newBuilder(). + setExpression(expressionBuilderFilterOne) + .addOutputNames("COLUMN_N_NATIONKEY_GREATER_THAN_18"); + + List columnNames = Arrays.asList("N_NATIONKEY", "N_NAME", + "N_REGIONKEY", "N_COMMENT"); + List dataTypes = Arrays.asList( + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING, + TypeCreator.NULLABLE.I32, + TypeCreator.NULLABLE.STRING + ); + NamedStruct of = NamedStruct.of( + columnNames, + Type.Struct.builder().fields(dataTypes).nullable(false).build() + ); + // Extensions URI + HashMap extensionUris = new HashMap<>(); + extensionUris.put( + "key-001", + SimpleExtensionURI.newBuilder() + .setExtensionUriAnchor(1) + .setUri("/functions_comparison.yaml") + .build() + ); + // Extensions + ArrayList extensions = new ArrayList<>(); + SimpleExtensionDeclaration extensionFunctionLowerThan = SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("gt:any_any") + .setExtensionUriReference(1)) + .build(); + extensions.add(extensionFunctionLowerThan); + // Extended Expression + ExtendedExpression.Builder extendedExpressionBuilder = + ExtendedExpression.newBuilder(). + addReferredExpr(0, + expressionReferenceBuilderFilterOne). + setBaseSchema(of.toProto(new TypeProtoConverter( + new ExtensionCollector()))); + extendedExpressionBuilder.addAllExtensionUris(extensionUris.values()); + extendedExpressionBuilder.addAllExtensions(extensions); + ExtendedExpression extendedExpression = extendedExpressionBuilder.build(); + byte[] extendedExpressions = Base64.getDecoder().decode( + Base64.getEncoder().encodeToString( + extendedExpression.toByteArray())); + ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect( + extendedExpressions.length); + substraitExpressionFilter.put(extendedExpressions); + return substraitExpressionFilter; + } + } + +.. code-block:: text + + ADD_TEN_TO_COLUMN_N_REGIONKEY CONCAT_COLUMNS_N_NAME_AND_N_COMMENT + 13 ROMANIA - ular asymptotes are about the furious multipliers. express dependencies nag above the ironically ironic account + 14 SAUDI ARABIA - ts. silent requests haggle. closely express packages sleep across the blithely + 12 VIETNAM - hely enticingly express accounts. even, final + 13 RUSSIA - requests against the platelets use never according to the quickly regular pint + 13 UNITED KINGDOM - eans boost carefully special requests. accounts are. carefull + 11 UNITED STATES - y final packages. slow foxes cajole quickly. quickly silent platelets breach ironic accounts. unusual pinto be + .. _`Substrait`: https://substrait.io/ .. _`Substrait Java`: https://github.com/substrait-io/substrait-java -.. _`Acero`: https://arrow.apache.org/docs/cpp/streaming_execution.html \ No newline at end of file +.. _`Acero`: https://arrow.apache.org/docs/cpp/streaming_execution.html +.. _`Extended Expression`: https://github.com/substrait-io/substrait/blob/main/site/docs/expressions/extended_expression.md diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index 5640bc4349670..49e0f1720909f 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -29,6 +29,8 @@ #include "arrow/filesystem/path_util.h" #include "arrow/filesystem/s3fs.h" #include "arrow/engine/substrait/util.h" +#include "arrow/engine/substrait/serde.h" +#include "arrow/engine/substrait/relation.h" #include "arrow/ipc/api.h" #include "arrow/util/iterator.h" #include "jni_util.h" @@ -200,7 +202,6 @@ arrow::Result> SchemaFromColumnNames( return arrow::Status::Invalid("Partition column '", ref.ToString(), "' is not in dataset schema"); } } - return schema(std::move(columns))->WithMetadata(input->metadata()); } } // namespace @@ -317,6 +318,14 @@ std::shared_ptr GetTableByName(const std::vector& nam return it->second; } +std::shared_ptr LoadArrowBufferFromByteBuffer(JNIEnv* env, jobject byte_buffer) { + const auto *buff = reinterpret_cast(env->GetDirectBufferAddress(byte_buffer)); + int length = env->GetDirectBufferCapacity(byte_buffer); + std::shared_ptr buffer = JniGetOrThrow(arrow::AllocateBuffer(length)); + std::memcpy(buffer->mutable_data(), buff, length); + return buffer; +} + /* * Class: org_apache_arrow_dataset_jni_NativeMemoryPool * Method: getDefaultMemoryPool @@ -455,11 +464,12 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset /* * Class: org_apache_arrow_dataset_jni_JniWrapper * Method: createScanner - * Signature: (J[Ljava/lang/String;JJ)J + * Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)J */ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner( - JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns, jlong batch_size, - jlong memory_pool_id) { + JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns, + jobject substrait_projection, jobject substrait_filter, + jlong batch_size, jlong memory_pool_id) { JNI_METHOD_START arrow::MemoryPool* pool = reinterpret_cast(memory_pool_id); if (pool == nullptr) { @@ -474,6 +484,40 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann std::vector column_vector = ToStringVector(env, columns); JniAssertOkOrThrow(scanner_builder->Project(column_vector)); } + if (substrait_projection != nullptr) { + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, + substrait_projection); + std::vector project_exprs; + std::vector project_names; + arrow::engine::BoundExpressions bounded_expression = + JniGetOrThrow(arrow::engine::DeserializeExpressions(*buffer)); + for(arrow::engine::NamedExpression& named_expression : + bounded_expression.named_expressions) { + project_exprs.push_back(std::move(named_expression.expression)); + project_names.push_back(std::move(named_expression.name)); + } + JniAssertOkOrThrow(scanner_builder->Project(std::move(project_exprs), std::move(project_names))); + } + if (substrait_filter != nullptr) { + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, + substrait_filter); + std::optional filter_expr = std::nullopt; + arrow::engine::BoundExpressions bounded_expression = + JniGetOrThrow(arrow::engine::DeserializeExpressions(*buffer)); + for(arrow::engine::NamedExpression& named_expression : + bounded_expression.named_expressions) { + filter_expr = named_expression.expression; + if (named_expression.expression.type()->id() == arrow::Type::BOOL) { + filter_expr = named_expression.expression; + } else { + JniThrow("There is no filter expression in the expression provided"); + } + } + if (filter_expr == std::nullopt) { + JniThrow("The filter expression has not been provided"); + } + JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr)); + } JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size)); auto scanner = JniGetOrThrow(scanner_builder->Finish()); @@ -748,10 +792,7 @@ JNIEXPORT void JNICALL arrow::engine::ConversionOptions conversion_options; conversion_options.named_table_provider = std::move(table_provider); // mapping arrow::Buffer - auto *buff = reinterpret_cast(env->GetDirectBufferAddress(plan)); - int length = env->GetDirectBufferCapacity(plan); - std::shared_ptr buffer = JniGetOrThrow(arrow::AllocateBuffer(length)); - std::memcpy(buffer->mutable_data(), buff, length); + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, plan); // execute plan std::shared_ptr reader_out = JniGetOrThrow(arrow::engine::ExecuteSerializedPlan(*buffer, nullptr, nullptr, conversion_options)); diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java index 93cc5d7a37040..a7df5be42f13b 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java @@ -17,6 +17,8 @@ package org.apache.arrow.dataset.jni; +import java.nio.ByteBuffer; + /** * JNI wrapper for Dataset API's native implementation. */ @@ -66,15 +68,19 @@ private JniWrapper() { /** * Create Scanner from a Dataset and get the native pointer of the Dataset. + * * @param datasetId the native pointer of the arrow::dataset::Dataset instance. * @param columns desired column names. * Columns not in this list will not be emitted when performing scan operation. Null equals * to "all columns". + * @param substraitProjection substrait extended expression to evaluate for project new columns + * @param substraitFilter substrait extended expression to evaluate for apply filter * @param batchSize batch size of scanned record batches. * @param memoryPool identifier of memory pool used in the native scanner. * @return the native pointer of the arrow::dataset::Scanner instance. */ - public native long createScanner(long datasetId, String[] columns, long batchSize, long memoryPool); + public native long createScanner(long datasetId, String[] columns, ByteBuffer substraitProjection, + ByteBuffer substraitFilter, long batchSize, long memoryPool); /** * Get a serialized schema from native instance of a Scanner. diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java index 30ff1a9302f7a..d9abad9971c4e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java @@ -40,8 +40,12 @@ public synchronized NativeScanner newScan(ScanOptions options) { if (closed) { throw new NativeInstanceReleasedException(); } + long scannerId = JniWrapper.get().createScanner(datasetId, options.getColumns().orElse(null), + options.getSubstraitProjection().orElse(null), + options.getSubstraitFilter().orElse(null), options.getBatchSize(), context.getMemoryPool().getNativeInstanceId()); + return new NativeScanner(context, scannerId); } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java index f5a1af384b24e..995d05ac3b314 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java @@ -17,6 +17,7 @@ package org.apache.arrow.dataset.scanner; +import java.nio.ByteBuffer; import java.util.Optional; import org.apache.arrow.util.Preconditions; @@ -25,8 +26,10 @@ * Options used during scanning. */ public class ScanOptions { - private final Optional columns; private final long batchSize; + private final Optional columns; + private final Optional substraitProjection; + private final Optional substraitFilter; /** * Constructor. @@ -56,6 +59,8 @@ public ScanOptions(long batchSize, Optional columns) { Preconditions.checkNotNull(columns); this.batchSize = batchSize; this.columns = columns; + this.substraitProjection = Optional.empty(); + this.substraitFilter = Optional.empty(); } public ScanOptions(long batchSize) { @@ -69,4 +74,77 @@ public Optional getColumns() { public long getBatchSize() { return batchSize; } + + public Optional getSubstraitProjection() { + return substraitProjection; + } + + public Optional getSubstraitFilter() { + return substraitFilter; + } + + /** + * Builder for Options used during scanning. + */ + public static class Builder { + private final long batchSize; + private Optional columns; + private ByteBuffer substraitProjection; + private ByteBuffer substraitFilter; + + /** + * Constructor. + * @param batchSize Maximum row number of each returned {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch} + */ + public Builder(long batchSize) { + this.batchSize = batchSize; + } + + /** + * Set the Projected columns. Empty for scanning all columns. + * + * @param columns Projected columns. Empty for scanning all columns. + * @return the ScanOptions configured. + */ + public Builder columns(Optional columns) { + Preconditions.checkNotNull(columns); + this.columns = columns; + return this; + } + + /** + * Set the Substrait extended expression for Projection new columns. + * + * @param substraitProjection Expressions to evaluate for project new columns. + * @return the ScanOptions configured. + */ + public Builder substraitProjection(ByteBuffer substraitProjection) { + Preconditions.checkNotNull(substraitProjection); + this.substraitProjection = substraitProjection; + return this; + } + + /** + * Set the Substrait extended expression for Filter. + * + * @param substraitFilter Expressions to evaluate for apply Filter. + * @return the ScanOptions configured. + */ + public Builder substraitFilter(ByteBuffer substraitFilter) { + Preconditions.checkNotNull(substraitFilter); + this.substraitFilter = substraitFilter; + return this; + } + + public ScanOptions build() { + return new ScanOptions(this); + } + } + + private ScanOptions(Builder builder) { + batchSize = builder.batchSize; + columns = builder.columns; + substraitProjection = Optional.ofNullable(builder.substraitProjection); + substraitFilter = Optional.ofNullable(builder.substraitFilter); + } } diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java b/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java index c23b7e002880a..0fba72892cdc6 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java @@ -18,6 +18,8 @@ package org.apache.arrow.dataset.substrait; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import java.nio.ByteBuffer; import java.nio.file.Files; @@ -27,6 +29,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import org.apache.arrow.dataset.ParquetWriteSupport; import org.apache.arrow.dataset.TestDataset; @@ -85,7 +88,7 @@ public void testRunQueryLocalFiles() throws Exception { } @Test - public void testRunQueryNamedTableNation() throws Exception { + public void testRunQueryNamedTable() throws Exception { //Query: //SELECT id, name FROM Users //Isthmus: @@ -123,7 +126,7 @@ public void testRunQueryNamedTableNation() throws Exception { } @Test(expected = RuntimeException.class) - public void testRunQueryNamedTableNationWithException() throws Exception { + public void testRunQueryNamedTableWithException() throws Exception { //Query: //SELECT id, name FROM Users //Isthmus: @@ -160,7 +163,7 @@ public void testRunQueryNamedTableNationWithException() throws Exception { } @Test - public void testRunBinaryQueryNamedTableNation() throws Exception { + public void testRunBinaryQueryNamedTable() throws Exception { //Query: //SELECT id, name FROM Users //Isthmus: @@ -187,9 +190,7 @@ public void testRunBinaryQueryNamedTableNation() throws Exception { Map mapTableToArrowReader = new HashMap<>(); mapTableToArrowReader.put("USERS", reader); // get binary plan - byte[] plan = Base64.getDecoder().decode(binaryPlan); - ByteBuffer substraitPlan = ByteBuffer.allocateDirect(plan.length); - substraitPlan.put(plan); + ByteBuffer substraitPlan = getByteBuffer(binaryPlan); // run query try (ArrowReader arrowReader = new AceroSubstraitConsumer(rootAllocator()).runQuery( substraitPlan, @@ -204,4 +205,256 @@ public void testRunBinaryQueryNamedTableNation() throws Exception { } } } + + @Test + public void testRunExtendedExpressionsFilter() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()) + ), null); + // Substrait Extended Expression: Filter: + // Expression 01: WHERE ID < 20 + String base64EncodedSubstraitFilter = "Ch4IARIaL2Z1bmN0aW9uc19jb21wYXJpc29uLnlhbWwSEhoQCAIQAhoKbHQ6YW55X2F" + + "ueRo3ChwaGggCGgQKAhABIggaBhIECgISACIGGgQKAigUGhdmaWx0ZXJfaWRfbG93ZXJfdGhhbl8yMCIaCgJJRAoETkFNRRIOCgQqAhA" + + "BCgRiAhABGAI="; + ByteBuffer substraitExpressionFilter = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowcount = 0; + while (reader.loadNextBatch()) { + rowcount += reader.getVectorSchemaRoot().getRowCount(); + assertTrue(reader.getVectorSchemaRoot().getVector("id").toString().equals("[19, 1, 11]")); + assertTrue(reader.getVectorSchemaRoot().getVector("name").toString() + .equals("[value_19, value_1, value_11]")); + } + assertEquals(3, rowcount); + } + } + + @Test + public void testRunExtendedExpressionsFilterWithProjectionsInsteadOfFilterException() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()) + ), null); + // Substrait Extended Expression: Project New Column: + // Expression ADD: id + 2 + // Expression CONCAT: name + '-' + name + String base64EncodedSubstraitFilter = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwSERoPCAEaC2FkZDppM" + + "zJfaTMyEhQaEggCEAEaDGNvbmNhdDp2Y2hhchoxChoaGBoEKgIQASIIGgYSBAoCEgAiBhoECgIoAhoTYWRkX3R3b190b19jb2x1" + + "bW5fYRpGCi0aKwgBGgRiAhABIgoaCBIGCgQSAggBIgkaBwoFYgMgLSAiChoIEgYKBBICCAEaFWNvbmNhdF9jb2x1bW5fYV9hbmR" + + "fYiIaCgJJRAoETkFNRRIOCgQqAhABCgRiAhABGAI="; + ByteBuffer substraitExpressionFilter = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish() + ) { + Exception e = assertThrows(RuntimeException.class, () -> dataset.newScan(options)); + assertTrue(e.getMessage().startsWith("There is no filter expression in the expression provided")); + } + } + + @Test + public void testRunExtendedExpressionsFilterWithEmptyFilterException() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()) + ), null); + String base64EncodedSubstraitFilter = ""; + ByteBuffer substraitExpressionFilter = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitFilter(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish() + ) { + Exception e = assertThrows(RuntimeException.class, () -> dataset.newScan(options)); + assertTrue(e.getMessage().contains("no anonymous struct type was provided to which names could be attached.")); + } + } + + @Test + public void testRunExtendedExpressionsProjection() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("add_two_to_column_a", new ArrowType.Int(32, true)), + Field.nullable("concat_column_a_and_b", new ArrowType.Utf8()) + ), null); + // Substrait Extended Expression: Project New Column: + // Expression ADD: id + 2 + // Expression CONCAT: name + '-' + name + String binarySubstraitExpressionProject = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwSERoPCAEaC2FkZDppM" + + "zJfaTMyEhQaEggCEAEaDGNvbmNhdDp2Y2hhchoxChoaGBoEKgIQASIIGgYSBAoCEgAiBhoECgIoAhoTYWRkX3R3b190b19jb2x1" + + "bW5fYRpGCi0aKwgBGgRiAhABIgoaCBIGCgQSAggBIgkaBwoFYgMgLSAiChoIEgYKBBICCAEaFWNvbmNhdF9jb2x1bW5fYV9hbmR" + + "fYiIaCgJJRAoETkFNRRIOCgQqAhABCgRiAhABGAI="; + ByteBuffer substraitExpressionProject = getByteBuffer(binarySubstraitExpressionProject); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(substraitExpressionProject) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowcount = 0; + while (reader.loadNextBatch()) { + assertTrue(reader.getVectorSchemaRoot().getVector("add_two_to_column_a").toString() + .equals("[21, 3, 13, 23, 47]")); + assertTrue(reader.getVectorSchemaRoot().getVector("concat_column_a_and_b").toString() + .equals("[value_19 - value_19, value_1 - value_1, value_11 - value_11, " + + "value_21 - value_21, value_45 - value_45]")); + rowcount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(5, rowcount); + } + } + + @Test + public void testRunExtendedExpressionsProjectionWithFilterInsteadOfProjectionException() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("filter_id_lower_than_20", new ArrowType.Bool()) + ), null); + // Substrait Extended Expression: Filter: + // Expression 01: WHERE ID < 20 + String binarySubstraitExpressionFilter = "Ch4IARIaL2Z1bmN0aW9uc19jb21wYXJpc29uLnlhbWwSEhoQCAIQAhoKbHQ6YW55X2F" + + "ueRo3ChwaGggCGgQKAhABIggaBhIECgISACIGGgQKAigUGhdmaWx0ZXJfaWRfbG93ZXJfdGhhbl8yMCIaCgJJRAoETkFNRRIOCgQqAhA" + + "BCgRiAhABGAI="; + ByteBuffer substraitExpressionFilter = getByteBuffer(binarySubstraitExpressionFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowcount = 0; + while (reader.loadNextBatch()) { + assertTrue(reader.getVectorSchemaRoot().getVector("filter_id_lower_than_20").toString() + .equals("[true, true, true, false, false]")); + rowcount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(5, rowcount); + } + } + + @Test + public void testRunExtendedExpressionsProjectionWithEmptyProjectionException() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()) + ), null); + String base64EncodedSubstraitFilter = ""; + ByteBuffer substraitExpressionProjection = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(substraitExpressionProjection) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish() + ) { + Exception e = assertThrows(RuntimeException.class, () -> dataset.newScan(options)); + assertTrue(e.getMessage().contains("no anonymous struct type was provided to which names could be attached.")); + } + } + + @Test + public void testRunExtendedExpressionsProjectAndFilter() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("add_two_to_column_a", new ArrowType.Int(32, true)), + Field.nullable("concat_column_a_and_b", new ArrowType.Utf8()) + ), null); + // Substrait Extended Expression: Project New Column: + // Expression ADD: id + 2 + // Expression CONCAT: name + '-' + name + String binarySubstraitExpressionProject = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwSERoPCAEaC2FkZDppM" + + "zJfaTMyEhQaEggCEAEaDGNvbmNhdDp2Y2hhchoxChoaGBoEKgIQASIIGgYSBAoCEgAiBhoECgIoAhoTYWRkX3R3b190b19jb2x1" + + "bW5fYRpGCi0aKwgBGgRiAhABIgoaCBIGCgQSAggBIgkaBwoFYgMgLSAiChoIEgYKBBICCAEaFWNvbmNhdF9jb2x1bW5fYV9hbmR" + + "fYiIaCgJJRAoETkFNRRIOCgQqAhABCgRiAhABGAI="; + ByteBuffer substraitExpressionProject = getByteBuffer(binarySubstraitExpressionProject); + // Substrait Extended Expression: Filter: + // Expression 01: WHERE ID < 20 + String base64EncodedSubstraitFilter = "Ch4IARIaL2Z1bmN0aW9uc19jb21wYXJpc29uLnlhbWwSEhoQCAIQAhoKbHQ6YW55X2F" + + "ueRo3ChwaGggCGgQKAhABIggaBhIECgISACIGGgQKAigUGhdmaWx0ZXJfaWRfbG93ZXJfdGhhbl8yMCIaCgJJRAoETkFNRRIOCgQqAhA" + + "BCgRiAhABGAI="; + ByteBuffer substraitExpressionFilter = getByteBuffer(base64EncodedSubstraitFilter); + ParquetWriteSupport writeSupport = ParquetWriteSupport + .writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 19, "value_19", 1, "value_1", + 11, "value_11", 21, "value_21", 45, "value_45"); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .substraitProjection(substraitExpressionProject) + .substraitFilter(substraitExpressionFilter) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, writeSupport.getOutputURI()); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowcount = 0; + while (reader.loadNextBatch()) { + assertTrue(reader.getVectorSchemaRoot().getVector("add_two_to_column_a").toString() + .equals("[21, 3, 13]")); + assertTrue(reader.getVectorSchemaRoot().getVector("concat_column_a_and_b").toString() + .equals("[value_19 - value_19, value_1 - value_1, value_11 - value_11]")); + rowcount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(3, rowcount); + } + } + + private static ByteBuffer getByteBuffer(String base64EncodedSubstrait) { + byte[] decodedSubstrait = Base64.getDecoder().decode(base64EncodedSubstrait); + ByteBuffer substraitExpression = ByteBuffer.allocateDirect(decodedSubstrait.length); + substraitExpression.put(decodedSubstrait); + return substraitExpression; + } }