From 4d334aa063d3b06d3654fb7e74ad1d4049e46f95 Mon Sep 17 00:00:00 2001 From: caican00 Date: Wed, 1 May 2024 22:46:56 +0800 Subject: [PATCH] [#2543] feat(spark-connector): support row-level operations to iceberg Table --- integration-test/build.gradle.kts | 4 + .../integration/test/spark/SparkCommonIT.java | 103 +++++++++++++ .../test/spark/hive/SparkHiveCatalogIT.java | 5 + .../spark/iceberg/SparkIcebergCatalogIT.java | 139 ++++++++++++++++++ .../test/util/spark/SparkTableInfo.java | 42 +++++- .../test/util/spark/SparkUtilIT.java | 8 +- .../spark/connector/ConnectorConstants.java | 1 + .../spark/connector/catalog/BaseCatalog.java | 42 ++++-- .../connector/hive/GravitinoHiveCatalog.java | 13 +- .../spark/connector/hive/SparkHiveTable.java | 52 ++++++- .../iceberg/GravitinoIcebergCatalog.java | 13 +- .../connector/iceberg/SparkIcebergTable.java | 67 ++++++--- .../plugin/GravitinoDriverPlugin.java | 21 ++- .../spark/connector/utils/ConnectorUtil.java | 26 ++++ .../SparkBaseTableHelper.java} | 68 ++------- .../connector/utils/TestConnectorUtil.java | 31 ++++ 16 files changed, 518 insertions(+), 117 deletions(-) create mode 100644 spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/ConnectorUtil.java rename spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/{table/SparkBaseTable.java => utils/SparkBaseTableHelper.java} (67%) create mode 100644 spark-connector/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/utils/TestConnectorUtil.java diff --git a/integration-test/build.gradle.kts b/integration-test/build.gradle.kts index 384f8417b18..95ce862da68 100644 --- a/integration-test/build.gradle.kts +++ b/integration-test/build.gradle.kts @@ -13,6 +13,8 @@ plugins { val scalaVersion: String = project.properties["scalaVersion"] as? String ?: extra["defaultScalaVersion"].toString() val sparkVersion: String = libs.versions.spark.get() +val sparkMajorVersion: String = sparkVersion.substringBeforeLast(".") +val kyuubiVersion: String = libs.versions.kyuubi.get() val icebergVersion: String = libs.versions.iceberg.get() val scalaCollectionCompatVersion: String = libs.versions.scala.collection.compat.get() @@ -114,6 +116,8 @@ dependencies { exclude("io.dropwizard.metrics") exclude("org.rocksdb") } + testImplementation("org.apache.iceberg:iceberg-spark-runtime-${sparkMajorVersion}_$scalaVersion:$icebergVersion") + testImplementation("org.apache.kyuubi:kyuubi-spark-connector-hive_$scalaVersion:$kyuubiVersion") testImplementation(libs.okhttp3.loginterceptor) testImplementation(libs.postgresql.driver) diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java index 9dab1b46839..498a245228f 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java @@ -68,11 +68,39 @@ protected static String getDeleteSql(String tableName, String condition) { return String.format("DELETE FROM %s where %s", tableName, condition); } + private static String getUpdateTableSql(String tableName, String setClause, String whereClause) { + return String.format("UPDATE %s SET %s WHERE %s", tableName, setClause, whereClause); + } + + private static String getRowLevelUpdateTableSql( + String targetTableName, String selectClause, String sourceTableName, String onClause) { + return String.format( + "MERGE INTO %s " + + "USING (%s) %s " + + "ON %s " + + "WHEN MATCHED THEN UPDATE SET * " + + "WHEN NOT MATCHED THEN INSERT *", + targetTableName, selectClause, sourceTableName, onClause); + } + + private static String getRowLevelDeleteTableSql( + String targetTableName, String selectClause, String sourceTableName, String onClause) { + return String.format( + "MERGE INTO %s " + + "USING (%s) %s " + + "ON %s " + + "WHEN MATCHED THEN DELETE " + + "WHEN NOT MATCHED THEN INSERT *", + targetTableName, selectClause, sourceTableName, onClause); + } + // Whether supports [CLUSTERED BY col_name3 SORTED BY col_name INTO num_buckets BUCKETS] protected abstract boolean supportsSparkSQLClusteredBy(); protected abstract boolean supportsPartition(); + protected abstract boolean supportsDelete(); + // Use a custom database not the original default database because SparkCommonIT couldn't // read&write data to tables in default database. The main reason is default database location is // determined by `hive.metastore.warehouse.dir` in hive-site.xml which is local HDFS address @@ -702,6 +730,28 @@ void testTableOptions() { checkTableReadWrite(tableInfo); } + @Test + @EnabledIf("supportsDelete") + void testDeleteOperation() { + String tableName = "test_row_level_delete_table"; + dropTableIfExists(tableName); + createSimpleTable(tableName); + + SparkTableInfo table = getTableInfo(tableName); + checkTableColumns(tableName, getSimpleTableColumn(), table); + sql( + String.format( + "INSERT INTO %s VALUES (1, '1', 1),(2, '2', 2),(3, '3', 3),(4, '4', 4),(5, '5', 5)", + tableName)); + List queryResult1 = getTableData(tableName); + Assertions.assertEquals(5, queryResult1.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", queryResult1)); + sql(getDeleteSql(tableName, "id <= 4")); + List queryResult2 = getTableData(tableName); + Assertions.assertEquals(1, queryResult2.size()); + Assertions.assertEquals("5,5,5", queryResult2.get(0)); + } + protected void checkTableReadWrite(SparkTableInfo table) { String name = table.getTableIdentifier(); boolean isPartitionTable = table.isPartitionTable(); @@ -760,6 +810,49 @@ protected String getExpectedTableData(SparkTableInfo table) { .collect(Collectors.joining(",")); } + protected void checkTableRowLevelUpdate(String tableName) { + writeToEmptyTableAndCheckData(tableName); + String updatedValues = "id = 6, name = '6', age = 6"; + sql(getUpdateTableSql(tableName, updatedValues, "id = 5")); + List queryResult = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(5, queryResult.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;6,6,6", String.join(";", queryResult)); + } + + protected void checkTableRowLevelDelete(String tableName) { + writeToEmptyTableAndCheckData(tableName); + sql(getDeleteSql(tableName, "id <= 2")); + List queryResult = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(3, queryResult.size()); + Assertions.assertEquals("3,3,3;4,4,4;5,5,5", String.join(";", queryResult)); + } + + protected void checkTableDeleteByMergeInto(String tableName) { + writeToEmptyTableAndCheckData(tableName); + + String sourceTableName = "source_table"; + String selectClause = + "SELECT 1 AS id, '1' AS name, 1 AS age UNION ALL SELECT 6 AS id, '6' AS name, 6 AS age"; + String onClause = String.format("%s.id = %s.id", tableName, sourceTableName); + sql(getRowLevelDeleteTableSql(tableName, selectClause, sourceTableName, onClause)); + List queryResult = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(5, queryResult.size()); + Assertions.assertEquals("2,2,2;3,3,3;4,4,4;5,5,5;6,6,6", String.join(";", queryResult)); + } + + protected void checkTableUpdateByMergeInto(String tableName) { + writeToEmptyTableAndCheckData(tableName); + + String sourceTableName = "source_table"; + String selectClause = + "SELECT 1 AS id, '2' AS name, 2 AS age UNION ALL SELECT 6 AS id, '6' AS name, 6 AS age"; + String onClause = String.format("%s.id = %s.id", tableName, sourceTableName); + sql(getRowLevelUpdateTableSql(tableName, selectClause, sourceTableName, onClause)); + List queryResult = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(6, queryResult.size()); + Assertions.assertEquals("1,2,2;2,2,2;3,3,3;4,4,4;5,5,5;6,6,6", String.join(";", queryResult)); + } + protected String getCreateSimpleTableString(String tableName) { return getCreateSimpleTableString(tableName, false); } @@ -801,6 +894,16 @@ protected void checkTableColumns( .check(tableInfo); } + private void writeToEmptyTableAndCheckData(String tableName) { + sql( + String.format( + "INSERT INTO %s VALUES (1, '1', 1),(2, '2', 2),(3, '3', 3),(4, '4', 4),(5, '5', 5)", + tableName)); + List queryResult = getTableData(tableName); + Assertions.assertEquals(5, queryResult.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", queryResult)); + } + // partition expression may contain "'", like a='s'/b=1 private String getPartitionExpression(SparkTableInfo table, String delimiter) { return table.getPartitionedColumns().stream() diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java index 1f34c87c10f..f42e0332dd1 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java @@ -55,6 +55,11 @@ protected boolean supportsPartition() { return true; } + @Override + protected boolean supportsDelete() { + return false; + } + @Test public void testCreateHiveFormatPartitionTable() { String tableName = "hive_partition_table"; diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java index b94d6eb5e17..f7da5564809 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java @@ -9,6 +9,7 @@ import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo; import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.io.File; import java.util.ArrayList; import java.util.Arrays; @@ -18,10 +19,13 @@ import java.util.Map; import java.util.stream.Collectors; import org.apache.hadoop.fs.Path; +import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; +import org.apache.spark.SparkConf; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; @@ -30,13 +34,21 @@ import org.apache.spark.sql.connector.catalog.FunctionCatalog; import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.internal.StaticSQLConf; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.platform.commons.util.StringUtils; +import scala.Tuple3; public abstract class SparkIcebergCatalogIT extends SparkCommonIT { + private static final String ICEBERG_FORMAT_VERSION = "format-version"; + private static final String ICEBERG_DELETE_MODE = "write.delete.mode"; + private static final String ICEBERG_UPDATE_MODE = "write.update.mode"; + private static final String ICEBERG_MERGE_MODE = "write.merge.mode"; + @Override protected String getCatalogName() { return "iceberg"; @@ -57,6 +69,11 @@ protected boolean supportsPartition() { return true; } + @Override + protected boolean supportsDelete() { + return true; + } + @Override protected String getTableLocation(SparkTableInfo table) { return String.join(File.separator, table.getTableLocation(), "data"); @@ -216,6 +233,24 @@ void testIcebergMetadataColumns() throws NoSuchTableException { testDeleteMetadataColumn(); } + @Test + void testInjectSparkExtensions() { + SparkSession sparkSession = getSparkSession(); + SparkConf conf = sparkSession.sparkContext().getConf(); + Assertions.assertTrue(conf.contains(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key())); + String extensions = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key()); + Assertions.assertTrue(StringUtils.isNotBlank(extensions)); + Assertions.assertEquals(IcebergSparkSessionExtensions.class.getName(), extensions); + } + + @Test + void testIcebergTableRowLevelOperations() { + testIcebergDeleteOperation(); + testIcebergUpdateOperation(); + testIcebergMergeIntoDeleteOperation(); + testIcebergMergeIntoUpdateOperation(); + } + private void testMetadataColumns() { String tableName = "test_metadata_columns"; dropTableIfExists(tableName); @@ -386,6 +421,88 @@ private void testDeleteMetadataColumn() { Assertions.assertEquals(0, queryResult1.size()); } + private void testIcebergDeleteOperation() { + getIcebergTablePropertyValues() + .forEach( + tuple -> { + String tableName = + String.format("test_iceberg_%s_%s_delete_operation", tuple._1(), tuple._2()); + dropTableIfExists(tableName); + createIcebergTableWithTabProperties( + tableName, + tuple._1(), + ImmutableMap.of( + ICEBERG_FORMAT_VERSION, + String.valueOf(tuple._2()), + ICEBERG_DELETE_MODE, + tuple._3())); + checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName)); + checkTableRowLevelDelete(tableName); + }); + } + + private void testIcebergUpdateOperation() { + getIcebergTablePropertyValues() + .forEach( + tuple -> { + String tableName = + String.format("test_iceberg_%s_%s_update_operation", tuple._1(), tuple._2()); + dropTableIfExists(tableName); + createIcebergTableWithTabProperties( + tableName, + tuple._1(), + ImmutableMap.of( + ICEBERG_FORMAT_VERSION, + String.valueOf(tuple._2()), + ICEBERG_UPDATE_MODE, + tuple._3())); + checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName)); + checkTableRowLevelUpdate(tableName); + }); + } + + private void testIcebergMergeIntoDeleteOperation() { + getIcebergTablePropertyValues() + .forEach( + tuple -> { + String tableName = + String.format( + "test_iceberg_%s_%s_mergeinto_delete_operation", tuple._1(), tuple._2()); + dropTableIfExists(tableName); + createIcebergTableWithTabProperties( + tableName, + tuple._1(), + ImmutableMap.of( + ICEBERG_FORMAT_VERSION, + String.valueOf(tuple._2()), + ICEBERG_MERGE_MODE, + tuple._3())); + checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName)); + checkTableDeleteByMergeInto(tableName); + }); + } + + private void testIcebergMergeIntoUpdateOperation() { + getIcebergTablePropertyValues() + .forEach( + tuple -> { + String tableName = + String.format( + "test_iceberg_%s_%s_mergeinto_update_operation", tuple._1(), tuple._2()); + dropTableIfExists(tableName); + createIcebergTableWithTabProperties( + tableName, + tuple._1(), + ImmutableMap.of( + ICEBERG_FORMAT_VERSION, + String.valueOf(tuple._2()), + ICEBERG_MERGE_MODE, + tuple._3())); + checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName)); + checkTableUpdateByMergeInto(tableName); + }); + } + private List getIcebergSimpleTableColumn() { return Arrays.asList( SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"), @@ -416,4 +533,26 @@ private SparkMetadataColumnInfo[] getIcebergMetadataColumns() { new SparkMetadataColumnInfo("_deleted", DataTypes.BooleanType, false) }; } + + private List> getIcebergTablePropertyValues() { + return Arrays.asList( + new Tuple3<>(false, 1, "copy-on-write"), + new Tuple3<>(false, 2, "merge-on-read"), + new Tuple3<>(true, 1, "copy-on-write"), + new Tuple3<>(true, 2, "merge-on-read")); + } + + private void createIcebergTableWithTabProperties( + String tableName, boolean isPartitioned, ImmutableMap tblProperties) { + String partitionedClause = isPartitioned ? " PARTITIONED BY (name) " : ""; + String tblPropertiesStr = + tblProperties.entrySet().stream() + .map(e -> String.format("'%s'='%s'", e.getKey(), e.getValue())) + .collect(Collectors.joining(",")); + String createSql = + String.format( + "CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', age INT) %s TBLPROPERTIES(%s)", + tableName, partitionedClause, tblPropertiesStr); + sql(createSql); + } } diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java index ee08de46ee9..43d3b85adfb 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java @@ -6,7 +6,9 @@ package com.datastrato.gravitino.integration.test.util.spark; import com.datastrato.gravitino.spark.connector.ConnectorConstants; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter; +import com.datastrato.gravitino.spark.connector.hive.SparkHiveTable; +import com.datastrato.gravitino.spark.connector.iceberg.SparkIcebergTable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -18,6 +20,7 @@ import lombok.Data; import org.apache.commons.lang3.StringUtils; import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; +import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.connector.expressions.ApplyTransform; import org.apache.spark.sql.connector.expressions.BucketTransform; @@ -29,6 +32,7 @@ import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.expressions.YearsTransform; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; import org.junit.jupiter.api.Assertions; /** SparkTableInfo is used to check the result in test. */ @@ -89,7 +93,7 @@ void addPartition(Transform partition) { } } - static SparkTableInfo create(SparkBaseTable baseTable) { + static SparkTableInfo create(Table baseTable) { SparkTableInfo sparkTableInfo = new SparkTableInfo(); String identifier = baseTable.name(); String[] items = identifier.split("\\."); @@ -98,7 +102,7 @@ static SparkTableInfo create(SparkBaseTable baseTable) { sparkTableInfo.tableName = items[1]; sparkTableInfo.database = items[0]; sparkTableInfo.columns = - Arrays.stream(baseTable.schema().fields()) + Arrays.stream(getSchema(baseTable).fields()) .map( sparkField -> new SparkColumnInfo( @@ -110,7 +114,7 @@ static SparkTableInfo create(SparkBaseTable baseTable) { sparkTableInfo.comment = baseTable.properties().remove(ConnectorConstants.COMMENT); sparkTableInfo.tableProperties = baseTable.properties(); boolean supportsBucketPartition = - baseTable.getSparkTransformConverter().isSupportsBucketPartition(); + getSparkTransformConverter(baseTable).isSupportsBucketPartition(); Arrays.stream(baseTable.partitioning()) .forEach( transform -> { @@ -149,10 +153,6 @@ static SparkTableInfo create(SparkBaseTable baseTable) { return sparkTableInfo; } - private static boolean isBucketPartition(boolean supportsBucketPartition, Transform transform) { - return supportsBucketPartition && !(transform instanceof SortedBucketTransform); - } - public List getUnPartitionedColumns() { return columns.stream() .filter(column -> !partitionColumnNames.contains(column.name)) @@ -165,6 +165,32 @@ public List getPartitionedColumns() { .collect(Collectors.toList()); } + private static boolean isBucketPartition(boolean supportsBucketPartition, Transform transform) { + return supportsBucketPartition && !(transform instanceof SortedBucketTransform); + } + + private static SparkTransformConverter getSparkTransformConverter(Table baseTable) { + if (baseTable instanceof SparkHiveTable) { + return ((SparkHiveTable) baseTable).getSparkTransformConverter(); + } else if (baseTable instanceof SparkIcebergTable) { + return ((SparkIcebergTable) baseTable).getSparkTransformConverter(); + } else { + throw new IllegalArgumentException( + "Doesn't support Spark table: " + baseTable.getClass().getName()); + } + } + + private static StructType getSchema(Table baseTable) { + if (baseTable instanceof SparkHiveTable) { + return ((SparkHiveTable) baseTable).schema(); + } else if (baseTable instanceof SparkIcebergTable) { + return ((SparkIcebergTable) baseTable).schema(); + } else { + throw new IllegalArgumentException( + "Doesn't support Spark table: " + baseTable.getClass().getName()); + } + } + @Data public static class SparkColumnInfo { private String name; diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java index 2603fbe8f73..bad6fa0cb62 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java @@ -20,7 +20,6 @@ package com.datastrato.gravitino.integration.test.util.spark; import com.datastrato.gravitino.integration.test.util.AbstractIT; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import java.sql.Timestamp; import java.text.SimpleDateFormat; import java.util.Arrays; @@ -130,8 +129,7 @@ protected SparkTableInfo getTableInfo(String tableName) { CommandResult result = (CommandResult) ds.logicalPlan(); DescribeRelation relation = (DescribeRelation) result.commandLogicalPlan(); ResolvedTable table = (ResolvedTable) relation.child(); - SparkBaseTable baseTable = (SparkBaseTable) table.table(); - return SparkTableInfo.create(baseTable); + return SparkTableInfo.create(table.table()); } protected void dropTableIfExists(String tableName) { @@ -159,6 +157,10 @@ protected void insertTableAsSelect(String tableName, String newName) { sql(String.format("INSERT INTO TABLE %s SELECT * FROM %s", newName, tableName)); } + protected static String getSelectAllSqlWithOrder(String tableName) { + return String.format("SELECT * FROM %s ORDER BY id", tableName); + } + private static String getSelectAllSql(String tableName) { return String.format("SELECT * FROM %s", tableName); } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java index 3a49a21470f..9758ff42196 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java @@ -14,6 +14,7 @@ public class ConnectorConstants { public static final String LOCATION = "location"; public static final String DOT = "."; + public static final String COMMA = ","; private ConnectorConstants() {} } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/BaseCatalog.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/BaseCatalog.java index f5994b4ce86..1cfc98de6ef 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/BaseCatalog.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/BaseCatalog.java @@ -19,7 +19,6 @@ import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter.DistributionAndSortOrdersInfo; import com.datastrato.gravitino.spark.connector.SparkTypeConverter; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import java.util.Arrays; @@ -27,6 +26,7 @@ import java.util.Map; import java.util.Optional; import javax.ws.rs.NotSupportedException; +import lombok.SneakyThrows; import org.apache.commons.lang3.StringUtils; import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; @@ -93,15 +93,17 @@ protected abstract TableCatalog createAndInitSparkCatalog( * * @param identifier Spark's table identifier * @param gravitinoTable Gravitino table to do DDL operations + * @param sparkTable specific Spark table to do IO operations * @param sparkCatalog specific Spark catalog to do IO operations * @param propertiesConverter transform properties between Gravitino and Spark * @param sparkTransformConverter sparkTransformConverter convert transforms between Gravitino and * Spark * @return a specific Spark table */ - protected abstract SparkBaseTable createSparkTable( + protected abstract Table createSparkTable( Identifier identifier, com.datastrato.gravitino.rel.Table gravitinoTable, + Table sparkTable, TableCatalog sparkCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter); @@ -162,10 +164,10 @@ public Identifier[] listTables(String[] namespace) throws NoSuchNamespaceExcepti } } + @SneakyThrows @Override public Table createTable( - Identifier ident, Column[] columns, Transform[] transforms, Map properties) - throws TableAlreadyExistsException, NoSuchNamespaceException { + Identifier ident, Column[] columns, Transform[] transforms, Map properties) { NameIdentifier gravitinoIdentifier = NameIdentifier.of(metalakeName, catalogName, getDatabase(ident), ident.name()); com.datastrato.gravitino.rel.Column[] gravitinoColumns = @@ -184,7 +186,7 @@ public Table createTable( sparkTransformConverter.toGravitinoPartitionings(transforms); try { - com.datastrato.gravitino.rel.Table table = + com.datastrato.gravitino.rel.Table gravitinoTable = gravitinoCatalogClient .asTableCatalog() .createTable( @@ -195,12 +197,20 @@ public Table createTable( partitionings, distributionAndSortOrdersInfo.getDistribution(), distributionAndSortOrdersInfo.getSortOrders()); + Table sparkTable = sparkCatalog.loadTable(ident); return createSparkTable( - ident, table, sparkCatalog, propertiesConverter, sparkTransformConverter); + ident, + gravitinoTable, + sparkTable, + sparkCatalog, + propertiesConverter, + sparkTransformConverter); } catch (NoSuchSchemaException e) { throw new NoSuchNamespaceException(ident.namespace()); } catch (com.datastrato.gravitino.exceptions.TableAlreadyExistsException e) { throw new TableAlreadyExistsException(ident); + } catch (NoSuchTableException e) { + throw new NoSuchTableException(ident); } } @@ -208,13 +218,19 @@ public Table createTable( public Table loadTable(Identifier ident) throws NoSuchTableException { try { String database = getDatabase(ident); - com.datastrato.gravitino.rel.Table table = + com.datastrato.gravitino.rel.Table gravitinoTable = gravitinoCatalogClient .asTableCatalog() .loadTable(NameIdentifier.of(metalakeName, catalogName, database, ident.name())); + Table sparkTable = sparkCatalog.loadTable(ident); // Will create a catalog specific table return createSparkTable( - ident, table, sparkCatalog, propertiesConverter, sparkTransformConverter); + ident, + gravitinoTable, + sparkTable, + sparkCatalog, + propertiesConverter, + sparkTransformConverter); } catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) { throw new NoSuchTableException(ident); } @@ -235,14 +251,20 @@ public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchT .map(BaseCatalog::transformTableChange) .toArray(com.datastrato.gravitino.rel.TableChange[]::new); try { - com.datastrato.gravitino.rel.Table table = + com.datastrato.gravitino.rel.Table gravitinoTable = gravitinoCatalogClient .asTableCatalog() .alterTable( NameIdentifier.of(metalakeName, catalogName, getDatabase(ident), ident.name()), gravitinoTableChanges); + Table sparkTable = sparkCatalog.loadTable(ident); return createSparkTable( - ident, table, sparkCatalog, propertiesConverter, sparkTransformConverter); + ident, + gravitinoTable, + sparkTable, + sparkCatalog, + propertiesConverter, + sparkTransformConverter); } catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) { throw new NoSuchTableException(ident); } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/GravitinoHiveCatalog.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/GravitinoHiveCatalog.java index 6ffca1ff9f4..a1cefdaf3a9 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/GravitinoHiveCatalog.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/GravitinoHiveCatalog.java @@ -10,7 +10,6 @@ import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.catalog.BaseCatalog; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import com.google.common.base.Preconditions; import java.util.HashMap; import java.util.Map; @@ -42,14 +41,20 @@ protected TableCatalog createAndInitSparkCatalog( } @Override - protected SparkBaseTable createSparkTable( + protected org.apache.spark.sql.connector.catalog.Table createSparkTable( Identifier identifier, Table gravitinoTable, - TableCatalog sparkCatalog, + org.apache.spark.sql.connector.catalog.Table sparkTable, + TableCatalog sparkHiveCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { return new SparkHiveTable( - identifier, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter); + identifier, + gravitinoTable, + sparkTable, + sparkHiveCatalog, + propertiesConverter, + sparkTransformConverter); } @Override diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/SparkHiveTable.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/SparkHiveTable.java index 91f9468178b..ac656e4e639 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/SparkHiveTable.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/hive/SparkHiveTable.java @@ -8,23 +8,61 @@ import com.datastrato.gravitino.rel.Table; import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; +import com.datastrato.gravitino.spark.connector.utils.SparkBaseTableHelper; +import com.google.common.annotations.VisibleForTesting; +import java.util.Map; +import org.apache.kyuubi.spark.connector.hive.HiveTable; +import org.apache.kyuubi.spark.connector.hive.HiveTableCatalog; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; + +/** Keep consistent behavior with the SparkIcebergTable */ +public class SparkHiveTable extends HiveTable { + + private SparkBaseTableHelper sparkBaseTableHelper; -/** May support more capabilities like partition management. */ -public class SparkHiveTable extends SparkBaseTable { public SparkHiveTable( Identifier identifier, Table gravitinoTable, - TableCatalog sparkCatalog, + org.apache.spark.sql.connector.catalog.Table sparkHiveTable, + TableCatalog sparkHiveCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { - super(identifier, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter); + super( + SparkSession.active(), + ((HiveTable) sparkHiveTable).catalogTable(), + (HiveTableCatalog) sparkHiveCatalog); + this.sparkBaseTableHelper = + new SparkBaseTableHelper( + identifier, gravitinoTable, propertiesConverter, sparkTransformConverter); } @Override - protected boolean isCaseSensitive() { - return false; + public String name() { + return sparkBaseTableHelper.name(false); + } + + @Override + @SuppressWarnings("deprecation") + public StructType schema() { + return sparkBaseTableHelper.schema(); + } + + @Override + public Map properties() { + return sparkBaseTableHelper.properties(); + } + + @Override + public Transform[] partitioning() { + return sparkBaseTableHelper.partitioning(); + } + + @VisibleForTesting + public SparkTransformConverter getSparkTransformConverter() { + return sparkBaseTableHelper.getSparkTransformConverter(); } } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java index f7a028cad7a..5355dbc3dfd 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java @@ -9,7 +9,6 @@ import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.catalog.BaseCatalog; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import com.google.common.base.Preconditions; import java.util.HashMap; import java.util.Locale; @@ -66,14 +65,20 @@ protected TableCatalog createAndInitSparkCatalog( } @Override - protected SparkBaseTable createSparkTable( + protected org.apache.spark.sql.connector.catalog.Table createSparkTable( Identifier identifier, Table gravitinoTable, - TableCatalog sparkCatalog, + org.apache.spark.sql.connector.catalog.Table sparkTable, + TableCatalog sparkIcebergCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { return new SparkIcebergTable( - identifier, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter); + identifier, + gravitinoTable, + sparkTable, + sparkIcebergCatalog, + propertiesConverter, + sparkTransformConverter); } @Override diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/SparkIcebergTable.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/SparkIcebergTable.java index 22dd0bb73a8..5c040e45670 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/SparkIcebergTable.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/SparkIcebergTable.java @@ -8,43 +8,72 @@ import com.datastrato.gravitino.rel.Table; import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; -import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; +import com.datastrato.gravitino.spark.connector.utils.SparkBaseTableHelper; +import com.google.common.annotations.VisibleForTesting; +import java.lang.reflect.Field; +import java.util.Map; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.source.SparkTable; import org.apache.spark.sql.connector.catalog.Identifier; -import org.apache.spark.sql.connector.catalog.MetadataColumn; -import org.apache.spark.sql.connector.catalog.SupportsDelete; -import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; import org.apache.spark.sql.connector.catalog.TableCatalog; -import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; -public class SparkIcebergTable extends SparkBaseTable - implements SupportsDelete, SupportsMetadataColumns { +/** + * For spark-connector in Iceberg, it explicitly uses SparkTable to identify whether it is an + * Iceberg table, so the SparkIcebergTable must extend SparkTable. + */ +public class SparkIcebergTable extends SparkTable { + + private SparkBaseTableHelper sparkBaseTableHelper; public SparkIcebergTable( Identifier identifier, Table gravitinoTable, + org.apache.spark.sql.connector.catalog.Table sparkIcebergTable, TableCatalog sparkIcebergCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { - super( - identifier, - gravitinoTable, - sparkIcebergCatalog, - propertiesConverter, - sparkTransformConverter); + super(((SparkTable) sparkIcebergTable).table(), !isCacheEnabled(sparkIcebergCatalog)); + this.sparkBaseTableHelper = + new SparkBaseTableHelper( + identifier, gravitinoTable, propertiesConverter, sparkTransformConverter); + } + + @Override + public String name() { + return sparkBaseTableHelper.name(true); } @Override - public boolean canDeleteWhere(Filter[] filters) { - return ((SupportsDelete) getSparkTable()).canDeleteWhere(filters); + @SuppressWarnings("deprecation") + public StructType schema() { + return sparkBaseTableHelper.schema(); } @Override - public void deleteWhere(Filter[] filters) { - ((SupportsDelete) getSparkTable()).deleteWhere(filters); + public Map properties() { + return sparkBaseTableHelper.properties(); } @Override - public MetadataColumn[] metadataColumns() { - return ((SupportsMetadataColumns) getSparkTable()).metadataColumns(); + public Transform[] partitioning() { + return sparkBaseTableHelper.partitioning(); + } + + @VisibleForTesting + public SparkTransformConverter getSparkTransformConverter() { + return sparkBaseTableHelper.getSparkTransformConverter(); + } + + private static boolean isCacheEnabled(TableCatalog sparkIcebergCatalog) { + try { + SparkCatalog catalog = ((SparkCatalog) sparkIcebergCatalog); + Field cacheEnabled = catalog.getClass().getDeclaredField("cacheEnabled"); + cacheEnabled.setAccessible(true); + return cacheEnabled.getBoolean(catalog); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException("Failed to get cacheEnabled field from SparkCatalog", e); + } } } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java index 3f830de2cdc..201666cc004 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java @@ -5,6 +5,8 @@ package com.datastrato.gravitino.spark.connector.plugin; +import static com.datastrato.gravitino.spark.connector.utils.ConnectorUtil.removeDuplicates; + import com.datastrato.gravitino.Catalog; import com.datastrato.gravitino.spark.connector.GravitinoSparkConfig; import com.datastrato.gravitino.spark.connector.catalog.GravitinoCatalogManager; @@ -15,10 +17,12 @@ import java.util.Locale; import java.util.Map; import org.apache.commons.lang3.StringUtils; +import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.api.plugin.DriverPlugin; import org.apache.spark.api.plugin.PluginContext; +import org.apache.spark.sql.internal.StaticSQLConf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,6 +34,8 @@ public class GravitinoDriverPlugin implements DriverPlugin { private static final Logger LOG = LoggerFactory.getLogger(GravitinoDriverPlugin.class); private GravitinoCatalogManager catalogManager; + private static final String[] GRAVITINO_DRIVER_EXTENSIONS = + new String[] {IcebergSparkSessionExtensions.class.getName()}; @Override public Map init(SparkContext sc, PluginContext pluginContext) { @@ -48,7 +54,7 @@ public Map init(SparkContext sc, PluginContext pluginContext) { catalogManager = GravitinoCatalogManager.create(gravitinoUri, metalake); catalogManager.loadRelationalCatalogs(); registerGravitinoCatalogs(conf, catalogManager.getCatalogs()); - registerSqlExtensions(); + registerSqlExtensions(conf); return Collections.emptyMap(); } @@ -103,6 +109,15 @@ private void registerCatalog(SparkConf sparkConf, String catalogName, String pro LOG.info("Register {} catalog to Spark catalog manager.", catalogName); } - // Todo inject Iceberg extensions - private void registerSqlExtensions() {} + private void registerSqlExtensions(SparkConf conf) { + String gravitinoDriverExtensions = String.join(",", GRAVITINO_DRIVER_EXTENSIONS); + if (conf.contains(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key())) { + String sparkSessionExtensions = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key()); + conf.set( + StaticSQLConf.SPARK_SESSION_EXTENSIONS().key(), + removeDuplicates(GRAVITINO_DRIVER_EXTENSIONS, sparkSessionExtensions)); + } else { + conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key(), gravitinoDriverExtensions); + } + } } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/ConnectorUtil.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/ConnectorUtil.java new file mode 100644 index 00000000000..eeaa56c9da2 --- /dev/null +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/ConnectorUtil.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ + +package com.datastrato.gravitino.spark.connector.utils; + +import static com.datastrato.gravitino.spark.connector.ConnectorConstants.COMMA; + +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.Set; +import org.apache.commons.lang3.StringUtils; + +public class ConnectorUtil { + + public static String removeDuplicates(String[] elements, String otherElements) { + Set uniqueElements = new LinkedHashSet<>(Arrays.asList(elements)); + if (StringUtils.isNotBlank(otherElements)) { + uniqueElements.addAll(Arrays.asList(otherElements.split(COMMA))); + } + return uniqueElements.stream() + .reduce((element1, element2) -> element1 + COMMA + element2) + .orElse(""); + } +} diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/SparkBaseTableHelper.java similarity index 67% rename from spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java rename to spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/SparkBaseTableHelper.java index d1333135f19..0011968bd35 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/utils/SparkBaseTableHelper.java @@ -3,7 +3,7 @@ * This software is licensed under the Apache License version 2. */ -package com.datastrato.gravitino.spark.connector.table; +package com.datastrato.gravitino.spark.connector.utils; import com.datastrato.gravitino.rel.expressions.distributions.Distribution; import com.datastrato.gravitino.rel.expressions.sorts.SortOrder; @@ -11,65 +11,47 @@ import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.SparkTypeConverter; -import com.google.common.annotations.VisibleForTesting; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.connector.catalog.Identifier; -import org.apache.spark.sql.connector.catalog.SupportsRead; -import org.apache.spark.sql.connector.catalog.SupportsWrite; -import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableCapability; -import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.connector.expressions.Transform; -import org.apache.spark.sql.connector.read.ScanBuilder; -import org.apache.spark.sql.connector.write.LogicalWriteInfo; -import org.apache.spark.sql.connector.write.WriteBuilder; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.MetadataBuilder; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * Provides schema info from Gravitino, IO from the internal spark table. The specific catalog table * could implement more capabilities like SupportsPartitionManagement for Hive table, SupportsIndex * for JDBC table, SupportsRowLevelOperations for Iceberg table. */ -public abstract class SparkBaseTable implements Table, SupportsRead, SupportsWrite { +public class SparkBaseTableHelper { + private Identifier identifier; private com.datastrato.gravitino.rel.Table gravitinoTable; - private TableCatalog sparkCatalog; - private Table lazySparkTable; private PropertiesConverter propertiesConverter; private SparkTransformConverter sparkTransformConverter; - public SparkBaseTable( + public SparkBaseTableHelper( Identifier identifier, com.datastrato.gravitino.rel.Table gravitinoTable, - TableCatalog sparkCatalog, PropertiesConverter propertiesConverter, SparkTransformConverter sparkTransformConverter) { this.identifier = identifier; this.gravitinoTable = gravitinoTable; - this.sparkCatalog = sparkCatalog; this.propertiesConverter = propertiesConverter; this.sparkTransformConverter = sparkTransformConverter; } - @Override - public String name() { - return getNormalizedIdentifier(identifier, gravitinoTable.name()); + public String name(boolean isCaseSensitive) { + return getNormalizedIdentifier(identifier, gravitinoTable.name(), isCaseSensitive); } - @Override - @SuppressWarnings("deprecation") public StructType schema() { List structs = Arrays.stream(gravitinoTable.columns()) @@ -93,7 +75,6 @@ public StructType schema() { return DataTypes.createStructType(structs); } - @Override public Map properties() { Map properties = new HashMap(); if (gravitinoTable.properties() != null) { @@ -110,22 +91,6 @@ public Map properties() { return properties; } - @Override - public Set capabilities() { - return getSparkTable().capabilities(); - } - - @Override - public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { - return ((SupportsRead) getSparkTable()).newScanBuilder(options); - } - - @Override - public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { - return ((SupportsWrite) getSparkTable()).newWriteBuilder(info); - } - - @Override public Transform[] partitioning() { com.datastrato.gravitino.rel.expressions.transforms.Transform[] partitions = gravitinoTable.partitioning(); @@ -134,35 +99,20 @@ public Transform[] partitioning() { return sparkTransformConverter.toSparkTransform(partitions, distribution, sortOrders); } - protected Table getSparkTable() { - if (lazySparkTable == null) { - try { - this.lazySparkTable = sparkCatalog.loadTable(identifier); - } catch (NoSuchTableException e) { - throw new RuntimeException(e); - } - } - return lazySparkTable; - } - - @VisibleForTesting public SparkTransformConverter getSparkTransformConverter() { return sparkTransformConverter; } - protected boolean isCaseSensitive() { - return true; - } - // The underlying catalogs may not case-sensitive, to keep consistent with the action of SparkSQL, // we should return normalized identifiers. - private String getNormalizedIdentifier(Identifier tableIdentifier, String gravitinoTableName) { + private String getNormalizedIdentifier( + Identifier tableIdentifier, String gravitinoTableName, boolean isCaseSensitive) { if (tableIdentifier.namespace().length == 0) { return gravitinoTableName; } String databaseName = tableIdentifier.namespace()[0]; - if (isCaseSensitive() == false) { + if (!isCaseSensitive) { databaseName = databaseName.toLowerCase(Locale.ROOT); } diff --git a/spark-connector/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/utils/TestConnectorUtil.java b/spark-connector/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/utils/TestConnectorUtil.java new file mode 100644 index 00000000000..81d452d28e8 --- /dev/null +++ b/spark-connector/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/utils/TestConnectorUtil.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ + +package com.datastrato.gravitino.spark.connector.utils; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class TestConnectorUtil { + + @Test + void testRemoveDuplicates() { + String[] elements = {"a", "b", "c"}; + String otherElements = "a,d,e"; + String result = ConnectorUtil.removeDuplicates(elements, otherElements); + Assertions.assertEquals(result, "a,b,c,d,e"); + + elements = new String[] {"a", "a", "b", "c"}; + otherElements = ""; + result = ConnectorUtil.removeDuplicates(elements, otherElements); + Assertions.assertEquals(result, "a,b,c"); + + elements = new String[] {"a", "a", "b", "c"}; + result = ConnectorUtil.removeDuplicates(elements, null); + Assertions.assertEquals(result, "a,b,c"); + } +}