Skip to content

Commit

Permalink
Spark 3.5: Support Aggregate push down for incremental scan (#10538)
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxingao authored Jun 21, 2024
1 parent e57b9f6 commit a47937c
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.iceberg.StructLike;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.TableScan;
import org.apache.iceberg.expressions.AggregateEvaluator;
import org.apache.iceberg.expressions.Binder;
import org.apache.iceberg.expressions.BoundAggregate;
Expand Down Expand Up @@ -232,15 +231,8 @@ public boolean pushAggregation(Aggregation aggregation) {
return false;
}

TableScan scan = table.newScan().includeColumnStats();
Snapshot snapshot = readSnapshot();
if (snapshot == null) {
LOG.info("Skipping aggregate pushdown: table snapshot is null");
return false;
}
scan = scan.useSnapshot(snapshot.snapshotId());
scan = configureSplitPlanning(scan);
scan = scan.filter(filterExpression());
org.apache.iceberg.Scan scan =
buildIcebergBatchScan(true /* include Column Stats */, schemaWithMetadataColumns());

try (CloseableIterable<FileScanTask> fileScanTasks = scan.planFiles()) {
List<FileScanTask> tasks = ImmutableList.copyOf(fileScanTasks);
Expand Down Expand Up @@ -282,11 +274,6 @@ private boolean canPushDownAggregation(Aggregation aggregation) {
return false;
}

if (readConf.startSnapshotId() != null) {
LOG.info("Skipping aggregate pushdown: incremental scan is not supported");
return false;
}

// If group by expression is the same as the partition, the statistics information can still
// be used to calculate min/max/count, will enable aggregate push down in next phase.
// TODO: enable aggregate push down for partition col group by expression
Expand All @@ -298,17 +285,6 @@ private boolean canPushDownAggregation(Aggregation aggregation) {
return true;
}

private Snapshot readSnapshot() {
Snapshot snapshot;
if (readConf.snapshotId() != null) {
snapshot = table.snapshot(readConf.snapshotId());
} else {
snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch());
}

return snapshot;
}

private boolean metricsModeSupportsAggregatePushDown(List<BoundAggregate<?, ?>> aggregates) {
MetricsConfig config = MetricsConfig.forTable(table);
for (BoundAggregate aggregate : aggregates) {
Expand Down Expand Up @@ -387,6 +363,18 @@ public Scan build() {
}

private Scan buildBatchScan() {
Schema expectedSchema = schemaWithMetadataColumns();
return new SparkBatchQueryScan(
spark,
table,
buildIcebergBatchScan(false /* not include Column Stats */, expectedSchema),
readConf,
expectedSchema,
filterExpressions,
metricsReporter::scanReport);
}

private org.apache.iceberg.Scan buildIcebergBatchScan(boolean withStats, Schema expectedSchema) {
Long snapshotId = readConf.snapshotId();
Long asOfTimestamp = readConf.asOfTimestamp();
String branch = readConf.branch();
Expand Down Expand Up @@ -427,22 +415,30 @@ private Scan buildBatchScan() {
SparkReadOptions.END_TIMESTAMP);

if (startSnapshotId != null) {
return buildIncrementalAppendScan(startSnapshotId, endSnapshotId);
return buildIncrementalAppendScan(startSnapshotId, endSnapshotId, withStats, expectedSchema);
} else {
return buildBatchScan(snapshotId, asOfTimestamp, branch, tag);
return buildBatchScan(snapshotId, asOfTimestamp, branch, tag, withStats, expectedSchema);
}
}

private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String branch, String tag) {
Schema expectedSchema = schemaWithMetadataColumns();

private org.apache.iceberg.Scan buildBatchScan(
Long snapshotId,
Long asOfTimestamp,
String branch,
String tag,
boolean withStats,
Schema expectedSchema) {
BatchScan scan =
newBatchScan()
.caseSensitive(caseSensitive)
.filter(filterExpression())
.project(expectedSchema)
.metricsReporter(metricsReporter);

if (withStats) {
scan = scan.includeColumnStats();
}

if (snapshotId != null) {
scan = scan.useSnapshot(snapshotId);
}
Expand All @@ -459,21 +455,11 @@ private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String branch,
scan = scan.useRef(tag);
}

scan = configureSplitPlanning(scan);

return new SparkBatchQueryScan(
spark,
table,
scan,
readConf,
expectedSchema,
filterExpressions,
metricsReporter::scanReport);
return configureSplitPlanning(scan);
}

private Scan buildIncrementalAppendScan(long startSnapshotId, Long endSnapshotId) {
Schema expectedSchema = schemaWithMetadataColumns();

private org.apache.iceberg.Scan buildIncrementalAppendScan(
long startSnapshotId, Long endSnapshotId, boolean withStats, Schema expectedSchema) {
IncrementalAppendScan scan =
table
.newIncrementalAppendScan()
Expand All @@ -483,20 +469,15 @@ private Scan buildIncrementalAppendScan(long startSnapshotId, Long endSnapshotId
.project(expectedSchema)
.metricsReporter(metricsReporter);

if (withStats) {
scan = scan.includeColumnStats();
}

if (endSnapshotId != null) {
scan = scan.toSnapshot(endSnapshotId);
}

scan = configureSplitPlanning(scan);

return new SparkBatchQueryScan(
spark,
table,
scan,
readConf,
expectedSchema,
filterExpressions,
metricsReporter::scanReport);
return configureSplitPlanning(scan);
}

@SuppressWarnings("CyclomaticComplexity")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.functions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
Expand Down Expand Up @@ -290,29 +291,44 @@ public void testIncrementalScanOptions() throws IOException {
"Cannot set only end-snapshot-id for incremental scans. Please, set start-snapshot-id too.");

// test (1st snapshot, current snapshot] incremental scan.
List<SimpleRecord> result =
Dataset<Row> unboundedIncrementalResult =
spark
.read()
.format("iceberg")
.option("start-snapshot-id", snapshotIds.get(3).toString())
.load(tableLocation)
.load(tableLocation);
List<SimpleRecord> result1 =
unboundedIncrementalResult
.orderBy("id")
.as(Encoders.bean(SimpleRecord.class))
.collectAsList();
assertThat(result).as("Records should match").isEqualTo(expectedRecords.subList(1, 4));
assertThat(result1).as("Records should match").isEqualTo(expectedRecords.subList(1, 4));
assertThat(unboundedIncrementalResult.count())
.as("Unprocessed count should match record count")
.isEqualTo(3);

Row row1 = unboundedIncrementalResult.agg(functions.min("id"), functions.max("id")).head();
assertThat(row1.getInt(0)).as("min value should match").isEqualTo(2);
assertThat(row1.getInt(1)).as("max value should match").isEqualTo(4);

// test (2nd snapshot, 3rd snapshot] incremental scan.
Dataset<Row> resultDf =
Dataset<Row> incrementalResult =
spark
.read()
.format("iceberg")
.option("start-snapshot-id", snapshotIds.get(2).toString())
.option("end-snapshot-id", snapshotIds.get(1).toString())
.load(tableLocation);
List<SimpleRecord> result1 =
resultDf.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
assertThat(result1).as("Records should match").isEqualTo(expectedRecords.subList(2, 3));
assertThat(resultDf.count()).as("Unprocessed count should match record count").isEqualTo(1);
List<SimpleRecord> result2 =
incrementalResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
assertThat(result2).as("Records should match").isEqualTo(expectedRecords.subList(2, 3));
assertThat(incrementalResult.count())
.as("Unprocessed count should match record count")
.isEqualTo(1);

Row row2 = incrementalResult.agg(functions.min("id"), functions.max("id")).head();
assertThat(row2.getInt(0)).as("min value should match").isEqualTo(3);
assertThat(row2.getInt(1)).as("max value should match").isEqualTo(3);
}

@TestTemplate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.CatalogTestBase;
import org.apache.iceberg.spark.SparkReadOptions;
import org.apache.iceberg.spark.TestBase;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.execution.ExplainMode;
import org.apache.spark.sql.functions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
Expand Down Expand Up @@ -808,4 +813,49 @@ public void testInfinity() {
});
assertEquals("min/max/count push down", expected, actual);
}

@TestTemplate
public void testAggregatePushDownForIncrementalScan() {
sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
sql(
"INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ",
tableName);
long snapshotId1 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
sql("INSERT INTO %s VALUES (4, 7777), (5, 8888)", tableName);
long snapshotId2 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
sql("INSERT INTO %s VALUES (6, -7777), (7, 8888)", tableName);
long snapshotId3 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
sql("INSERT INTO %s VALUES (8, 7777), (9, 9999)", tableName);

Dataset<Row> pushdownDs =
spark
.read()
.format("iceberg")
.option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId2)
.option(SparkReadOptions.END_SNAPSHOT_ID, snapshotId3)
.load(tableName)
.agg(functions.min("data"), functions.max("data"), functions.count("data"));
String explain1 = pushdownDs.queryExecution().explainString(ExplainMode.fromString("simple"));
assertThat(explain1).contains("LocalTableScan", "min(data)", "max(data)", "count(data)");

List<Object[]> expected1 = Lists.newArrayList();
expected1.add(new Object[] {-7777, 8888, 2L});
assertEquals("min/max/count push down", expected1, rowsToJava(pushdownDs.collectAsList()));

Dataset<Row> unboundedPushdownDs =
spark
.read()
.format("iceberg")
.option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId1)
.load(tableName)
.agg(functions.min("data"), functions.max("data"), functions.count("data"));
String explain2 =
unboundedPushdownDs.queryExecution().explainString(ExplainMode.fromString("simple"));
assertThat(explain2).contains("LocalTableScan", "min(data)", "max(data)", "count(data)");

List<Object[]> expected2 = Lists.newArrayList();
expected2.add(new Object[] {-7777, 9999, 6L});
assertEquals(
"min/max/count push down", expected2, rowsToJava(unboundedPushdownDs.collectAsList()));
}
}

0 comments on commit a47937c

Please sign in to comment.