Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Aggregate push down for incremental scan #10538

Merged
merged 2 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.iceberg.StructLike;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.TableScan;
import org.apache.iceberg.expressions.AggregateEvaluator;
import org.apache.iceberg.expressions.Binder;
import org.apache.iceberg.expressions.BoundAggregate;
Expand Down Expand Up @@ -232,15 +231,8 @@ public boolean pushAggregation(Aggregation aggregation) {
return false;
}

TableScan scan = table.newScan().includeColumnStats();
Snapshot snapshot = readSnapshot();
if (snapshot == null) {
LOG.info("Skipping aggregate pushdown: table snapshot is null");
return false;
}
scan = scan.useSnapshot(snapshot.snapshotId());
scan = configureSplitPlanning(scan);
scan = scan.filter(filterExpression());
org.apache.iceberg.Scan scan =
buildIcebergBatchScan(true /* include Column Stats */, schemaWithMetadataColumns());

try (CloseableIterable<FileScanTask> fileScanTasks = scan.planFiles()) {
List<FileScanTask> tasks = ImmutableList.copyOf(fileScanTasks);
Expand Down Expand Up @@ -282,11 +274,6 @@ private boolean canPushDownAggregation(Aggregation aggregation) {
return false;
}

if (readConf.startSnapshotId() != null) {
LOG.info("Skipping aggregate pushdown: incremental scan is not supported");
return false;
}

// If group by expression is the same as the partition, the statistics information can still
// be used to calculate min/max/count, will enable aggregate push down in next phase.
// TODO: enable aggregate push down for partition col group by expression
Expand All @@ -298,17 +285,6 @@ private boolean canPushDownAggregation(Aggregation aggregation) {
return true;
}

private Snapshot readSnapshot() {
Snapshot snapshot;
if (readConf.snapshotId() != null) {
snapshot = table.snapshot(readConf.snapshotId());
} else {
snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch());
}

return snapshot;
}

