Skip to content

Commit

Permalink
[apache#2587] feat(spark-connector): Support iceberg metadataColumns (a…
Browse files Browse the repository at this point in the history
…pache#2717)

### What changes were proposed in this pull request?
Support retrieve iceberg metadataColumns, such as `_spec_id`,
`_partition`, `_file`, `_pos`, `_deleted`.

### Why are the changes needed?

Support retrieve iceberg metadataColumns, row-level operations depend on
this.

Fix: apache#2587

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?

New integration test.
  • Loading branch information
caican00 authored and diqiu50 committed Jun 13, 2024
1 parent 1b7cc45 commit 5ea0ca1
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,33 @@
package com.datastrato.gravitino.integration.test.spark.iceberg;

import com.datastrato.gravitino.integration.test.spark.SparkCommonIT;
import com.datastrato.gravitino.integration.test.util.spark.SparkMetadataColumnInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker;
import com.google.common.collect.ImmutableList;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException;
import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.connector.catalog.CatalogPlugin;
import org.apache.spark.sql.connector.catalog.FunctionCatalog;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -200,6 +210,186 @@ void testIcebergPartitions() {
});
}

@Test
void testIcebergMetadataColumns() throws NoSuchTableException {
testMetadataColumns();
testSpecAndPartitionMetadataColumns();
testPositionMetadataColumn();
testPartitionMetadataColumnWithUnPartitionedTable();
testFileMetadataColumn();
testDeleteMetadataColumn();
}

private void testMetadataColumns() {
String tableName = "test_metadata_columns";
dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + " PARTITIONED BY (name);";
sql(createTableSQL);

SparkTableInfo tableInfo = getTableInfo(tableName);

SparkMetadataColumnInfo[] metadataColumns = getIcebergMetadataColumns();
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withMetadataColumns(metadataColumns);
checker.check(tableInfo);
}

private void testSpecAndPartitionMetadataColumns() {
String tableName = "test_spec_partition";
dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + " PARTITIONED BY (name);";
sql(createTableSQL);

SparkTableInfo tableInfo = getTableInfo(tableName);

SparkMetadataColumnInfo[] metadataColumns = getIcebergMetadataColumns();
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withMetadataColumns(metadataColumns);
checker.check(tableInfo);

String insertData = String.format("INSERT into %s values(2,'a', 1);", tableName);
sql(insertData);

String expectedMetadata = "0,a";
String getMetadataSQL =
String.format("SELECT _spec_id, _partition FROM %s ORDER BY _spec_id", tableName);
List<String> queryResult = getTableMetadata(getMetadataSQL);
Assertions.assertEquals(1, queryResult.size());
Assertions.assertEquals(expectedMetadata, queryResult.get(0));
}

private void testPositionMetadataColumn() throws NoSuchTableException {
String tableName = "test_position_metadata_column";
dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + " PARTITIONED BY (name);";
sql(createTableSQL);

SparkTableInfo tableInfo = getTableInfo(tableName);

SparkMetadataColumnInfo[] metadataColumns = getIcebergMetadataColumns();
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withMetadataColumns(metadataColumns);
checker.check(tableInfo);

List<Integer> ids = new ArrayList<>();
for (int id = 0; id < 200; id++) {
ids.add(id);
}
Dataset<Row> df =
getSparkSession()
.createDataset(ids, Encoders.INT())
.withColumnRenamed("value", "id")
.withColumn("name", new Column(Literal.create("a", DataTypes.StringType)))
.withColumn("age", new Column(Literal.create(1, DataTypes.IntegerType)));
df.coalesce(1).writeTo(tableName).append();

Assertions.assertEquals(200, getSparkSession().table(tableName).count());

String getMetadataSQL = String.format("SELECT _pos FROM %s", tableName);
List<String> expectedRows = ids.stream().map(String::valueOf).collect(Collectors.toList());
List<String> queryResult = getTableMetadata(getMetadataSQL);
Assertions.assertEquals(expectedRows.size(), queryResult.size());
Assertions.assertArrayEquals(expectedRows.toArray(), queryResult.toArray());
}

