Skip to content

Commit

Permalink
Spark 3.5: Support metadata columns in staged scan (#8872)
Browse files Browse the repository at this point in the history
  • Loading branch information
zinking authored Nov 16, 2023
1 parent 6ec3de3 commit bfe1d03
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark.extensions;

import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.apache.iceberg.ScanTask;
import org.apache.iceberg.Table;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.ScanTaskSetManager;
import org.apache.iceberg.spark.Spark3Util;
import org.apache.iceberg.spark.SparkCatalogConfig;
import org.apache.iceberg.spark.SparkReadOptions;
import org.apache.iceberg.spark.source.SimpleRecord;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Test;
import org.junit.runners.Parameterized;

public class TestMetaColumnProjectionWithStageScan extends SparkExtensionsTestBase {

public TestMetaColumnProjectionWithStageScan(
String catalogName, String implementation, Map<String, String> config) {
super(catalogName, implementation, config);
}

@Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}")
public static Object[][] parameters() {
return new Object[][] {
{
SparkCatalogConfig.HADOOP.catalogName(),
SparkCatalogConfig.HADOOP.implementation(),
SparkCatalogConfig.HADOOP.properties()
}
};
}

@After
public void removeTables() {
sql("DROP TABLE IF EXISTS %s", tableName);
}

private <T extends ScanTask> void stageTask(
Table tab, String fileSetID, CloseableIterable<T> tasks) {
ScanTaskSetManager taskSetManager = ScanTaskSetManager.get();
taskSetManager.stageTasks(tab, fileSetID, Lists.newArrayList(tasks));
}

@Test
public void testReadStageTableMeta() throws Exception {
sql(
"CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES"
+ "('format-version'='2', 'write.delete.mode'='merge-on-read')",
tableName);

List<SimpleRecord> records =
Lists.newArrayList(
new SimpleRecord(1, "a"),
new SimpleRecord(2, "b"),
new SimpleRecord(3, "c"),
new SimpleRecord(4, "d"));

spark
.createDataset(records, Encoders.bean(SimpleRecord.class))
.coalesce(1)
.writeTo(tableName)
.append();

Table table = Spark3Util.loadIcebergTable(spark, tableName);
table.refresh();
String tableLocation = table.location();

try (CloseableIterable<ScanTask> tasks = table.newBatchScan().planFiles()) {
String fileSetID = UUID.randomUUID().toString();
stageTask(table, fileSetID, tasks);
Dataset<Row> scanDF2 =
spark
.read()
.format("iceberg")
.option(SparkReadOptions.FILE_OPEN_COST, "0")
.option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID)
.load(tableLocation);

Assertions.assertThat(scanDF2.columns().length).isEqualTo(2);
}

try (CloseableIterable<ScanTask> tasks = table.newBatchScan().planFiles()) {
String fileSetID = UUID.randomUUID().toString();
stageTask(table, fileSetID, tasks);
Dataset<Row> scanDF =
spark
.read()
.format("iceberg")
.option(SparkReadOptions.FILE_OPEN_COST, "0")
.option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID)
.load(tableLocation)
.select("*", "_pos");

List<Row> rows = scanDF.collectAsList();
ImmutableList<Object[]> expectedRows =
ImmutableList.of(row(1L, "a", 0L), row(2L, "b", 1L), row(3L, "c", 2L), row(4L, "d", 3L));
assertEquals("result should match", expectedRows, rowsToJava(rows));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Objects;
import org.apache.iceberg.ScanTask;
import org.apache.iceberg.ScanTaskGroup;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
Expand All @@ -39,9 +40,8 @@ class SparkStagedScan extends SparkScan {

private List<ScanTaskGroup<ScanTask>> taskGroups = null; // lazy cache of tasks

SparkStagedScan(SparkSession spark, Table table, SparkReadConf readConf) {
super(spark, table, readConf, table.schema(), ImmutableList.of(), null);

SparkStagedScan(SparkSession spark, Table table, Schema expectedSchema, SparkReadConf readConf) {
super(spark, table, readConf, expectedSchema, ImmutableList.of(), null);
this.taskSetId = readConf.scanTaskSetId();
this.splitSize = readConf.splitSize();
this.splitLookback = readConf.splitLookback();
Expand Down Expand Up @@ -77,14 +77,16 @@ public boolean equals(Object other) {
SparkStagedScan that = (SparkStagedScan) other;
return table().name().equals(that.table().name())
&& Objects.equals(taskSetId, that.taskSetId)
&& readSchema().equals(that.readSchema())
&& Objects.equals(splitSize, that.splitSize)
&& Objects.equals(splitLookback, that.splitLookback)
&& Objects.equals(openFileCost, that.openFileCost);
}

@Override
public int hashCode() {
return Objects.hash(table().name(), taskSetId, splitSize, splitSize, openFileCost);
return Objects.hash(
table().name(), taskSetId, readSchema(), splitSize, splitSize, openFileCost);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,75 @@
*/
package org.apache.iceberg.spark.source;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.iceberg.MetadataColumns;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.SparkReadConf;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

class SparkStagedScanBuilder implements ScanBuilder {
class SparkStagedScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns {

private final SparkSession spark;
private final Table table;
private final SparkReadConf readConf;
private final List<String> metaColumns = Lists.newArrayList();

private Schema schema = null;

SparkStagedScanBuilder(SparkSession spark, Table table, CaseInsensitiveStringMap options) {
this.spark = spark;
this.table = table;
this.readConf = new SparkReadConf(spark, table, options);
this.schema = table.schema();
}

@Override
public Scan build() {
return new SparkStagedScan(spark, table, readConf);
return new SparkStagedScan(spark, table, schemaWithMetadataColumns(), readConf);
}

@Override
public void pruneColumns(StructType requestedSchema) {
StructType requestedProjection = removeMetaColumns(requestedSchema);
this.schema = SparkSchemaUtil.prune(schema, requestedProjection);

Stream.of(requestedSchema.fields())
.map(StructField::name)
.filter(MetadataColumns::isMetadataColumn)
.distinct()
.forEach(metaColumns::add);
}

private StructType removeMetaColumns(StructType structType) {
return new StructType(
Stream.of(structType.fields())
.filter(field -> MetadataColumns.nonMetadataColumn(field.name()))
.toArray(StructField[]::new));
}

private Schema schemaWithMetadataColumns() {
// metadata columns
List<Types.NestedField> fields =
metaColumns.stream()
.distinct()
.map(name -> MetadataColumns.metadataColumn(table, name))
.collect(Collectors.toList());
Schema meta = new Schema(fields);

// schema of rows returned by readers
return TypeUtil.join(schema, meta);
}
}

0 comments on commit bfe1d03

Please sign in to comment.