private boolean metricsModeSupportsAggregatePushDown(List<BoundAggregate<?, ?>> aggregates) {
MetricsConfig config = MetricsConfig.forTable(table);
for (BoundAggregate aggregate : aggregates) {
Expand Down Expand Up @@ -387,6 +363,18 @@ public Scan build() {
}

private Scan buildBatchScan() {
Schema expectedSchema = schemaWithMetadataColumns();
return new SparkBatchQueryScan(
spark,
table,
buildIcebergBatchScan(false /* not include Column Stats */, expectedSchema),
readConf,
expectedSchema,
filterExpressions,
metricsReporter::scanReport);
}

private org.apache.iceberg.Scan buildIcebergBatchScan(boolean withStats, Schema expectedSchema) {
Long snapshotId = readConf.snapshotId();
Long asOfTimestamp = readConf.asOfTimestamp();
String branch = readConf.branch();
Expand Down Expand Up @@ -427,22 +415,30 @@ private Scan buildBatchScan() {
SparkReadOptions.END_TIMESTAMP);

if (startSnapshotId != null) {
return buildIncrementalAppendScan(startSnapshotId, endSnapshotId);
return buildIncrementalAppendScan(startSnapshotId, endSnapshotId, withStats, expectedSchema);
} else {
return buildBatchScan(snapshotId, asOfTimestamp, branch, tag);
return buildBatchScan(snapshotId, asOfTimestamp, branch, tag, withStats, expectedSchema);
}
}

private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String branch, String tag) {
Schema expectedSchema = schemaWithMetadataColumns();

private org.apache.iceberg.Scan buildBatchScan(
Long snapshotId,
Long asOfTimestamp,
String branch,
String tag,
boolean withStats,
Schema expectedSchema) {
BatchScan scan =
newBatchScan()
.caseSensitive(caseSensitive)
.filter(filterExpression())
.project(expectedSchema)
.metricsReporter(metricsReporter);

if (withStats) {
scan = scan.includeColumnStats();
}

if (snapshotId != null) {
scan = scan.useSnapshot(snapshotId);
}
Expand All @@ -459,21 +455,11 @@ private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String branch,
scan = scan.useRef(tag);
}

scan = configureSplitPlanning(scan);

return new SparkBatchQueryScan(
spark,
table,
scan,
readConf,
expectedSchema,
filterExpressions,
metricsReporter::scanReport);
return configureSplitPlanning(scan);
}

private Scan buildIncrementalAppendScan(long startSnapshotId, Long endSnapshotId) {
Schema expectedSchema = schemaWithMetadataColumns();

private org.apache.iceberg.Scan buildIncrementalAppendScan(
long startSnapshotId, Long endSnapshotId, boolean withStats, Schema expectedSchema) {
IncrementalAppendScan scan =
table
.newIncrementalAppendScan()
Expand All @@ -483,20 +469,15 @@ private Scan buildIncrementalAppendScan(long startSnapshotId, Long endSnapshotId
.project(expectedSchema)
.metricsReporter(metricsReporter);

if (withStats) {
scan = scan.includeColumnStats();
}

if (endSnapshotId != null) {
scan = scan.toSnapshot(endSnapshotId);
}

scan = configureSplitPlanning(scan);

return new SparkBatchQueryScan(
spark,
table,
scan,
readConf,
expectedSchema,
filterExpressions,
metricsReporter::scanReport);
return configureSplitPlanning(scan);
}

@SuppressWarnings("CyclomaticComplexity")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.functions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
Expand Down Expand Up @@ -290,29 +291,44 @@ public void testIncrementalScanOptions() throws IOException {
"Cannot set only end-snapshot-id for incremental scans. Please, set start-snapshot-id too.");

// test (1st snapshot, current snapshot] incremental scan.
List<SimpleRecord> result =
Dataset<Row> unboundedIncrementalResult =
spark
.read()
.format("iceberg")
.option("start-snapshot-id", snapshotIds.get(3).toString())
.load(tableLocation)
.load(tableLocation);
List<SimpleRecord> result1 =
unboundedIncrementalResult
.orderBy("id")
.as(Encoders.bean(SimpleRecord.class))
.collectAsList();
assertThat(result).as("Records should match").isEqualTo(expectedRecords.subList(1, 4));
assertThat(result1).as("Records should match").isEqualTo(expectedRecords.subList(1, 4));
assertThat(unboundedIncrementalResult.count())
.as("Unprocessed count should match record count")
.isEqualTo(3);

Row row1 = unboundedIncrementalResult.agg(functions.min("id"), functions.max("id")).head();
assertThat(row1.getInt(0)).as("min value should match").isEqualTo(2);
assertThat(row1.getInt(1)).as("max value should match").isEqualTo(4);

// test (2nd snapshot, 3rd snapshot] incremental scan.
Dataset<Row> resultDf =
Dataset<Row> incrementalResult =
spark
.read()
.format("iceberg")
.option("start-snapshot-id", snapshotIds.get(2).toString())
.option("end-snapshot-id", snapshotIds.get(1).toString())
.load(tableLocation);
List<SimpleRecord> result1 =
resultDf.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
assertThat(result1).as("Records should match").isEqualTo(expectedRecords.subList(2, 3));
assertThat(resultDf.count()).as("Unprocessed count should match record count").isEqualTo(1);
List<SimpleRecord> result2 =
incrementalResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
assertThat(result2).as("Records should match").isEqualTo(expectedRecords.subList(2, 3));
assertThat(incrementalResult.count())
.as("Unprocessed count should match record count")
.isEqualTo(1);

Row row2 = incrementalResult.agg(functions.min("id"), functions.max("id")).head();
assertThat(row2.getInt(0)).as("min value should match").isEqualTo(3);
assertThat(row2.getInt(1)).as("max value should match").isEqualTo(3);
}

@TestTemplate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.CatalogTestBase;
import org.apache.iceberg.spark.SparkReadOptions;
import org.apache.iceberg.spark.TestBase;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.execution.ExplainMode;
import org.apache.spark.sql.functions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
Expand Down Expand Up @@ -808,4 +813,49 @@ public void testInfinity() {
});
assertEquals("min/max/count push down", expected, actual);
}

@TestTemplate
public void testAggregatePushDownForIncrementalScan() {
sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
sql(
"INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ",
tableName);
long snapshotId1 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
sql("INSERT INTO %s VALUES (4, 7777), (5, 8888)", tableName);
long snapshotId2 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
sql("INSERT INTO %s VALUES (6, -7777), (7, 8888)", tableName);
long snapshotId3 = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
sql("INSERT INTO %s VALUES (8, 7777), (9, 9999)", tableName);

Dataset<Row> pushdownDs =
spark
.read()
.format("iceberg")
.option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId2)
.option(SparkReadOptions.END_SNAPSHOT_ID, snapshotId3)
.load(tableName)
.agg(functions.min("data"), functions.max("data"), functions.count("data"));
String explain1 = pushdownDs.queryExecution().explainString(ExplainMode.fromString("simple"));
assertThat(explain1).contains("LocalTableScan", "min(data)", "max(data)", "count(data)");

List<Object[]> expected1 = Lists.newArrayList();
expected1.add(new Object[] {-7777, 8888, 2L});
assertEquals("min/max/count push down", expected1, rowsToJava(pushdownDs.collectAsList()));

Dataset<Row> unboundedPushdownDs =
spark
.read()
.format("iceberg")
.option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId1)
.load(tableName)
.agg(functions.min("data"), functions.max("data"), functions.count("data"));
String explain2 =
unboundedPushdownDs.queryExecution().explainString(ExplainMode.fromString("simple"));
assertThat(explain2).contains("LocalTableScan", "min(data)", "max(data)", "count(data)");

List<Object[]> expected2 = Lists.newArrayList();
expected2.add(new Object[] {-7777, 9999, 6L});
assertEquals(
"min/max/count push down", expected2, rowsToJava(unboundedPushdownDs.collectAsList()));
}
}