private void testPartitionMetadataColumnWithUnPartitionedTable() {
String tableName = "test_position_metadata_column_in_unpartitioned_table";
dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
sql(createTableSQL);

SparkTableInfo tableInfo = getTableInfo(tableName);

SparkMetadataColumnInfo[] metadataColumns = getIcebergMetadataColumns();
metadataColumns[1] =
new SparkMetadataColumnInfo(
"_partition", DataTypes.createStructType(new StructField[] {}), true);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withMetadataColumns(metadataColumns);
checker.check(tableInfo);

String insertData = String.format("INSERT into %s values(2,'a', 1);", tableName);
sql(insertData);

String getMetadataSQL = String.format("SELECT _partition FROM %s", tableName);
Assertions.assertEquals(1, getSparkSession().sql(getMetadataSQL).count());
Row row = getSparkSession().sql(getMetadataSQL).collectAsList().get(0);
Assertions.assertNotNull(row);
Assertions.assertNull(row.get(0));
}

private void testFileMetadataColumn() {
String tableName = "test_file_metadata_column";
dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + " PARTITIONED BY (name);";
sql(createTableSQL);

SparkTableInfo tableInfo = getTableInfo(tableName);

SparkMetadataColumnInfo[] metadataColumns = getIcebergMetadataColumns();
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withMetadataColumns(metadataColumns);
checker.check(tableInfo);

String insertData = String.format("INSERT into %s values(2,'a', 1);", tableName);
sql(insertData);

String getMetadataSQL = String.format("SELECT _file FROM %s", tableName);
List<String> queryResult = getTableMetadata(getMetadataSQL);
Assertions.assertEquals(1, queryResult.size());
Assertions.assertTrue(queryResult.get(0).contains(tableName));
}

private void testDeleteMetadataColumn() {
String tableName = "test_delete_metadata_column";
dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + " PARTITIONED BY (name);";
sql(createTableSQL);

SparkTableInfo tableInfo = getTableInfo(tableName);

SparkMetadataColumnInfo[] metadataColumns = getIcebergMetadataColumns();
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withMetadataColumns(metadataColumns);
checker.check(tableInfo);

String insertData = String.format("INSERT into %s values(2,'a', 1);", tableName);
sql(insertData);

String getMetadataSQL = String.format("SELECT _deleted FROM %s", tableName);
List<String> queryResult = getTableMetadata(getMetadataSQL);
Assertions.assertEquals(1, queryResult.size());
Assertions.assertEquals("false", queryResult.get(0));

sql(getDeleteSql(tableName, "1 = 1"));

List<String> queryResult1 = getTableMetadata(getMetadataSQL);
Assertions.assertEquals(0, queryResult1.size());
}

