From a47937c0c1fcafe57d7dc83551d8c9a3ce0ab1b9 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 21 Jun 2024 10:34:13 -0700 Subject: [PATCH] Spark 3.5: Support Aggregate push down for incremental scan (#10538) --- .../spark/source/SparkScanBuilder.java | 89 ++++++++----------- .../spark/source/TestDataSourceOptions.java | 32 +++++-- .../spark/sql/TestAggregatePushDown.java | 50 +++++++++++ 3 files changed, 109 insertions(+), 62 deletions(-) diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index d6f34231ae75..b430e6fca233 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -37,7 +37,6 @@ import org.apache.iceberg.StructLike; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; -import org.apache.iceberg.TableScan; import org.apache.iceberg.expressions.AggregateEvaluator; import org.apache.iceberg.expressions.Binder; import org.apache.iceberg.expressions.BoundAggregate; @@ -232,15 +231,8 @@ public boolean pushAggregation(Aggregation aggregation) { return false; } - TableScan scan = table.newScan().includeColumnStats(); - Snapshot snapshot = readSnapshot(); - if (snapshot == null) { - LOG.info("Skipping aggregate pushdown: table snapshot is null"); - return false; - } - scan = scan.useSnapshot(snapshot.snapshotId()); - scan = configureSplitPlanning(scan); - scan = scan.filter(filterExpression()); + org.apache.iceberg.Scan scan = + buildIcebergBatchScan(true /* include Column Stats */, schemaWithMetadataColumns()); try (CloseableIterable fileScanTasks = scan.planFiles()) { List tasks = ImmutableList.copyOf(fileScanTasks); @@ -282,11 +274,6 @@ private boolean canPushDownAggregation(Aggregation aggregation) { return false; } - if (readConf.startSnapshotId() != null) { - LOG.info("Skipping aggregate pushdown: incremental scan is not supported"); - return false; - } - // If group by expression is the same as the partition, the statistics information can still // be used to calculate min/max/count, will enable aggregate push down in next phase. // TODO: enable aggregate push down for partition col group by expression @@ -298,17 +285,6 @@ private boolean canPushDownAggregation(Aggregation aggregation) { return true; } - private Snapshot readSnapshot() { - Snapshot snapshot; - if (readConf.snapshotId() != null) { - snapshot = table.snapshot(readConf.snapshotId()); - } else { - snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch()); - } - - return snapshot; - } - private boolean metricsModeSupportsAggregatePushDown(List> aggregates) { MetricsConfig config = MetricsConfig.forTable(table); for (BoundAggregate aggregate : aggregates) { @@ -387,6 +363,18 @@ public Scan build() { } private Scan buildBatchScan() { + Schema expectedSchema = schemaWithMetadataColumns(); + return new SparkBatchQueryScan( + spark, + table, + buildIcebergBatchScan(false /* not include Column Stats */, expectedSchema), + readConf, + expectedSchema, + filterExpressions, + metricsReporter::scanReport); + } + + private org.apache.iceberg.Scan buildIcebergBatchScan(boolean withStats, Schema expectedSchema) { Long snapshotId = readConf.snapshotId(); Long asOfTimestamp = readConf.asOfTimestamp(); String branch = readConf.branch(); @@ -427,15 +415,19 @@ private Scan buildBatchScan() { SparkReadOptions.END_TIMESTAMP); if (startSnapshotId != null) { - return buildIncrementalAppendScan(startSnapshotId, endSnapshotId); + return buildIncrementalAppendScan(startSnapshotId, endSnapshotId, withStats, expectedSchema); } else { - return buildBatchScan(snapshotId, asOfTimestamp, branch, tag); + return buildBatchScan(snapshotId, asOfTimestamp, branch, tag, withStats, expectedSchema); } } - private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String branch, String tag) { - Schema expectedSchema = schemaWithMetadataColumns(); - + private org.apache.iceberg.Scan buildBatchScan( + Long snapshotId, + Long asOfTimestamp, + String branch, + String tag, + boolean withStats, + Schema expectedSchema) { BatchScan scan = newBatchScan() .caseSensitive(caseSensitive) @@ -443,6 +435,10 @@ private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String branch, .project(expectedSchema) .metricsReporter(metricsReporter); + if (withStats) { + scan = scan.includeColumnStats(); + } + if (snapshotId != null) { scan = scan.useSnapshot(snapshotId); } @@ -459,21 +455,11 @@ private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String branch, scan = scan.useRef(tag); } - scan = configureSplitPlanning(scan); - - return new SparkBatchQueryScan( - spark, - table, - scan, - readConf, - expectedSchema, - filterExpressions, - metricsReporter::scanReport); + return configureSplitPlanning(scan); } - private Scan buildIncrementalAppendScan(long startSnapshotId, Long endSnapshotId) { - Schema expectedSchema = schemaWithMetadataColumns(); - + private org.apache.iceberg.Scan buildIncrementalAppendScan( + long startSnapshotId, Long endSnapshotId, boolean withStats, Schema expectedSchema) { IncrementalAppendScan scan = table .newIncrementalAppendScan() @@ -483,20 +469,15 @@ private Scan buildIncrementalAppendScan(long startSnapshotId, Long endSnapshotId .project(expectedSchema) .metricsReporter(metricsReporter); + if (withStats) { + scan = scan.includeColumnStats(); + } + if (endSnapshotId != null) { scan = scan.toSnapshot(endSnapshotId); } - scan = configureSplitPlanning(scan); - - return new SparkBatchQueryScan( - spark, - table, - scan, - readConf, - expectedSchema, - filterExpressions, - metricsReporter::scanReport); + return configureSplitPlanning(scan); } @SuppressWarnings("CyclomaticComplexity") diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java index ff6ddea32360..627fe15f2819 100644 --- a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java @@ -57,6 +57,7 @@ import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.TestTemplate; @@ -290,29 +291,44 @@ public void testIncrementalScanOptions() throws IOException { "Cannot set only end-snapshot-id for incremental scans. Please, set start-snapshot-id too."); // test (1st snapshot, current snapshot] incremental scan. - List result = + Dataset unboundedIncrementalResult = spark .read() .format("iceberg") .option("start-snapshot-id", snapshotIds.get(3).toString()) - .load(tableLocation) + .load(tableLocation); + List result1 = + unboundedIncrementalResult .orderBy("id") .as(Encoders.bean(SimpleRecord.class)) .collectAsList(); - assertThat(result).as("Records should match").isEqualTo(expectedRecords.subList(1, 4)); + assertThat(result1).as("Records should match").isEqualTo(expectedRecords.subList(1, 4)); + assertThat(unboundedIncrementalResult.count()) + .as("Unprocessed count should match record count") + .isEqualTo(3); + + Row row1 = unboundedIncrementalResult.agg(functions.min("id"), functions.max("id")).head(); + assertThat(row1.getInt(0)).as("min value should match").isEqualTo(2); + assertThat(row1.getInt(1)).as("max value should match").isEqualTo(4); // test (2nd snapshot, 3rd snapshot] incremental scan. - Dataset resultDf = + Dataset incrementalResult = spark .read() .format("iceberg") .option("start-snapshot-id", snapshotIds.get(2).toString()) .option("end-snapshot-id", snapshotIds.get(1).toString()) .load(tableLocation); - List result1 = - resultDf.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); - assertThat(result1).as("Records should match").isEqualTo(expectedRecords.subList(2, 3)); - assertThat(resultDf.count()).as("Unprocessed count should match record count").isEqualTo(1); + List result2 = + incrementalResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); + assertThat(result2).as("Records should match").isEqualTo(expectedRecords.subList(2, 3)); + assertThat(incrementalResult.count()) + .as("Unprocessed count should match record count") + .isEqualTo(1); + + Row row2 = incrementalResult.agg(functions.min("id"), functions.max("id")).head(); + assertThat(row2.getInt(0)).as("min value should match").isEqualTo(3); + assertThat(row2.getInt(1)).as("max value should match").isEqualTo(3); } @TestTemplate diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java index 05515946c145..7e9bdeec8af0 100644 --- a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java @@ -35,8 +35,13 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkReadOptions; import org.apache.iceberg.spark.TestBase; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.ExplainMode; +import org.apache.spark.sql.functions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.TestTemplate; @@ -808,4 +813,49 @@ public void testInfinity() { }); assertEquals("min/max/count push down", expected, actual); } + + @TestTemplate + public void testAggregatePushDownForIncrementalScan() { + sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName); + sql( + "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", + tableName); + long snapshotId1 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + sql("INSERT INTO %s VALUES (4, 7777), (5, 8888)", tableName); + long snapshotId2 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + sql("INSERT INTO %s VALUES (6, -7777), (7, 8888)", tableName); + long snapshotId3 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + sql("INSERT INTO %s VALUES (8, 7777), (9, 9999)", tableName); + + Dataset pushdownDs = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId2) + .option(SparkReadOptions.END_SNAPSHOT_ID, snapshotId3) + .load(tableName) + .agg(functions.min("data"), functions.max("data"), functions.count("data")); + String explain1 = pushdownDs.queryExecution().explainString(ExplainMode.fromString("simple")); + assertThat(explain1).contains("LocalTableScan", "min(data)", "max(data)", "count(data)"); + + List expected1 = Lists.newArrayList(); + expected1.add(new Object[] {-7777, 8888, 2L}); + assertEquals("min/max/count push down", expected1, rowsToJava(pushdownDs.collectAsList())); + + Dataset unboundedPushdownDs = + spark + .read() + .format("iceberg") + .option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId1) + .load(tableName) + .agg(functions.min("data"), functions.max("data"), functions.count("data")); + String explain2 = + unboundedPushdownDs.queryExecution().explainString(ExplainMode.fromString("simple")); + assertThat(explain2).contains("LocalTableScan", "min(data)", "max(data)", "count(data)"); + + List expected2 = Lists.newArrayList(); + expected2.add(new Object[] {-7777, 9999, 6L}); + assertEquals( + "min/max/count push down", expected2, rowsToJava(unboundedPushdownDs.collectAsList())); + } }