diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java index 4b776254c81..e73b0ac275e 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java @@ -93,9 +93,10 @@ public interface Scan { * Get the remaining filter that is not guaranteed to be satisfied for the data Delta Kernel * returns. This filter is used by Delta Kernel to do data skipping when possible. * + * @param tableClient {@link TableClient} instance to use in Delta Kernel. * @return the remaining filter as a {@link Predicate}. */ - Optional getRemainingFilter(); + Optional getRemainingFilter(TableClient tableClient); /** * Get the scan state associated with the current scan. This state is common across all diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index d059bf7ab4d..080fd05705d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -55,9 +55,22 @@ public class ScanImpl implements Scan { private final Metadata metadata; private final LogReplay logReplay; private final Path dataPath; - private final Optional> partitionAndDataFilters; + + private final Optional filter; + private boolean accessedScanFiles; + private boolean areFiltersSplit; + // Subset of partition predicate the expression handler can support evaluating + private Optional metadataPredicate; + // Subset of the given query predicate the Kernel tries to use prune scan file list as best as + // it can, but can't guarantee that the all the scan files returned contains data on which this + // predicate evaluates to true. The connector needs to apply this filter on the data from the + // returned scan files to completely remove the data that doesn't satisfy given query predicate. + // + // The predicate could be on the data columns and/or unsupported predicate on partition columns + private Optional remainingPredicate; + public ScanImpl( StructType snapshotSchema, StructType readSchema, @@ -71,7 +84,7 @@ public ScanImpl( this.protocol = protocol; this.metadata = metadata; this.logReplay = logReplay; - this.partitionAndDataFilters = splitFilters(filter); + this.filter = filter; this.dataPath = dataPath; } @@ -88,7 +101,7 @@ public CloseableIterator getScanFiles(TableClient tableCl accessedScanFiles = true; // Generate data skipping filter and decide if we should read the stats column - Optional dataSkippingFilter = getDataSkippingFilter(); + Optional dataSkippingFilter = getDataSkippingFilter(tableClient); boolean shouldReadStats = dataSkippingFilter.isPresent(); // Get active AddFiles via log replay @@ -142,22 +155,38 @@ public Row getScanState(TableClient tableClient) { } @Override - public Optional getRemainingFilter() { - return getDataFilters(); + public Optional getRemainingFilter(TableClient tableClient) { + splitFilters(tableClient); + return remainingPredicate; } - private Optional> splitFilters(Optional filter) { - return filter.map(predicate -> - PartitionUtils.splitMetadataAndDataPredicates( - predicate, metadata.getPartitionColNames())); + private void splitFilters(TableClient tableClient) { + if (areFiltersSplit) { + return; + } + filter.map(predicate -> { + Tuple2 metadataAndNonMetadataFilters = + PartitionUtils.splitPredicates( + tableClient.getExpressionHandler(), + metadata.getSchema(), + predicate, + metadata.getPartitionColNames()); + + metadataPredicate = removeAlwaysTrue(Optional.of(metadataAndNonMetadataFilters._1)); + remainingPredicate = removeAlwaysTrue(Optional.of(metadataAndNonMetadataFilters._2)); + return null; + }); + areFiltersSplit = true; } - private Optional getDataFilters() { - return removeAlwaysTrue(partitionAndDataFilters.map(filters -> filters._2)); + private Optional getDataFilters(TableClient tableClient) { + splitFilters(tableClient); + return remainingPredicate; } - private Optional getPartitionsFilters() { - return removeAlwaysTrue(partitionAndDataFilters.map(filters -> filters._1)); + private Optional getPartitionsFilters(TableClient tableClient) { + splitFilters(tableClient); + return metadataPredicate; } /** @@ -171,7 +200,7 @@ private Optional removeAlwaysTrue(Optional predicate) { private CloseableIterator applyPartitionPruning( TableClient tableClient, CloseableIterator scanFileIter) { - Optional partitionPredicate = getPartitionsFilters(); + Optional partitionPredicate = getPartitionsFilters(tableClient); if (!partitionPredicate.isPresent()) { // There is no partition filter, return the scan file iterator as is. return scanFileIter; @@ -221,9 +250,9 @@ public void close() throws IOException { }; } - private Optional getDataSkippingFilter() { - return getDataFilters().flatMap(dataFilters -> - DataSkippingUtils.constructDataSkippingFilter(dataFilters, metadata.getDataSchema()) + private Optional getDataSkippingFilter(TableClient tableClient) { + return getDataFilters(tableClient).flatMap(dataFilters -> + DataSkippingUtils.constructDataSkippingFilter(dataFilters, metadata.getSchema()) ); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java index b42c4cd1d16..6e20d8e2208 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java @@ -18,8 +18,9 @@ import java.util.List; import static java.lang.String.format; -import io.delta.kernel.expressions.Expression; -import io.delta.kernel.expressions.Predicate; +import io.delta.kernel.expressions.*; +import static io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE; + import static io.delta.kernel.internal.util.Preconditions.checkArgument; public class ExpressionUtils { @@ -64,4 +65,22 @@ public static Expression getUnaryChild(Expression expression) { format("%s: expected one inputs, but got %s", expression, children.size())); return children.get(0); } + + /* + * Utility method to combine the given predicates with AND + */ + public static Predicate combineWithAndOp(Predicate left, Predicate right) { + String leftName = left.getName().toUpperCase(); + String rightName = right.getName().toUpperCase(); + if (leftName.equals("ALWAYS_FALSE") || rightName.equals("ALWAYS_FALSE")) { + return ALWAYS_FALSE; + } + if (leftName.equals("ALWAYS_TRUE")) { + return right; + } + if (rightName.equals("ALWAYS_TRUE")) { + return left; + } + return new And(left, right); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java index 63dc051bd15..b6fc45e97df 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java @@ -29,10 +29,10 @@ import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.expressions.*; import io.delta.kernel.types.*; -import static io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE; import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; import io.delta.kernel.internal.InternalScanFileUtils; +import static io.delta.kernel.internal.util.ExpressionUtils.*; public class PartitionUtils { private PartitionUtils() {} @@ -103,30 +103,37 @@ public static ColumnarBatch withPartitionColumns( } /** - * Split the given predicate into predicate on partition columns and predicate on data columns. + * Split the given predicate into predicate that a part that can be guaranteed to be satisfied + * by kernel when returning the scan file and the best effort predicate that Kernel uses for + * skipping but doesn't guarantee the returned scan files has data that doesn't satisfy the + * predicate. * + * @param exprHandler * @param predicate * @param partitionColNames - * @return Tuple of partition column predicate and data column predicate. + * @return Tuple of guaranteed predicate and best effort predicate. */ - public static Tuple2 splitMetadataAndDataPredicates( - Predicate predicate, - Set partitionColNames) { + public static Tuple2 splitPredicates( + ExpressionHandler exprHandler, + StructType tableSchema, + Predicate predicate, + Set partitionColNames) { String predicateName = predicate.getName(); List children = predicate.getChildren(); if ("AND".equalsIgnoreCase(predicateName)) { - Predicate left = (Predicate) children.get(0); - Predicate right = (Predicate) children.get(1); + Predicate left = asPredicate(getLeft(predicate)); + Predicate right = asPredicate(getRight(predicate)); Tuple2 leftResult = - splitMetadataAndDataPredicates(left, partitionColNames); + splitPredicates(exprHandler, tableSchema, left, partitionColNames); Tuple2 rightResult = - splitMetadataAndDataPredicates(right, partitionColNames); + splitPredicates(exprHandler, tableSchema, right, partitionColNames); return new Tuple2<>( combineWithAndOp(leftResult._1, rightResult._1), combineWithAndOp(leftResult._2, rightResult._2)); } - if (hasNonPartitionColumns(children, partitionColNames)) { + if (hasNonPartitionColumns(children, partitionColNames) || + !exprHandler.isSupported(tableSchema, predicate, BooleanType.BOOLEAN)) { return new Tuple2(ALWAYS_TRUE, predicate); } else { return new Tuple2<>(predicate, ALWAYS_TRUE); @@ -215,21 +222,6 @@ private static boolean hasNonPartitionColumns( return false; } - private static Predicate combineWithAndOp(Predicate left, Predicate right) { - String leftName = left.getName().toUpperCase(); - String rightName = right.getName().toUpperCase(); - if (leftName.equals("ALWAYS_FALSE") || rightName.equals("ALWAYS_FALSE")) { - return ALWAYS_FALSE; - } - if (leftName.equals("ALWAYS_TRUE")) { - return right; - } - if (rightName.equals("ALWAYS_TRUE")) { - return left; - } - return new And(left, right); - } - private static Literal literalForPartitionValue(DataType dataType, String partitionValue) { if (partitionValue == null) { return Literal.ofNull(dataType); diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala index ceb91ddc45b..84800d702e1 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala @@ -15,140 +15,188 @@ */ package io.delta.kernel.internal.util -import java.util +import io.delta.kernel.client.ExpressionHandler +import io.delta.kernel.data.ColumnVector +import java.util import scala.collection.JavaConverters._ - import io.delta.kernel.expressions._ import io.delta.kernel.expressions.Literal._ -import io.delta.kernel.internal.util.PartitionUtils.{rewritePartitionPredicateOnScanFileSchema, splitMetadataAndDataPredicates} +import io.delta.kernel.internal.util.PartitionUtils.{rewritePartitionPredicateOnScanFileSchema, splitPredicates} import io.delta.kernel.types._ import org.scalatest.funsuite.AnyFunSuite class PartitionUtilsSuite extends AnyFunSuite { // Table schema - // Data columns: data1: int, data2: string, date3: struct(data31: boolean, data32: long) - // Partition columns: part1: int, part2: date, part3: string + // Data columns: d1: int, d2: string, date3: struct(d31: boolean, d32: long) + // Partition columns: p1: int, p2: date, p3: string val tableSchema = new StructType() - .add("data1", IntegerType.INTEGER) - .add("data2", StringType.STRING) - .add("data3", new StructType() - .add("data31", BooleanType.BOOLEAN) - .add("data32", LongType.LONG)) - .add("part1", IntegerType.INTEGER) - .add("part2", DateType.DATE) - .add("part3", StringType.STRING) + .add("d1", IntegerType.INTEGER) + .add("d2", StringType.STRING) + .add("d3", new StructType() + .add("d31", BooleanType.BOOLEAN) + .add("d32", LongType.LONG)) + .add("p1", IntegerType.INTEGER) + .add("p2", DateType.DATE) + .add("p3", StringType.STRING) private val partitionColsMetadata = new util.HashMap[String, StructField]() { { - put("part1", tableSchema.get("part1")) - put("part2", tableSchema.get("part2")) - put("part3", tableSchema.get("part3")) + put("p1", tableSchema.get("p1")) + put("p2", tableSchema.get("p2")) + put("p3", tableSchema.get("p3")) } } private val partitionCols: java.util.Set[String] = partitionColsMetadata.keySet() - // Test cases for verifying partition of predicate into data and partition predicates - // Map entry format (predicate -> (partition predicate, data predicate) + // Test cases for verifying query predicate is split into guaranteed and best effort predicates + // Map entry format (predicate -> (guaranteed predicate, best effort predicate) val partitionTestCases = Map[Predicate, (String, String)]( // single predicate on a data column - predicate("=", col("data1"), ofInt(12)) -> - ("ALWAYS_TRUE()", "(column(`data1`) = 12)"), + eq(col("d1"), int(12)) -> ("ALWAYS_TRUE()", "(column(`d1`) = 12)"), + // multiple predicates on data columns joined with AND - predicate("AND", - predicate("=", col("data1"), ofInt(12)), - predicate(">=", col("data2"), ofString("sss"))) -> - ("ALWAYS_TRUE()", "((column(`data1`) = 12) AND (column(`data2`) >= sss))"), + and(eq(col("d1"), int(12)), gte(col("d2"), str("sss"))) -> + ("ALWAYS_TRUE()", "((column(`d1`) = 12) AND (column(`d2`) >= sss))"), + // multiple predicates on data columns joined with OR - predicate("OR", - predicate("<=", col("data2"), ofString("sss")), - predicate("=", col("data3", "data31"), ofBoolean(true))) -> - ("ALWAYS_TRUE()", "((column(`data2`) <= sss) OR (column(`data3`.`data31`) = true))"), + or(lte(col("d2"), str("sss")), eq(col("d3", "d31"), ofBoolean(true))) -> + ("ALWAYS_TRUE()", "((column(`d2`) <= sss) OR (column(`d3`.`d31`) = true))"), + // single predicate on a partition column - predicate("=", col("part1"), ofInt(12)) -> - ("(column(`part1`) = 12)", "ALWAYS_TRUE()"), + eq(col("p1"), int(12)) -> ("(column(`p1`) = 12)", "ALWAYS_TRUE()"), + // multiple predicates on partition columns joined with AND - predicate("AND", - predicate("=", col("part1"), ofInt(12)), - predicate(">=", col("part3"), ofString("sss"))) -> - ("((column(`part1`) = 12) AND (column(`part3`) >= sss))", "ALWAYS_TRUE()"), + and(eq(col("p1"), int(12)), gte(col("p3"), str("sss"))) -> + ("((column(`p1`) = 12) AND (column(`p3`) >= sss))", "ALWAYS_TRUE()"), + // multiple predicates on partition columns joined with OR - predicate("OR", - predicate("<=", col("part3"), ofString("sss")), - predicate("=", col("part1"), ofInt(2781))) -> - ("((column(`part3`) <= sss) OR (column(`part1`) = 2781))", "ALWAYS_TRUE()"), + or(lte(col("p3"), str("sss")), eq(col("p1"), int(2781))) -> + ("((column(`p3`) <= sss) OR (column(`p1`) = 2781))", "ALWAYS_TRUE()"), // predicates (each on data and partition column) joined with AND - predicate("AND", - predicate("=", col("data1"), ofInt(12)), - predicate(">=", col("part3"), ofString("sss"))) -> - ("(column(`part3`) >= sss)", "(column(`data1`) = 12)"), + and(eq(col("d1"), int(12)), gte(col("p3"), str("sss"))) -> + ("(column(`p3`) >= sss)", "(column(`d1`) = 12)"), // predicates (each on data and partition column) joined with OR - predicate("OR", - predicate("=", col("data1"), ofInt(12)), - predicate(">=", col("part3"), ofString("sss"))) -> - ("ALWAYS_TRUE()", "((column(`data1`) = 12) OR (column(`part3`) >= sss))"), + or(eq(col("d1"), int(12)), gte(col("p3"), str("sss"))) -> + ("ALWAYS_TRUE()", "((column(`d1`) = 12) OR (column(`p3`) >= sss))"), // predicates (multiple on data and partition columns) joined with AND - predicate("AND", - predicate("AND", - predicate("=", col("data1"), ofInt(12)), - predicate(">=", col("data2"), ofString("sss"))), - predicate("AND", - predicate("=", col("part1"), ofInt(12)), - predicate(">=", col("part3"), ofString("sss")))) -> + and( + and(eq(col("d1"), int(12)), gte(col("d2"), str("sss"))), + and(eq(col("p1"), int(12)), gte(col("p3"), str("sss")))) -> ( - "((column(`part1`) = 12) AND (column(`part3`) >= sss))", - "((column(`data1`) = 12) AND (column(`data2`) >= sss))" + "((column(`p1`) = 12) AND (column(`p3`) >= sss))", + "((column(`d1`) = 12) AND (column(`d2`) >= sss))" ), // predicates (multiple on data and partition columns joined with OR) joined with AND - predicate("AND", - predicate("OR", - predicate("=", col("data1"), ofInt(12)), - predicate(">=", col("data2"), ofString("sss"))), - predicate("OR", - predicate("=", col("part1"), ofInt(12)), - predicate(">=", col("part3"), ofString("sss")))) -> + and( + or(eq(col("d1"), int(12)), gte(col("d2"), str("sss"))), + or(eq(col("p1"), int(12)), gte(col("p3"), str("sss")))) -> ( - "((column(`part1`) = 12) OR (column(`part3`) >= sss))", - "((column(`data1`) = 12) OR (column(`data2`) >= sss))" + "((column(`p1`) = 12) OR (column(`p3`) >= sss))", + "((column(`d1`) = 12) OR (column(`d2`) >= sss))" ), // predicates (multiple on data and partition columns joined with OR) joined with OR - predicate("OR", - predicate("OR", - predicate("=", col("data1"), ofInt(12)), - predicate(">=", col("data2"), ofString("sss"))), - predicate("OR", - predicate("=", col("part1"), ofInt(12)), - predicate(">=", col("part3"), ofString("sss")))) -> + or( + or(eq(col("d1"), int(12)), gte(col("d2"), str("sss"))), + or(eq(col("p1"), int(12)), gte(col("p3"), str("sss")))) -> ( "ALWAYS_TRUE()", - "(((column(`data1`) = 12) OR (column(`data2`) >= sss)) OR " + - "((column(`part1`) = 12) OR (column(`part3`) >= sss)))" + "(((column(`d1`) = 12) OR (column(`d2`) >= sss)) OR " + + "((column(`p1`) = 12) OR (column(`p3`) >= sss)))" ), // predicates (data and partitions compared in the same expression) - predicate("AND", - predicate("=", col("data1"), col("part1")), - predicate(">=", col("part3"), ofString("sss"))) -> + and(eq(col("d1"), col("p1")), gte(col("p3"), str("sss"))) -> ( - "(column(`part3`) >= sss)", - "(column(`data1`) = column(`part1`))" + "(column(`p3`) >= sss)", + "(column(`d1`) = column(`p1`))" ), // predicate only on data column but reverse order of literal and column - predicate("=", ofInt(12), col("data1")) -> - ("ALWAYS_TRUE()", "(12 = column(`data1`))") + eq(int(12), col("d1")) -> ("ALWAYS_TRUE()", "(12 = column(`d1`))"), + + // just an unsupported predicate + unsupported("p1") -> ("ALWAYS_TRUE()", "UNSUPPORTED(column(`p1`))"), + + // two unsupported predicates combined with AND and OR + and(unsupported("p1"), unsupported("d1")) -> + ("ALWAYS_TRUE()", "(UNSUPPORTED(column(`p1`)) AND UNSUPPORTED(column(`d1`)))"), + or(unsupported("p1"), unsupported("d1")) -> + ("ALWAYS_TRUE()", "(UNSUPPORTED(column(`p1`)) OR UNSUPPORTED(column(`d1`)))"), + + // supported and unsupported predicates combined with a AND + and(unsupported("p1"), gte(col("p3"), str("sss"))) -> + ("(column(`p3`) >= sss)", "UNSUPPORTED(column(`p1`))"), + and(unsupported("p1"), gte(col("d3"), str("sss"))) -> + ("ALWAYS_TRUE()", "(UNSUPPORTED(column(`p1`)) AND (column(`d3`) >= sss))"), + + // predicates (multiple supported and unsupported joined with AND) joined with AND + and( + and(eq(col("p1"), int(12)), gte(col("p3"), str("sss"))), + and(unsupported("p1"), unsupported("p3"))) -> + ( + "((column(`p1`) = 12) AND (column(`p3`) >= sss))", + "(UNSUPPORTED(column(`p1`)) AND UNSUPPORTED(column(`p3`)))" + ), + + // predicates (multiple supported and unsupported joined with AND) joined with AND + and( + and(eq(col("p1"), int(12)), gte(col("p3"), str("sss"))), + and(unsupported("d1"), unsupported("p3"))) -> + ( + "((column(`p1`) = 12) AND (column(`p3`) >= sss))", + "(UNSUPPORTED(column(`d1`)) AND UNSUPPORTED(column(`p3`)))" + ), + + // predicates (multiple supported and unsupported joined with AND) joined with AND + and( + and(eq(col("p1"), int(12)), gte(col("p3"), str("sss"))), + and(eq(col("p1"), int(12)), unsupported("p3"))) -> + ( + "(((column(`p1`) = 12) AND (column(`p3`) >= sss)) AND (column(`p1`) = 12))", + "UNSUPPORTED(column(`p3`))" + ), + + // predicates (multiple supported and unsupported joined with AND) joined with AND + and( + and(eq(col("p1"), int(14)), gte(col("p3"), str("sss"))), + and(eq(col("d1"), int(12)), unsupported("p3"))) -> + ( + "((column(`p1`) = 14) AND (column(`p3`) >= sss))", + "((column(`d1`) = 12) AND UNSUPPORTED(column(`p3`)))" + ), + + // predicates (multiple supported and unsupported joined with OR) joined with AND + and( + or(eq(col("p1"), int(12)), gte(col("p3"), str("sss"))), + or(unsupported("p1"), unsupported("p3"))) -> + ( + "((column(`p1`) = 12) OR (column(`p3`) >= sss))", + "(UNSUPPORTED(column(`p1`)) OR UNSUPPORTED(column(`p3`)))" + ), + + // predicates (multiple supported and unsupported joined with OR) joined with OR + or( + or(eq(col("p1"), int(12)), gte(col("p3"), str("sss"))), + or(unsupported("p1"), unsupported("p3"))) -> + ( + "ALWAYS_TRUE()", + "(((column(`p1`) = 12) OR (column(`p3`) >= sss)) OR " + + "(UNSUPPORTED(column(`p1`)) OR UNSUPPORTED(column(`p3`))))" + ) ) partitionTestCases.foreach { case (predicate, (partitionPredicate, dataPredicate)) => - test(s"split predicate into data and partition predicates: $predicate") { - val metadataAndDataPredicates = splitMetadataAndDataPredicates(predicate, partitionCols) + test(s"split predicate into guaranteed and best-effort predicates: $predicate") { + val metadataAndDataPredicates = + splitPredicates(defaultExprHandler, tableSchema, predicate, partitionCols) assert(metadataAndDataPredicates._1.toString === partitionPredicate) assert(metadataAndDataPredicates._2.toString === dataPredicate) } @@ -157,21 +205,21 @@ class PartitionUtilsSuite extends AnyFunSuite { // Map entry format: (given predicate -> expected rewritten predicate) val rewriteTestCases = Map( // single predicate on a partition column - predicate("=", col("part2"), ofTimestamp(12)) -> - "(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part2), date) = 12)", + eq(col("p2"), ofTimestamp(12)) -> + "(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), p2), date) = 12)", // multiple predicates on partition columns joined with AND - predicate("AND", - predicate("=", col("part1"), ofInt(12)), - predicate(">=", col("part3"), ofString("sss"))) -> - """((partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 12) AND - |(ELEMENT_AT(column(`add`.`partitionValues`), part3) >= sss))""" + and( + eq(col("p1"), int(12)), + gte(col("p3"), str("sss"))) -> + """((partition_value(ELEMENT_AT(column(`add`.`partitionValues`), p1), integer) = 12) AND + |(ELEMENT_AT(column(`add`.`partitionValues`), p3) >= sss))""" .stripMargin.replaceAll("\n", " "), // multiple predicates on partition columns joined with OR - predicate("OR", - predicate("<=", col("part3"), ofString("sss")), - predicate("=", col("part1"), ofInt(2781))) -> - """((ELEMENT_AT(column(`add`.`partitionValues`), part3) <= sss) OR - |(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 2781))""" + or( + lte(col("p3"), str("sss")), + eq(col("p1"), int(2781))) -> + """((ELEMENT_AT(column(`add`.`partitionValues`), p3) <= sss) OR + |(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), p1), integer) = 2781))""" .stripMargin.replaceAll("\n", " ") ) @@ -184,12 +232,47 @@ class PartitionUtilsSuite extends AnyFunSuite { } } - private def col(names: String*): Column = { - new Column(names.toArray) + val defaultExprHandler = new ExpressionHandler { + override def getEvaluator(inputSchema: StructType, expression: Expression, outputType: DataType) + : ExpressionEvaluator = + throw new UnsupportedOperationException("Not implemented") + + override def isSupported( + inputSchema: StructType, expression: Expression, outputType: DataType): Boolean = { + hasSupportedExpr(expression) + } + + override def getPredicateEvaluator(inputSchema: StructType, predicate: Predicate) + : PredicateEvaluator = + throw new UnsupportedOperationException("Not implemented") + + override def createSelectionVector(values: Array[Boolean], from: Int, to: Int): ColumnVector = + throw new UnsupportedOperationException("Not implemented") + + def hasSupportedExpr(expr: Expression): Boolean = { + expr match { + case _: Column | _: Literal => true + case pred: Predicate => + pred.getName.toUpperCase() match { + case "AND" | "OR" | "=" | "!=" | ">" | ">=" | "<" | "<=" => + !pred.getChildren.asScala.exists(!hasSupportedExpr(_)) + case _ => false + } + case _ => false + } + } } - private def predicate(name: String, children: Expression*): Predicate = { + private def col(names: String*): Column = new Column(names.toArray) + private def predicate(name: String, children: Expression*): Predicate = new Predicate(name, children.asJava) - } + private def and(left: Predicate, right: Predicate): Predicate = predicate("AND", left, right) + private def or(left: Predicate, right: Predicate): Predicate = predicate("OR", left, right) + private def eq(left: Expression, right: Expression): Predicate = predicate("=", left, right) + private def gte(column: Column, literal: Literal): Predicate = predicate(">=", column, literal) + private def lte(column: Column, literal: Literal): Predicate = predicate("<=", column, literal) + private def int(value: Int): Literal = Literal.ofInt(value) + private def str(value: String): Literal = Literal.ofString(value) + private def unsupported(colName: String): Predicate = predicate("UNSUPPORTED", col(colName)); } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/PartitionPruningSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/PartitionPruningSuite.scala index be93364027f..86cd91c2464 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/PartitionPruningSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/PartitionPruningSuite.scala @@ -156,6 +156,58 @@ class PartitionPruningSuite extends AnyFunSuite with TestUtils { ) -> ( predicate("=", col("as_value"), ofString("200")), Seq() + ), + + ( + "partition pruning: predicate with (unsupported expr OR data predicate)", + or( + predicate("=", col("as_value"), ofString("1")), // data col filter + predicate("unsupported") // unsupported expression + ) + ) -> ( + or( + predicate("=", col("as_value"), ofString("1")), // data col filter + predicate("unsupported") // unsupported expression + ), + Seq((18878, "0"), (18878, "1"), (null, "2")) + ), + + ( + "partition pruning: predicate with (unsupported expr OR partition predicate)", + or( + predicate("=", col("as_float"), ofFloat(1)), // partition col filter + predicate("unsupported") // unsupported expression + ) + ) -> ( + or( + predicate("=", col("as_float"), ofFloat(1)), // partition col filter + predicate("unsupported") // unsupported expression + ), + Seq((18878, "0"), (18878, "1"), (null, "2")) + ), + ( + "partition pruning: predicate with (unsupported expr AND data predicate)", + and( + predicate("=", col("as_value"), ofString("1")), // data col filter + predicate("unsupported") // unsupported expression + ) + ) -> ( + and( + predicate("=", col("as_value"), ofString("1")), // data col filter + predicate("unsupported") // unsupported expression + ), + Seq((18878, "0"), (18878, "1"), (null, "2")) + ), + + ( + "partition pruning: predicate with (unsupported expr AND partition predicate)", + and( + predicate("=", col("as_float"), ofFloat(1)), // partition col filter + predicate("unsupported") // unsupported expression + ) + ) -> ( + predicate("unsupported"), // unsupported expression + Seq((18878, "1")) ) ) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index 28cad3061f2..22d6c6fafc2 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -183,7 +183,7 @@ trait TestUtils extends Assertions with SQLHelper { val scan = scanBuilder.build() if (filter != null) { - val actRemainingPredicate = scan.getRemainingFilter() + val actRemainingPredicate = scan.getRemainingFilter(defaultTableClient) assert( actRemainingPredicate.toString === Optional.ofNullable(expectedRemainingFilter).toString) }