From d404224ed8026b373d3cfe975141fecd946cc2d3 Mon Sep 17 00:00:00 2001 From: fanng Date: Sun, 18 Feb 2024 19:40:50 +0800 Subject: [PATCH] add parition --- gradle/libs.versions.toml | 1 + .../integration/test/spark/SparkCommonIT.java | 149 +++++++++-- .../integration/test/spark/SparkEnvIT.java | 32 ++- .../test/spark/hive/SparkHiveCatalogIT.java | 37 +++ .../test/util/spark/SparkTableInfo.java | 56 +++++ .../util/spark/SparkTableInfoChecker.java | 42 ++++ spark-connector/build.gradle.kts | 5 + .../spark/connector/ConnectorConstants.java | 7 + .../connector/SparkTransformConverter.java | 233 ++++++++++++++++++ .../connector/catalog/GravitinoCatalog.java | 14 +- .../spark/connector/table/SparkBaseTable.java | 13 + .../TestSparkTransformConverter.java | 229 +++++++++++++++++ 12 files changed, 790 insertions(+), 28 deletions(-) create mode 100644 spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java create mode 100644 spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 71deaac89ef..acc44b37f8d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -31,6 +31,7 @@ iceberg = '1.3.1' # 1.4.0 causes test to fail trino = '426' spark = "3.4.1" # 3.5.0 causes tests to fail scala-collection-compat = "2.7.0" +scala-java-compat = "1.0.2" sqlite-jdbc = "3.42.0.0" testng = "7.5.1" testcontainers = "1.19.0" diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java index 6b735affd69..fe4df1129db 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java @@ -8,12 +8,14 @@ import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo; import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker; import com.google.common.collect.ImmutableMap; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import org.apache.hadoop.fs.Path; import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; @@ -23,16 +25,10 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; import org.junit.platform.commons.util.StringUtils; public abstract class SparkCommonIT extends SparkEnvIT { - private static String getSelectAllSql(String tableName) { - return String.format("SELECT * FROM %s", tableName); - } - - private static String getInsertWithoutPartitionSql(String tableName, String values) { - return String.format("INSERT INTO %s VALUES (%s)", tableName, values); - } // To generate test data for write&read table. private static final Map typeConstant = @@ -51,8 +47,25 @@ private static String getInsertWithoutPartitionSql(String tableName, String valu DataTypes.createStructField("col2", DataTypes.StringType, true))), "struct(1, 'a')"); - // Use a custom database not the original default database because SparkCommonIT couldn't - // read&write data to tables in default database. The main reason is default database location is + private static String getSelectAllSql(String tableName) { + return String.format("SELECT * FROM %s", tableName); + } + + private static String getInsertWithoutPartitionSql(String tableName, String values) { + return String.format("INSERT INTO %s VALUES (%s)", tableName, values); + } + + private static String getInsertWithPartitionSql( + String tableName, String partitionString, String values) { + return String.format( + "INSERT OVERWRITE %s PARTITION (%s) VALUES (%s)", tableName, partitionString, values); + } + + // Whether supports [CLUSTERED BY col_name3 SORTED BY col_name INTO num_buckets BUCKETS] + protected abstract boolean supportsSparkSQLClusteredBy(); + + // Use a custom database not the original default database because SparkIT couldn't read&write + // data to tables in default database. The main reason is default database location is // determined by `hive.metastore.warehouse.dir` in hive-site.xml which is local HDFS address // not real HDFS address. The location of tables created under default database is like // hdfs://localhost:9000/xxx which couldn't read write data from SparkCommonIT. Will use default @@ -69,10 +82,6 @@ void init() { sql("USE " + getDefaultDatabase()); } - protected String getDefaultDatabase() { - return "default_db"; - } - @Test void testLoadCatalogs() { Set catalogs = getCatalogs(); @@ -442,24 +451,94 @@ void testComplexType() { checkTableReadWrite(tableInfo); } - private void checkTableColumns( - String tableName, List columnInfos, SparkTableInfo tableInfo) { - SparkTableInfoChecker.create() - .withName(tableName) - .withColumns(columnInfos) - .withComment(null) - .check(tableInfo); + @Test + void testCreateDatasourceFormatPartitionTable() { + String tableName = "datasource_partition_table"; + + dropTableIfExists(tableName); + String createTableSQL = getCreateSimpleTableString(tableName); + createTableSQL = createTableSQL + "USING PARQUET PARTITIONED BY (name, age)"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getSimpleTableColumn()) + .withIdentifyPartition(Arrays.asList("name", "age")); + checker.check(tableInfo); + checkTableReadWrite(tableInfo); + checkPartitionDirExists(tableInfo); + } + + @Test + @EnabledIf("supportsSparkSQLClusteredBy") + void testCreateBucketTable() { + String tableName = "bucket_table"; + + dropTableIfExists(tableName); + String createTableSQL = getCreateSimpleTableString(tableName); + createTableSQL = createTableSQL + "CLUSTERED BY (id, name) INTO 4 buckets;"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getSimpleTableColumn()) + .withBucket(4, Arrays.asList("id", "name")); + checker.check(tableInfo); + checkTableReadWrite(tableInfo); } - private void checkTableReadWrite(SparkTableInfo table) { + @Test + @EnabledIf("supportsSparkSQLClusteredBy") + void testCreateSortBucketTable() { + String tableName = "sort_bucket_table"; + + dropTableIfExists(tableName); + String createTableSQL = getCreateSimpleTableString(tableName); + createTableSQL = + createTableSQL + "CLUSTERED BY (id, name) SORTED BY (name, id) INTO 4 buckets;"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getSimpleTableColumn()) + .withBucket(4, Arrays.asList("id", "name"), Arrays.asList("name", "id")); + checker.check(tableInfo); + checkTableReadWrite(tableInfo); + } + + protected void checkPartitionDirExists(SparkTableInfo table) { + Assertions.assertTrue(table.isPartitionTable(), "Not a partition table"); + String tableLocation = table.getTableLocation(); + String partitionExpression = getPartitionExpression(table, "/").replace("'", ""); + Path partitionPath = new Path(tableLocation, partitionExpression); + try { + Assertions.assertTrue( + hdfs.exists(partitionPath), "Partition directory not exists," + partitionPath); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void checkTableReadWrite(SparkTableInfo table) { String name = table.getTableIdentifier(); + boolean isPartitionTable = table.isPartitionTable(); String insertValues = - table.getColumns().stream() + table.getUnPartitionedColumns().stream() .map(columnInfo -> typeConstant.get(columnInfo.getType())) .map(Object::toString) .collect(Collectors.joining(",")); - sql(getInsertWithoutPartitionSql(name, insertValues)); + String insertDataSQL = ""; + if (isPartitionTable) { + String partitionExpressions = getPartitionExpression(table, ","); + insertDataSQL = getInsertWithPartitionSql(name, partitionExpressions, insertValues); + } else { + insertDataSQL = getInsertWithoutPartitionSql(name, insertValues); + } + sql(insertDataSQL); // do something to match the query result: // 1. remove "'" from values, such as 'a' is trans to a @@ -514,23 +593,43 @@ private void checkTableReadWrite(SparkTableInfo table) { Assertions.assertEquals(checkValues, queryResult.get(0)); } - private String getCreateSimpleTableString(String tableName) { + protected String getCreateSimpleTableString(String tableName) { return String.format( "CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', age INT)", tableName); } - private List getSimpleTableColumn() { + protected List getSimpleTableColumn() { return Arrays.asList( SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"), SparkColumnInfo.of("name", DataTypes.StringType, ""), SparkColumnInfo.of("age", DataTypes.IntegerType, null)); } + protected String getDefaultDatabase() { + return "default_db"; + } + // Helper method to create a simple table, and could use corresponding // getSimpleTableColumn to check table column. private void createSimpleTable(String identifier) { String createTableSql = getCreateSimpleTableString(identifier); sql(createTableSql); } + + private void checkTableColumns( + String tableName, List columns, SparkTableInfo tableInfo) { + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(columns) + .withComment(null) + .check(tableInfo); + } + + // partition expression may contain "'", like a='s'/b=1 + private String getPartitionExpression(SparkTableInfo table, String delimiter) { + return table.getPartitionedColumns().stream() + .map(column -> column.getName() + "=" + typeConstant.get(column.getType())) + .collect(Collectors.joining(delimiter)); + } } diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkEnvIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkEnvIT.java index b0b7fd895e6..52de8da4a67 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkEnvIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkEnvIT.java @@ -14,8 +14,11 @@ import com.datastrato.gravitino.spark.connector.GravitinoSparkConfig; import com.datastrato.gravitino.spark.connector.plugin.GravitinoSparkPlugin; import com.google.common.collect.Maps; +import java.io.IOException; import java.util.Collections; import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; @@ -28,11 +31,12 @@ public abstract class SparkEnvIT extends SparkUtilIT { private static final Logger LOG = LoggerFactory.getLogger(SparkEnvIT.class); private static final ContainerSuite containerSuite = ContainerSuite.getInstance(); + protected FileSystem hdfs; private final String metalakeName = "test"; private SparkSession sparkSession; - private String hiveMetastoreUri; - private String gravitinoUri; + private String hiveMetastoreUri = "thrift://127.0.0.1:9083"; + private String gravitinoUri = "http://127.0.0.1:8090"; protected abstract String getCatalogName(); @@ -47,6 +51,7 @@ protected SparkSession getSparkSession() { @BeforeAll void startUp() { initHiveEnv(); + initHdfsFileSystem(); initGravitinoEnv(); initMetalakeAndCatalogs(); initSparkEnv(); @@ -58,6 +63,13 @@ void startUp() { @AfterAll void stop() { + if (hdfs != null) { + try { + hdfs.close(); + } catch (IOException e) { + LOG.warn("Close HDFS filesystem failed,", e); + } + } if (sparkSession != null) { sparkSession.close(); } @@ -92,6 +104,22 @@ private void initHiveEnv() { HiveContainer.HIVE_METASTORE_PORT); } + private void initHdfsFileSystem() { + Configuration conf = new Configuration(); + conf.set( + "fs.defaultFS", + String.format( + "hdfs://%s:%d", + containerSuite.getHiveContainer().getContainerIpAddress(), + HiveContainer.HDFS_DEFAULTFS_PORT)); + try { + hdfs = FileSystem.get(conf); + } catch (IOException e) { + LOG.error("Create HDFS filesystem failed", e); + throw new RuntimeException(e); + } + } + private void initSparkEnv() { sparkSession = SparkSession.builder() diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java index bce6cb212bf..ef477aeb2ff 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java @@ -5,7 +5,15 @@ package com.datastrato.gravitino.integration.test.spark.hive; import com.datastrato.gravitino.integration.test.spark.SparkCommonIT; +import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo; +import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo; +import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.spark.sql.types.DataTypes; import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @Tag("gravitino-docker-it") @@ -21,4 +29,33 @@ protected String getCatalogName() { protected String getProvider() { return "hive"; } + + @Override + protected boolean supportsSparkSQLClusteredBy() { + return true; + } + + @Test + public void testCreateHiveFormatPartitionTable() { + String tableName = "hive_partition_table"; + + dropTableIfExists(tableName); + String createTableSQL = getCreateSimpleTableString(tableName); + createTableSQL = createTableSQL + "PARTITIONED BY (age_p1 INT, age_p2 STRING)"; + sql(createTableSQL); + + List columns = new ArrayList<>(getSimpleTableColumn()); + columns.add(SparkColumnInfo.of("age_p1", DataTypes.IntegerType)); + columns.add(SparkColumnInfo.of("age_p2", DataTypes.StringType)); + + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(columns) + .withIdentifyPartition(Arrays.asList("age_p1", "age_p2")); + checker.check(tableInfo); + checkTableReadWrite(tableInfo); + checkPartitionDirExists(tableInfo); + } } diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java index 65e06c977c3..6f42cb810b2 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java @@ -9,11 +9,18 @@ import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; +import javax.ws.rs.NotSupportedException; import lombok.Data; import org.apache.commons.lang3.StringUtils; +import org.apache.spark.sql.connector.expressions.BucketTransform; +import org.apache.spark.sql.connector.expressions.IdentityTransform; +import org.apache.spark.sql.connector.expressions.SortedBucketTransform; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.DataType; import org.junit.jupiter.api.Assertions; @@ -26,6 +33,9 @@ public class SparkTableInfo { private List columns; private Map tableProperties; private List unknownItems = new ArrayList<>(); + private Transform bucket; + private List partitions = new ArrayList<>(); + private Set partitionColumnNames = new HashSet<>(); public SparkTableInfo() {} @@ -42,6 +52,28 @@ public String getTableIdentifier() { } } + public String getTableLocation() { + return tableProperties.get(ConnectorConstants.LOCATION); + } + + public boolean isPartitionTable() { + return partitions.size() > 0; + } + + void setBucket(Transform bucket) { + Assertions.assertNull(this.bucket, "Should only one distribution"); + this.bucket = bucket; + } + + void addPartition(Transform partition) { + this.partitions.add(partition); + if (partition instanceof IdentityTransform) { + partitionColumnNames.add(((IdentityTransform) partition).reference().fieldNames()[0]); + } else { + throw new NotSupportedException(partition.name() + " is not supported yet."); + } + } + static SparkTableInfo create(SparkBaseTable baseTable) { SparkTableInfo sparkTableInfo = new SparkTableInfo(); String identifier = baseTable.name(); @@ -62,9 +94,33 @@ static SparkTableInfo create(SparkBaseTable baseTable) { .collect(Collectors.toList()); sparkTableInfo.comment = baseTable.properties().remove(ConnectorConstants.COMMENT); sparkTableInfo.tableProperties = baseTable.properties(); + Arrays.stream(baseTable.partitioning()) + .forEach( + transform -> { + if (transform instanceof BucketTransform + || transform instanceof SortedBucketTransform) { + sparkTableInfo.setBucket(transform); + } else if (transform instanceof IdentityTransform) { + sparkTableInfo.addPartition(transform); + } else { + throw new NotSupportedException("Not support Spark transform: " + transform.name()); + } + }); return sparkTableInfo; } + public List getUnPartitionedColumns() { + return columns.stream() + .filter(column -> !partitionColumnNames.contains(column.name)) + .collect(Collectors.toList()); + } + + public List getPartitionedColumns() { + return columns.stream() + .filter(column -> partitionColumnNames.contains(column.name)) + .collect(Collectors.toList()); + } + @Data public static class SparkColumnInfo { private String name; diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java index e95730d1ae3..ca853c08e18 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java @@ -6,8 +6,12 @@ package com.datastrato.gravitino.integration.test.util.spark; import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import java.util.ArrayList; import java.util.List; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.IdentityTransform; +import org.apache.spark.sql.connector.expressions.Transform; import org.junit.jupiter.api.Assertions; /** @@ -27,6 +31,8 @@ public static SparkTableInfoChecker create() { private enum CheckField { NAME, COLUMN, + PARTITION, + BUCKET, COMMENT, } @@ -42,6 +48,35 @@ public SparkTableInfoChecker withColumns(List columns) { return this; } + public SparkTableInfoChecker withIdentifyPartition(List partitionColumns) { + partitionColumns.stream() + .forEach( + columnName -> { + IdentityTransform identityTransform = + SparkTransformConverter.createSparkIdentityTransform(columnName); + this.expectedTableInfo.addPartition(identityTransform); + }); + this.checkFields.add(CheckField.PARTITION); + return this; + } + + public SparkTableInfoChecker withBucket(int bucketNum, List bucketColumns) { + Transform bucketTransform = Expressions.bucket(bucketNum, bucketColumns.toArray(new String[0])); + this.expectedTableInfo.setBucket(bucketTransform); + this.checkFields.add(CheckField.BUCKET); + return this; + } + + public SparkTableInfoChecker withBucket( + int bucketNum, List bucketColumns, List sortColumns) { + Transform sortBucketTransform = + SparkTransformConverter.createSortBucketTransform( + bucketNum, bucketColumns.toArray(new String[0]), sortColumns.toArray(new String[0])); + this.expectedTableInfo.setBucket(sortBucketTransform); + this.checkFields.add(CheckField.BUCKET); + return this; + } + public SparkTableInfoChecker withComment(String comment) { this.expectedTableInfo.setComment(comment); this.checkFields.add(CheckField.COMMENT); @@ -61,6 +96,13 @@ public void check(SparkTableInfo realTableInfo) { Assertions.assertEquals( expectedTableInfo.getColumns(), realTableInfo.getColumns()); break; + case PARTITION: + Assertions.assertEquals( + expectedTableInfo.getPartitions(), realTableInfo.getPartitions()); + break; + case BUCKET: + Assertions.assertEquals(expectedTableInfo.getBucket(), realTableInfo.getBucket()); + break; case COMMENT: Assertions.assertEquals( expectedTableInfo.getComment(), realTableInfo.getComment()); diff --git a/spark-connector/build.gradle.kts b/spark-connector/build.gradle.kts index 245577f67de..1a03e73f34f 100644 --- a/spark-connector/build.gradle.kts +++ b/spark-connector/build.gradle.kts @@ -16,6 +16,7 @@ val scalaVersion: String = project.properties["scalaVersion"] as? String ?: extr val sparkVersion: String = libs.versions.spark.get() val icebergVersion: String = libs.versions.iceberg.get() val kyuubiVersion: String = libs.versions.kyuubi.get() +val scalaJava8CompatVersion: String = libs.versions.scala.java.compat.get() dependencies { implementation(project(":api")) @@ -27,6 +28,10 @@ dependencies { implementation("org.apache.kyuubi:kyuubi-spark-connector-hive_$scalaVersion:$kyuubiVersion") implementation("org.apache.spark:spark-catalyst_$scalaVersion:$sparkVersion") implementation("org.apache.spark:spark-sql_$scalaVersion:$sparkVersion") + implementation("org.scala-lang.modules:scala-java8-compat_$scalaVersion:$scalaJava8CompatVersion") + + annotationProcessor(libs.lombok) + compileOnly(libs.lombok) testImplementation(libs.junit.jupiter.api) testImplementation(libs.junit.jupiter.params) diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java index 40ae3b5c712..3a49a21470f 100644 --- a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java @@ -5,8 +5,15 @@ package com.datastrato.gravitino.spark.connector; +import com.datastrato.gravitino.rel.expressions.sorts.SortDirection; + public class ConnectorConstants { public static final String COMMENT = "comment"; + public static final SortDirection SPARK_DEFAULT_SORT_DIRECTION = SortDirection.ASCENDING; + public static final String LOCATION = "location"; + + public static final String DOT = "."; + private ConnectorConstants() {} } diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java new file mode 100644 index 00000000000..7ea61fb13fe --- /dev/null +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java @@ -0,0 +1,233 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ + +package com.datastrato.gravitino.spark.connector; + +import com.datastrato.gravitino.dto.rel.partitioning.Partitioning.SingleFieldPartitioning; +import com.datastrato.gravitino.rel.expressions.Expression; +import com.datastrato.gravitino.rel.expressions.NamedReference; +import com.datastrato.gravitino.rel.expressions.distributions.Distribution; +import com.datastrato.gravitino.rel.expressions.distributions.Distributions; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrder; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrders; +import com.datastrato.gravitino.rel.expressions.transforms.Transform; +import com.datastrato.gravitino.rel.expressions.transforms.Transforms; +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import javax.ws.rs.NotSupportedException; +import lombok.Getter; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.sql.connector.expressions.BucketTransform; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.IdentityTransform; +import org.apache.spark.sql.connector.expressions.LogicalExpressions; +import org.apache.spark.sql.connector.expressions.SortedBucketTransform; +import scala.collection.JavaConverters; + +/** + * SparkTransformConverter translate between Spark transform and Gravitino partition, distribution, + * sort orders. There may be multi partition transforms, but should be only one bucket transform. + * Spark bucket transform is corresponding to Gravitino Hash distribution without sort orders. Spark + * sorted bucket transform is corresponding to Gravitino Hash distribution with sort orders. + */ +public class SparkTransformConverter { + + public static class GravitinoTransformBundles { + private List partitions; + @Getter private Distribution distribution; + @Getter private SortOrder[] sortOrders; + + public Transform[] getPartitions() { + if (partitions == null) { + return null; + } + return partitions.toArray(new Transform[0]); + } + + private void addPartition(Transform partition) { + if (partitions == null) { + this.partitions = new ArrayList<>(); + } + partitions.add(partition); + } + + private void setDistribution(Distribution distribution) { + Preconditions.checkState(this.distribution == null, "Should only set distribution once"); + this.distribution = distribution; + } + + private void setSortOrders(SortOrder[] sortOrders) { + Preconditions.checkState(this.sortOrders == null, "Should only set sort orders once"); + this.sortOrders = sortOrders; + } + } + + public static GravitinoTransformBundles toGravitinoTransform( + org.apache.spark.sql.connector.expressions.Transform[] transforms) { + GravitinoTransformBundles bundles = new GravitinoTransformBundles(); + if (ArrayUtils.isEmpty(transforms)) { + return bundles; + } + + Arrays.stream(transforms) + .forEach( + transform -> { + if (transform instanceof IdentityTransform) { + IdentityTransform identityTransform = (IdentityTransform) transform; + bundles.addPartition( + Transforms.identity(identityTransform.reference().fieldNames())); + } else if (transform instanceof SortedBucketTransform) { + Pair pair = + toGravitinoDistributionAndSortOrders((SortedBucketTransform) transform); + bundles.setDistribution(pair.getLeft()); + bundles.setSortOrders(pair.getRight()); + } else if (transform instanceof BucketTransform) { + BucketTransform bucketTransform = (BucketTransform) transform; + Distribution distribution = toGravitinoDistribution(bucketTransform); + bundles.setDistribution(distribution); + } else { + throw new NotSupportedException("Not support Spark transform: " + transform.name()); + } + }); + + return bundles; + } + + public static org.apache.spark.sql.connector.expressions.Transform[] toSparkTransform( + com.datastrato.gravitino.rel.expressions.transforms.Transform[] partitions, + Distribution distribution, + SortOrder[] sortOrder) { + List sparkTransforms = new ArrayList<>(); + if (ArrayUtils.isNotEmpty(partitions)) { + Arrays.stream(partitions) + .forEach( + transform -> { + SingleFieldPartitioning identityTransform = (SingleFieldPartitioning) transform; + String[] fieldName = identityTransform.fieldName(); + switch (identityTransform.strategy()) { + case IDENTITY: + sparkTransforms.add( + createSparkIdentityTransform( + String.join(ConnectorConstants.DOT, fieldName))); + break; + default: + throw new UnsupportedOperationException( + "Not support gravitino partition: " + + transform.name() + + ", className: " + + transform.getClass().getName()); + } + }); + } + + org.apache.spark.sql.connector.expressions.Transform bucketTransform = + toSparkBucketTransform(distribution, sortOrder); + if (bucketTransform != null) { + sparkTransforms.add(bucketTransform); + } + + return sparkTransforms.toArray(new org.apache.spark.sql.connector.expressions.Transform[0]); + } + + private static Distribution toGravitinoDistribution(BucketTransform bucketTransform) { + int bucketNum = (Integer) bucketTransform.numBuckets().value(); + Expression[] expressions = + JavaConverters.seqAsJavaList(bucketTransform.columns()).stream() + .map(sparkReference -> NamedReference.field(sparkReference.fieldNames())) + .toArray(Expression[]::new); + return Distributions.hash(bucketNum, expressions); + } + + // Spark datasourceV2 doesn't support specify sort order direction, use ASCENDING as default. + private static Pair toGravitinoDistributionAndSortOrders( + SortedBucketTransform sortedBucketTransform) { + int bucketNum = (Integer) sortedBucketTransform.numBuckets().value(); + Expression[] bucketColumns = + transToGravitinoNamedReference( + JavaConverters.seqAsJavaList(sortedBucketTransform.columns())); + + Expression[] sortColumns = + transToGravitinoNamedReference( + JavaConverters.seqAsJavaList(sortedBucketTransform.sortedColumns())); + SortOrder[] sortOrders = + Arrays.stream(sortColumns) + .map( + sortColumn -> + SortOrders.of(sortColumn, ConnectorConstants.SPARK_DEFAULT_SORT_DIRECTION)) + .toArray(SortOrder[]::new); + + return Pair.of(Distributions.hash(bucketNum, bucketColumns), sortOrders); + } + + private static org.apache.spark.sql.connector.expressions.Transform toSparkBucketTransform( + Distribution distribution, SortOrder[] sortOrders) { + if (distribution == null) { + return null; + } + + switch (distribution.strategy()) { + case NONE: + return null; + case HASH: + int bucketNum = distribution.number(); + String[] bucketFields = + Arrays.stream(distribution.expressions()) + .map( + expression -> + getFieldNameFromGravitinoNamedReference((NamedReference) expression)) + .toArray(String[]::new); + if (sortOrders == null || sortOrders.length == 0) { + return Expressions.bucket(bucketNum, bucketFields); + } else { + String[] sortOrderFields = + Arrays.stream(sortOrders) + .map( + sortOrder -> + getFieldNameFromGravitinoNamedReference( + (NamedReference) sortOrder.expression())) + .toArray(String[]::new); + return createSortBucketTransform(bucketNum, bucketFields, sortOrderFields); + } + // Spark doesn't support EVEN or RANGE distribution + default: + throw new NotSupportedException( + "Not support distribution strategy: " + distribution.strategy()); + } + } + + private static Expression[] transToGravitinoNamedReference( + List sparkNamedReferences) { + return sparkNamedReferences.stream() + .map(sparkReference -> NamedReference.field(sparkReference.fieldNames())) + .toArray(Expression[]::new); + } + + public static org.apache.spark.sql.connector.expressions.Transform createSortBucketTransform( + int bucketNum, String[] bucketFields, String[] sortFields) { + return LogicalExpressions.bucket( + bucketNum, createSparkNamedReference(bucketFields), createSparkNamedReference(sortFields)); + } + + // columnName could be "a" or "a.b" for nested column + public static IdentityTransform createSparkIdentityTransform(String columnName) { + return IdentityTransform.apply(Expressions.column(columnName)); + } + + private static org.apache.spark.sql.connector.expressions.NamedReference[] + createSparkNamedReference(String[] fields) { + return Arrays.stream(fields) + .map(Expressions::column) + .toArray(org.apache.spark.sql.connector.expressions.NamedReference[]::new); + } + + // Gravitino use ["a","b"] for nested fields while Spark use "a.b"; + private static String getFieldNameFromGravitinoNamedReference( + NamedReference gravitinoNamedReference) { + return String.join(ConnectorConstants.DOT, gravitinoNamedReference.fieldName()); + } +} diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java index 0449a6e8c82..9b21c00df15 100644 --- a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java @@ -17,6 +17,8 @@ import com.datastrato.gravitino.spark.connector.GravitinoCatalogAdaptor; import com.datastrato.gravitino.spark.connector.GravitinoCatalogAdaptorFactory; import com.datastrato.gravitino.spark.connector.PropertiesConverter; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter.GravitinoTransformBundles; import com.datastrato.gravitino.spark.connector.SparkTypeConverter; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; @@ -127,11 +129,21 @@ public Table createTable( // Spark store comment in properties, we should retrieve it and pass to Gravitino explicitly. String comment = gravitinoProperties.remove(ConnectorConstants.COMMENT); + GravitinoTransformBundles gravitinoTransformContext = + SparkTransformConverter.toGravitinoTransform(partitions); + try { com.datastrato.gravitino.rel.Table table = gravitinoCatalogClient .asTableCatalog() - .createTable(gravitinoIdentifier, gravitinoColumns, comment, gravitinoProperties); + .createTable( + gravitinoIdentifier, + gravitinoColumns, + comment, + gravitinoProperties, + gravitinoTransformContext.getPartitions(), + gravitinoTransformContext.getDistribution(), + gravitinoTransformContext.getSortOrders()); return gravitinoAdaptor.createSparkTable(ident, table, sparkCatalog, propertiesConverter); } catch (NoSuchSchemaException e) { throw new NoSuchNamespaceException(ident.namespace()); diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java index b6ae81e4d41..0d057656e86 100644 --- a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java @@ -5,8 +5,11 @@ package com.datastrato.gravitino.spark.connector.table; +import com.datastrato.gravitino.rel.expressions.distributions.Distribution; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrder; import com.datastrato.gravitino.spark.connector.ConnectorConstants; import com.datastrato.gravitino.spark.connector.PropertiesConverter; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.SparkTypeConverter; import java.util.Arrays; import java.util.HashMap; @@ -22,6 +25,7 @@ import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.write.LogicalWriteInfo; import org.apache.spark.sql.connector.write.WriteBuilder; @@ -117,6 +121,15 @@ public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { return ((SupportsWrite) getSparkTable()).newWriteBuilder(info); } + @Override + public Transform[] partitioning() { + com.datastrato.gravitino.rel.expressions.transforms.Transform[] partitions = + gravitinoTable.partitioning(); + Distribution distribution = gravitinoTable.distribution(); + SortOrder[] sortOrders = gravitinoTable.sortOrder(); + return SparkTransformConverter.toSparkTransform(partitions, distribution, sortOrders); + } + protected Table getSparkTable() { if (lazySparkTable == null) { try { diff --git a/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java b/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java new file mode 100644 index 00000000000..3fc9b7f951e --- /dev/null +++ b/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java @@ -0,0 +1,229 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ + +package com.datastrato.gravitino.spark.connector; + +import com.datastrato.gravitino.dto.rel.partitioning.IdentityPartitioningDTO; +import com.datastrato.gravitino.rel.expressions.NamedReference; +import com.datastrato.gravitino.rel.expressions.distributions.Distribution; +import com.datastrato.gravitino.rel.expressions.distributions.Distributions; +import com.datastrato.gravitino.rel.expressions.sorts.SortDirection; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrder; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrders; +import com.datastrato.gravitino.rel.expressions.transforms.Transform; +import com.datastrato.gravitino.rel.expressions.transforms.Transforms; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter.GravitinoTransformBundles; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import javax.ws.rs.NotSupportedException; +import org.apache.spark.sql.connector.expressions.BucketTransform; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.LogicalExpressions; +import org.apache.spark.sql.connector.expressions.SortedBucketTransform; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import scala.collection.JavaConverters; + +@TestInstance(Lifecycle.PER_CLASS) +public class TestSparkTransformConverter { + private Map + sparkToGravitinoPartitionTransformMaps = new HashMap<>(); + + private Map + gravitinoToSparkPartitionTransformMaps = new HashMap<>(); + + @BeforeAll + void init() { + initSparkToGravitinoTransformMap(); + initGravitinoToSparkTransformMap(); + } + + @Test + void testPartition() { + sparkToGravitinoPartitionTransformMaps.forEach( + (sparkTransform, gravitinoTransform) -> { + GravitinoTransformBundles bundles = + SparkTransformConverter.toGravitinoTransform( + new org.apache.spark.sql.connector.expressions.Transform[] {sparkTransform}); + Assertions.assertNull(bundles.getDistribution()); + Assertions.assertNull(bundles.getSortOrders()); + Transform[] gravitinoPartitions = bundles.getPartitions(); + Assertions.assertTrue(gravitinoPartitions != null && gravitinoPartitions.length == 1); + Assertions.assertEquals(gravitinoTransform, gravitinoPartitions[0]); + }); + + gravitinoToSparkPartitionTransformMaps.forEach( + (gravitinoTransform, sparkTransform) -> { + org.apache.spark.sql.connector.expressions.Transform[] sparkTransforms = + SparkTransformConverter.toSparkTransform( + new Transform[] {gravitinoTransform}, null, null); + Assertions.assertTrue(sparkTransforms.length == 1); + Assertions.assertEquals(sparkTransform, sparkTransforms[0]); + }); + } + + @Test + void testGravitinoToSparkDistributionWithoutSortOrder() { + int bucketNum = 16; + String[][] columnNames = createGravitinoFieldReferenceNames("a", "b.c"); + Distribution gravitinoDistribution = createHashDistribution(bucketNum, columnNames); + + org.apache.spark.sql.connector.expressions.Transform[] sparkTransforms = + SparkTransformConverter.toSparkTransform(null, gravitinoDistribution, null); + Assertions.assertTrue(sparkTransforms != null && sparkTransforms.length == 1); + Assertions.assertTrue(sparkTransforms[0] instanceof BucketTransform); + BucketTransform bucket = (BucketTransform) sparkTransforms[0]; + Assertions.assertEquals(bucketNum, (Integer) bucket.numBuckets().value()); + String[][] columns = + JavaConverters.seqAsJavaList(bucket.columns()).stream() + .map(namedReference -> namedReference.fieldNames()) + .toArray(String[][]::new); + Assertions.assertArrayEquals(columnNames, columns); + + // none and null distribution + sparkTransforms = SparkTransformConverter.toSparkTransform(null, null, null); + Assertions.assertEquals(0, sparkTransforms.length); + sparkTransforms = SparkTransformConverter.toSparkTransform(null, Distributions.NONE, null); + Assertions.assertEquals(0, sparkTransforms.length); + + // range and even distribution + Assertions.assertThrowsExactly( + NotSupportedException.class, + () -> SparkTransformConverter.toSparkTransform(null, Distributions.RANGE, null)); + Distribution evenDistribution = Distributions.even(bucketNum, NamedReference.field("")); + Assertions.assertThrowsExactly( + NotSupportedException.class, + () -> SparkTransformConverter.toSparkTransform(null, evenDistribution, null)); + } + + @Test + void testSparkToGravitinoDistributionWithoutSortOrder() { + int bucketNum = 16; + String[] sparkFieldReferences = new String[] {"a", "b.c"}; + + org.apache.spark.sql.connector.expressions.Transform sparkBucket = + Expressions.bucket(bucketNum, sparkFieldReferences); + GravitinoTransformBundles bundles = + SparkTransformConverter.toGravitinoTransform( + new org.apache.spark.sql.connector.expressions.Transform[] {sparkBucket}); + + Assertions.assertNull(bundles.getSortOrders()); + Assertions.assertNull(bundles.getPartitions()); + + Distribution distribution = bundles.getDistribution(); + String[][] gravitinoFieldReferences = createGravitinoFieldReferenceNames(sparkFieldReferences); + Assertions.assertTrue( + distribution.equals(createHashDistribution(bucketNum, gravitinoFieldReferences))); + } + + @Test + void testSparkToGravitinoSortOrder() { + int bucketNum = 16; + String[][] bucketColumnNames = createGravitinoFieldReferenceNames("a", "b.c"); + String[][] sortColumnNames = createGravitinoFieldReferenceNames("f", "m.n"); + SortedBucketTransform sortedBucketTransform = + LogicalExpressions.bucket( + bucketNum, + createSparkFieldReference(bucketColumnNames), + createSparkFieldReference(sortColumnNames)); + + GravitinoTransformBundles bundles = + SparkTransformConverter.toGravitinoTransform( + new org.apache.spark.sql.connector.expressions.Transform[] {sortedBucketTransform}); + Assertions.assertNull(bundles.getPartitions()); + Assertions.assertTrue( + bundles.getDistribution().equals(createHashDistribution(bucketNum, bucketColumnNames))); + + SortOrder[] sortOrders = + createSortOrders(sortColumnNames, ConnectorConstants.SPARK_DEFAULT_SORT_DIRECTION); + // SortOrder doesn't implement equals for now + Assertions.assertEquals(sortOrders.length, bundles.getSortOrders().length); + for (int i = 0; i < sortOrders.length; i++) { + Assertions.assertEquals( + sortOrders[i].nullOrdering(), bundles.getSortOrders()[i].nullOrdering()); + Assertions.assertEquals(sortOrders[i].direction(), bundles.getSortOrders()[i].direction()); + Assertions.assertEquals(sortOrders[i].expression(), bundles.getSortOrders()[i].expression()); + } + } + + @Test + void testGravitinoToSparkSortOrder() { + int bucketNum = 16; + String[][] bucketColumnNames = createGravitinoFieldReferenceNames("a", "b.c"); + String[][] sortColumnNames = createGravitinoFieldReferenceNames("f", "m.n"); + Distribution distribution = createHashDistribution(bucketNum, bucketColumnNames); + SortOrder[] sortOrders = + createSortOrders(sortColumnNames, ConnectorConstants.SPARK_DEFAULT_SORT_DIRECTION); + + org.apache.spark.sql.connector.expressions.Transform[] transforms = + SparkTransformConverter.toSparkTransform(null, distribution, sortOrders); + Assertions.assertTrue(transforms.length == 1); + Assertions.assertTrue(transforms[0] instanceof SortedBucketTransform); + + SortedBucketTransform sortedBucketTransform = (SortedBucketTransform) transforms[0]; + Assertions.assertEquals(bucketNum, (Integer) sortedBucketTransform.numBuckets().value()); + String[][] sparkSortColumns = + JavaConverters.seqAsJavaList(sortedBucketTransform.sortedColumns()).stream() + .map(sparkNamedReference -> sparkNamedReference.fieldNames()) + .toArray(String[][]::new); + + String[][] sparkBucketColumns = + JavaConverters.seqAsJavaList(sortedBucketTransform.columns()).stream() + .map(sparkNamedReference -> sparkNamedReference.fieldNames()) + .toArray(String[][]::new); + + Assertions.assertArrayEquals(bucketColumnNames, sparkBucketColumns); + Assertions.assertArrayEquals(sortColumnNames, sparkSortColumns); + } + + private org.apache.spark.sql.connector.expressions.NamedReference[] createSparkFieldReference( + String[][] fields) { + return Arrays.stream(fields) + .map(field -> FieldReference.apply(String.join(ConnectorConstants.DOT, field))) + .toArray(org.apache.spark.sql.connector.expressions.NamedReference[]::new); + } + + // split column name for Gravitino + private String[][] createGravitinoFieldReferenceNames(String... columnNames) { + return Arrays.stream(columnNames) + .map(columnName -> columnName.split("\\.")) + .toArray(String[][]::new); + } + + private SortOrder[] createSortOrders(String[][] columnNames, SortDirection direction) { + return Arrays.stream(columnNames) + .map(columnName -> SortOrders.of(NamedReference.field(columnName), direction)) + .toArray(SortOrder[]::new); + } + + private Distribution createHashDistribution(int bucketNum, String[][] columnNames) { + NamedReference[] namedReferences = + Arrays.stream(columnNames) + .map(columnName -> NamedReference.field(columnName)) + .toArray(NamedReference[]::new); + return Distributions.hash(bucketNum, namedReferences); + } + + private void initSparkToGravitinoTransformMap() { + sparkToGravitinoPartitionTransformMaps.put( + SparkTransformConverter.createSparkIdentityTransform("a"), Transforms.identity("a")); + sparkToGravitinoPartitionTransformMaps.put( + SparkTransformConverter.createSparkIdentityTransform("a.b"), + Transforms.identity(new String[] {"a", "b"})); + } + + private void initGravitinoToSparkTransformMap() { + gravitinoToSparkPartitionTransformMaps.put( + IdentityPartitioningDTO.of("a"), SparkTransformConverter.createSparkIdentityTransform("a")); + gravitinoToSparkPartitionTransformMaps.put( + IdentityPartitioningDTO.of("a", "b"), + SparkTransformConverter.createSparkIdentityTransform("a.b")); + } +}