private List<SparkTableInfo.SparkColumnInfo> getIcebergSimpleTableColumn() {
return Arrays.asList(
SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"),
Expand All @@ -216,4 +406,18 @@ private String getCreateIcebergSimpleTableString(String tableName) {
"CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', ts TIMESTAMP)",
tableName);
}

private SparkMetadataColumnInfo[] getIcebergMetadataColumns() {
return new SparkMetadataColumnInfo[] {
new SparkMetadataColumnInfo("_spec_id", DataTypes.IntegerType, false),
new SparkMetadataColumnInfo(
"_partition",
DataTypes.createStructType(
new StructField[] {DataTypes.createStructField("name", DataTypes.StringType, true)}),
true),
new SparkMetadataColumnInfo("_file", DataTypes.StringType, false),
new SparkMetadataColumnInfo("_pos", DataTypes.LongType, false),
new SparkMetadataColumnInfo("_deleted", DataTypes.BooleanType, false)
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2024 Datastrato Pvt Ltd.
* This software is licensed under the Apache License version 2.
*/

package com.datastrato.gravitino.integration.test.util.spark;

import org.apache.spark.sql.connector.catalog.MetadataColumn;
import org.apache.spark.sql.types.DataType;

public class SparkMetadataColumnInfo implements MetadataColumn {
private final String name;
private final DataType dataType;
private final boolean isNullable;

public SparkMetadataColumnInfo(String name, DataType dataType, boolean isNullable) {
this.name = name;
this.dataType = dataType;
this.isNullable = isNullable;
}

public String name() {
return this.name;
}

public DataType dataType() {
return this.dataType;
}

public boolean isNullable() {
return this.isNullable;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import javax.ws.rs.NotSupportedException;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.apache.spark.sql.connector.expressions.ApplyTransform;
import org.apache.spark.sql.connector.expressions.BucketTransform;
Expand All @@ -42,6 +43,7 @@ public class SparkTableInfo {
private Transform bucket;
private List<Transform> partitions = new ArrayList<>();
private Set<String> partitionColumnNames = new HashSet<>();
private SparkMetadataColumnInfo[] metadataColumns;

public SparkTableInfo() {}

Expand Down Expand Up @@ -132,6 +134,18 @@ static SparkTableInfo create(SparkBaseTable baseTable) {
"Doesn't support Spark transform: " + transform.name());
}
});
if (baseTable instanceof SupportsMetadataColumns) {
SupportsMetadataColumns supportsMetadataColumns = (SupportsMetadataColumns) baseTable;
sparkTableInfo.metadataColumns =
Arrays.stream(supportsMetadataColumns.metadataColumns())
.map(
metadataColumn ->
new SparkMetadataColumnInfo(
metadataColumn.name(),
metadataColumn.dataType(),
metadataColumn.isNullable()))
.toArray(SparkMetadataColumnInfo[]::new);
}
return sparkTableInfo;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ private enum CheckField {
BUCKET,
COMMENT,
TABLE_PROPERTY,
METADATA_COLUMN
}

public SparkTableInfoChecker withName(String name) {
Expand Down Expand Up @@ -135,6 +136,12 @@ public SparkTableInfoChecker withTableProperties(Map<String, String> properties)
return this;
}

public SparkTableInfoChecker withMetadataColumns(SparkMetadataColumnInfo[] metadataColumns) {
this.expectedTableInfo.setMetadataColumns(metadataColumns);
this.checkFields.add(CheckField.METADATA_COLUMN);
return this;
}

public void check(SparkTableInfo realTableInfo) {
checkFields.stream()
.forEach(
Expand All @@ -156,6 +163,22 @@ public void check(SparkTableInfo realTableInfo) {
case BUCKET:
Assertions.assertEquals(expectedTableInfo.getBucket(), realTableInfo.getBucket());
break;
case METADATA_COLUMN:
Assertions.assertEquals(
expectedTableInfo.getMetadataColumns().length,
realTableInfo.getMetadataColumns().length);
for (int i = 0; i < expectedTableInfo.getMetadataColumns().length; i++) {
Assertions.assertEquals(
expectedTableInfo.getMetadataColumns()[i].name(),
realTableInfo.getMetadataColumns()[i].name());
Assertions.assertEquals(
expectedTableInfo.getMetadataColumns()[i].dataType(),
realTableInfo.getMetadataColumns()[i].dataType());
Assertions.assertEquals(
expectedTableInfo.getMetadataColumns()[i].isNullable(),
realTableInfo.getMetadataColumns()[i].isNullable());
}
break;
case COMMENT:
Assertions.assertEquals(
expectedTableInfo.getComment(), realTableInfo.getComment());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ protected List<String> getQueryData(String querySql) {
.collect(Collectors.toList());
}

// columns data are joined by ','
protected List<String> getTableMetadata(String getTableMetadataSql) {
return getQueryData(getTableMetadataSql);
}

// Create SparkTableInfo from SparkBaseTable retrieved from LogicalPlan.
protected SparkTableInfo getTableInfo(String tableName) {
Dataset ds = getSparkSession().sql("DESC TABLE EXTENDED " + tableName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import com.datastrato.gravitino.spark.connector.SparkTransformConverter;
import com.datastrato.gravitino.spark.connector.table.SparkBaseTable;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.MetadataColumn;
import org.apache.spark.sql.connector.catalog.SupportsDelete;
import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.apache.spark.sql.sources.Filter;

public class SparkIcebergTable extends SparkBaseTable implements SupportsDelete {
public class SparkIcebergTable extends SparkBaseTable
implements SupportsDelete, SupportsMetadataColumns {

public SparkIcebergTable(
Identifier identifier,
Expand All @@ -39,4 +42,9 @@ public boolean canDeleteWhere(Filter[] filters) {
public void deleteWhere(Filter[] filters) {
((SupportsDelete) getSparkTable()).deleteWhere(filters);
}

@Override
public MetadataColumn[] metadataColumns() {
return ((SupportsMetadataColumns) getSparkTable()).metadataColumns();
}
}

0 comments on commit 5ea0ca1

Please sign in to comment.