diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java index ae218e42e97..e95b07d3fab 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java @@ -5,13 +5,21 @@ package com.datastrato.gravitino.integration.test.spark.iceberg; import com.datastrato.gravitino.integration.test.spark.SparkCommonIT; +import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo; +import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker; +import java.io.File; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.types.DataTypes; +import org.junit.jupiter.api.Assertions; import com.datastrato.gravitino.spark.connector.iceberg.SparkIcebergTable; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.catalyst.analysis.ResolvedTable; import org.apache.spark.sql.catalyst.plans.logical.CommandResult; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation; import org.apache.spark.sql.connector.catalog.Table; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -20,6 +28,19 @@ @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class SparkIcebergCatalogIT extends SparkCommonIT { + protected List getIcebergSimpleTableColumn() { + return Arrays.asList( + SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"), + SparkTableInfo.SparkColumnInfo.of("name", DataTypes.StringType, ""), + SparkTableInfo.SparkColumnInfo.of("ts", DataTypes.TimestampType, null)); + } + + private String getCreateIcebergSimpleTableString(String tableName) { + return String.format( + "CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', ts TIMESTAMP)", + tableName); + } + @Override protected String getCatalogName() { return "iceberg"; @@ -37,7 +58,181 @@ protected boolean supportsSparkSQLClusteredBy() { @Override protected boolean supportsPartition() { - return false; + return true; + } + + @Test + void testCreateIcebergBucketPartitionTable() { + String tableName = "iceberg_bucket_partition_table"; + dropTableIfExists(tableName); + String createTableSQL = getCreateIcebergSimpleTableString(tableName); + createTableSQL = createTableSQL + " PARTITIONED BY (bucket(16, id));"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withBucket(16, Collections.singletonList("id")); + checker.check(tableInfo); + + String insertData = + String.format( + "INSERT into %s values(2,'a',cast('2024-01-01 12:00:00.000' as timestamp));", + tableName); + sql(insertData); + List queryResult = getTableData(tableName); + Assertions.assertTrue(queryResult.size() == 1); + Assertions.assertEquals("2,a,2024-01-01 12:00:00.000", queryResult.get(0)); + String location = tableInfo.getTableLocation() + File.separator + "data"; + String partitionExpression = "id_bucket=4"; + Path partitionPath = new Path(location, partitionExpression); + checkDirExists(partitionPath); + } + + @Test + void testCreateIcebergHourPartitionTable() { + String tableName = "iceberg_hour_partition_table"; + dropTableIfExists(tableName); + String createTableSQL = getCreateIcebergSimpleTableString(tableName); + createTableSQL = createTableSQL + " PARTITIONED BY (hours(ts));"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withHour(Collections.singletonList("ts")); + checker.check(tableInfo); + + String insertData = + String.format( + "INSERT into %s values(2,'a',cast('2024-01-01 12:00:00.000' as timestamp));", + tableName); + sql(insertData); + List queryResult = getTableData(tableName); + Assertions.assertTrue(queryResult.size() == 1); + Assertions.assertEquals("2,a,2024-01-01 12:00:00.000", queryResult.get(0)); + String location = tableInfo.getTableLocation() + File.separator + "data"; + String partitionExpression = "ts_hour=12"; + Path partitionPath = new Path(location, partitionExpression); + checkDirExists(partitionPath); + } + + @Test + void testCreateIcebergDayPartitionTable() { + String tableName = "iceberg_day_partition_table"; + dropTableIfExists(tableName); + String createTableSQL = getCreateIcebergSimpleTableString(tableName); + createTableSQL = createTableSQL + " PARTITIONED BY (days(ts));"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withDay(Collections.singletonList("ts")); + checker.check(tableInfo); + + String insertData = + String.format( + "INSERT into %s values(2,'a',cast('2024-01-01 12:00:00.000' as timestamp));", + tableName); + sql(insertData); + List queryResult = getTableData(tableName); + Assertions.assertTrue(queryResult.size() == 1); + Assertions.assertEquals("2,a,2024-01-01 12:00:00.000", queryResult.get(0)); + String location = tableInfo.getTableLocation() + File.separator + "data"; + String partitionExpression = "ts_day=2024-01-01"; + Path partitionPath = new Path(location, partitionExpression); + checkDirExists(partitionPath); + } + + @Test + void testCreateIcebergMonthPartitionTable() { + String tableName = "iceberg_month_partition_table"; + dropTableIfExists(tableName); + String createTableSQL = getCreateIcebergSimpleTableString(tableName); + createTableSQL = createTableSQL + " PARTITIONED BY (months(ts));"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withMonth(Collections.singletonList("ts")); + checker.check(tableInfo); + + String insertData = + String.format( + "INSERT into %s values(2,'a',cast('2024-01-01 12:00:00.000' as timestamp));", + tableName); + sql(insertData); + List queryResult = getTableData(tableName); + Assertions.assertTrue(queryResult.size() == 1); + Assertions.assertEquals("2,a,2024-01-01 12:00:00.000", queryResult.get(0)); + String location = tableInfo.getTableLocation() + File.separator + "data"; + String partitionExpression = "ts_month=2024-01"; + Path partitionPath = new Path(location, partitionExpression); + checkDirExists(partitionPath); + } + + @Test + void testCreateIcebergYearPartitionTable() { + String tableName = "iceberg_year_partition_table"; + dropTableIfExists(tableName); + String createTableSQL = getCreateIcebergSimpleTableString(tableName); + createTableSQL = createTableSQL + " PARTITIONED BY (years(ts));"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withYear(Collections.singletonList("ts")); + checker.check(tableInfo); + + String insertData = + String.format( + "INSERT into %s values(2,'a',cast('2024-01-01 12:00:00.000' as timestamp));", + tableName); + sql(insertData); + List queryResult = getTableData(tableName); + Assertions.assertTrue(queryResult.size() == 1); + Assertions.assertEquals("2,a,2024-01-01 12:00:00.000", queryResult.get(0)); + String location = tableInfo.getTableLocation() + File.separator + "data"; + String partitionExpression = "ts_year=2024"; + Path partitionPath = new Path(location, partitionExpression); + checkDirExists(partitionPath); + } + + @Test + void testCreateIcebergTruncatePartitionTable() { + String tableName = "iceberg_truncate_partition_table"; + dropTableIfExists(tableName); + String createTableSQL = getCreateIcebergSimpleTableString(tableName); + createTableSQL = createTableSQL + " PARTITIONED BY (truncate(1, name));"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withTruncate(1, Collections.singletonList("name")); + checker.check(tableInfo); + + String insertData = + String.format( + "INSERT into %s values(2,'a',cast('2024-01-01 12:00:00.000' as timestamp));", + tableName); + sql(insertData); + List queryResult = getTableData(tableName); + Assertions.assertTrue(queryResult.size() == 1); + Assertions.assertEquals("2,a,2024-01-01 12:00:00.000", queryResult.get(0)); + String location = tableInfo.getTableLocation() + File.separator + "data"; + String partitionExpression = "name_trunc=a"; + Path partitionPath = new Path(location, partitionExpression); + checkDirExists(partitionPath); } // TODO diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java index 449237ff157..c984646313b 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java @@ -17,10 +17,6 @@ import javax.ws.rs.NotSupportedException; import lombok.Data; import org.apache.commons.lang3.StringUtils; -import org.apache.spark.sql.connector.expressions.BucketTransform; -import org.apache.spark.sql.connector.expressions.IdentityTransform; -import org.apache.spark.sql.connector.expressions.SortedBucketTransform; -import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.DataType; import org.junit.jupiter.api.Assertions; @@ -34,6 +30,11 @@ public class SparkTableInfo { private Map tableProperties; private List unknownItems = new ArrayList<>(); private Transform bucket; + private Transform hour; + private Transform day; + private Transform month; + private Transform year; + private Transform truncate; private List partitions = new ArrayList<>(); private Set partitionColumnNames = new HashSet<>(); @@ -65,6 +66,31 @@ void setBucket(Transform bucket) { this.bucket = bucket; } + void setHour(Transform hour) { + Assertions.assertNull(this.hour, "Should only one distribution"); + this.hour = hour; + } + + void setDay(Transform day) { + Assertions.assertNull(this.day, "Should only one distribution"); + this.day = day; + } + + void setMonth(Transform month) { + Assertions.assertNull(this.month, "Should only one distribution"); + this.month = month; + } + + void setYear(Transform year) { + Assertions.assertNull(this.year, "Should only one distribution"); + this.year = year; + } + + void setTruncate(Transform truncate) { + Assertions.assertNull(this.truncate, "Should only one distribution"); + this.truncate = truncate; + } + void addPartition(Transform partition) { if (partition instanceof IdentityTransform) { partitionColumnNames.add(((IdentityTransform) partition).reference().fieldNames()[0]); @@ -102,6 +128,17 @@ static SparkTableInfo create(SparkBaseTable baseTable) { sparkTableInfo.setBucket(transform); } else if (transform instanceof IdentityTransform) { sparkTableInfo.addPartition(transform); + } else if (transform instanceof HoursTransform) { + sparkTableInfo.setHour(transform); + } else if (transform instanceof DaysTransform) { + sparkTableInfo.setDay(transform); + } else if (transform instanceof MonthsTransform) { + sparkTableInfo.setMonth(transform); + } else if (transform instanceof YearsTransform) { + sparkTableInfo.setYear(transform); + } else if (transform instanceof ApplyTransform + && "truncate".equals(transform.name())) { + sparkTableInfo.setTruncate(transform); } else { throw new NotSupportedException( "Doesn't support Spark transform: " + transform.name()); diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java index d346769281c..5724232952c 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java @@ -33,6 +33,11 @@ private enum CheckField { COLUMN, PARTITION, BUCKET, + HOUR, + DAY, + MONTH, + YEAR, + TRUNCATE, COMMENT, } @@ -76,6 +81,43 @@ public SparkTableInfoChecker withBucket( return this; } + public SparkTableInfoChecker withHour(List partitionColumns) { + Transform hourTransform = Expressions.hours(partitionColumns.get(0)); + this.expectedTableInfo.setHour(hourTransform); + this.checkFields.add(CheckField.HOUR); + return this; + } + + public SparkTableInfoChecker withDay(List partitionColumns) { + Transform dayTransform = Expressions.days(partitionColumns.get(0)); + this.expectedTableInfo.setDay(dayTransform); + this.checkFields.add(CheckField.DAY); + return this; + } + + public SparkTableInfoChecker withMonth(List partitionColumns) { + Transform monthTransform = Expressions.months(partitionColumns.get(0)); + this.expectedTableInfo.setMonth(monthTransform); + this.checkFields.add(CheckField.MONTH); + return this; + } + + public SparkTableInfoChecker withYear(List partitionColumns) { + Transform yearTransform = Expressions.years(partitionColumns.get(0)); + this.expectedTableInfo.setYear(yearTransform); + this.checkFields.add(CheckField.YEAR); + return this; + } + + public SparkTableInfoChecker withTruncate(int width, List partitionColumns) { + Transform truncateTransform = + Expressions.apply( + "truncate", Expressions.literal(width), Expressions.column(partitionColumns.get(0))); + this.expectedTableInfo.setTruncate(truncateTransform); + this.checkFields.add(CheckField.TRUNCATE); + return this; + } + public SparkTableInfoChecker withComment(String comment) { this.expectedTableInfo.setComment(comment); this.checkFields.add(CheckField.COMMENT); @@ -102,6 +144,22 @@ public void check(SparkTableInfo realTableInfo) { case BUCKET: Assertions.assertEquals(expectedTableInfo.getBucket(), realTableInfo.getBucket()); break; + case HOUR: + Assertions.assertEquals(expectedTableInfo.getHour(), realTableInfo.getHour()); + break; + case DAY: + Assertions.assertEquals(expectedTableInfo.getDay(), realTableInfo.getDay()); + break; + case MONTH: + Assertions.assertEquals(expectedTableInfo.getMonth(), realTableInfo.getMonth()); + break; + case YEAR: + Assertions.assertEquals(expectedTableInfo.getYear(), realTableInfo.getYear()); + break; + case TRUNCATE: + Assertions.assertEquals( + expectedTableInfo.getTruncate(), realTableInfo.getTruncate()); + break; case COMMENT: Assertions.assertEquals( expectedTableInfo.getComment(), realTableInfo.getComment()); diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java index 9afad670b76..7b6bf5ca6b5 100644 --- a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java @@ -21,11 +21,9 @@ import lombok.Getter; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.tuple.Pair; -import org.apache.spark.sql.connector.expressions.BucketTransform; -import org.apache.spark.sql.connector.expressions.Expressions; -import org.apache.spark.sql.connector.expressions.IdentityTransform; -import org.apache.spark.sql.connector.expressions.LogicalExpressions; -import org.apache.spark.sql.connector.expressions.SortedBucketTransform; +import org.apache.spark.sql.connector.expressions.*; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; import scala.collection.JavaConverters; /** @@ -68,6 +66,23 @@ public static Transform[] toGravitinoPartitionings( if (transform instanceof IdentityTransform) { IdentityTransform identityTransform = (IdentityTransform) transform; return Transforms.identity(identityTransform.reference().fieldNames()); + } else if (transform instanceof HoursTransform) { + HoursTransform hoursTransform = (HoursTransform) transform; + return Transforms.hour(hoursTransform.reference().fieldNames()); + } else if (transform instanceof DaysTransform) { + DaysTransform daysTransform = (DaysTransform) transform; + return Transforms.day(daysTransform.reference().fieldNames()); + } else if (transform instanceof MonthsTransform) { + MonthsTransform monthsTransform = (MonthsTransform) transform; + return Transforms.month(monthsTransform.reference().fieldNames()); + } else if (transform instanceof YearsTransform) { + YearsTransform yearsTransform = (YearsTransform) transform; + return Transforms.year(yearsTransform.reference().fieldNames()); + } else if (transform instanceof ApplyTransform + && "truncate".equals(transform.name())) { + return Transforms.truncate( + findWidth(transform), + String.join(ConnectorConstants.DOT, transform.references()[0].fieldNames())); } else { throw new NotSupportedException( "Doesn't support Spark transform: " + transform.name()); @@ -122,6 +137,33 @@ public static org.apache.spark.sql.connector.expressions.Transform[] toSparkTran sparkTransforms.add( createSparkIdentityTransform( String.join(ConnectorConstants.DOT, identityTransform.fieldName()))); + } else if (transform instanceof Transforms.HourTransform) { + Transforms.HourTransform hourTransform = (Transforms.HourTransform) transform; + sparkTransforms.add(createSparkHoursTransform(hourTransform.fieldName())); + } else if (transform instanceof Transforms.BucketTransform) { + Transforms.BucketTransform bucketTransform = + (Transforms.BucketTransform) transform; + int numBuckets = bucketTransform.numBuckets(); + String[] fieldNames = + Arrays.stream(bucketTransform.fieldNames()) + .map(f -> String.join(ConnectorConstants.DOT, f)) + .toArray(String[]::new); + sparkTransforms.add(createSparkBucketTransform(numBuckets, fieldNames)); + } else if (transform instanceof Transforms.DayTransform) { + Transforms.DayTransform dayTransform = (Transforms.DayTransform) transform; + sparkTransforms.add(createSparkDaysTransform(dayTransform.fieldName())); + } else if (transform instanceof Transforms.MonthTransform) { + Transforms.MonthTransform monthTransform = (Transforms.MonthTransform) transform; + sparkTransforms.add(createSparkMonthsTransform(monthTransform.fieldName())); + } else if (transform instanceof Transforms.YearTransform) { + Transforms.YearTransform yearTransform = (Transforms.YearTransform) transform; + sparkTransforms.add(createSparkYearsTransform(yearTransform.fieldName())); + } else if (transform instanceof Transforms.TruncateTransform) { + Transforms.TruncateTransform truncateTransform = + (Transforms.TruncateTransform) transform; + int width = truncateTransform.width(); + String[] fieldName = truncateTransform.fieldName(); + sparkTransforms.add(createSparkTruncateTransform("truncate", width, fieldName)); } else { throw new UnsupportedOperationException( "Doesn't support Gravitino partition: " @@ -224,6 +266,38 @@ public static IdentityTransform createSparkIdentityTransform(String columnName) return IdentityTransform.apply(Expressions.column(columnName)); } + public static HoursTransform createSparkHoursTransform(String[] fieldName) { + return LogicalExpressions.hours( + Expressions.column(String.join(ConnectorConstants.DOT, fieldName))); + } + + public static BucketTransform createSparkBucketTransform(int numBuckets, String[] fieldNames) { + return LogicalExpressions.bucket(numBuckets, createSparkNamedReference(fieldNames)); + } + + public static DaysTransform createSparkDaysTransform(String[] fieldName) { + return LogicalExpressions.days( + Expressions.column(String.join(ConnectorConstants.DOT, fieldName))); + } + + public static MonthsTransform createSparkMonthsTransform(String[] fieldName) { + return LogicalExpressions.months( + Expressions.column(String.join(ConnectorConstants.DOT, fieldName))); + } + + public static YearsTransform createSparkYearsTransform(String[] FieldName) { + return LogicalExpressions.years( + Expressions.column(String.join(ConnectorConstants.DOT, FieldName))); + } + + public static org.apache.spark.sql.connector.expressions.Transform createSparkTruncateTransform( + String functionName, int width, String[] fieldName) { + return Expressions.apply( + functionName, + Expressions.literal(width), + Expressions.column(String.join(ConnectorConstants.DOT, fieldName))); + } + private static org.apache.spark.sql.connector.expressions.NamedReference[] createSparkNamedReference(String[] fields) { return Arrays.stream(fields) @@ -241,4 +315,31 @@ private static boolean isBucketTransform( org.apache.spark.sql.connector.expressions.Transform transform) { return transform instanceof BucketTransform || transform instanceof SortedBucketTransform; } + + // Referred from org.apache.iceberg.spark.Spark3Util + private static int findWidth(org.apache.spark.sql.connector.expressions.Transform transform) { + for (org.apache.spark.sql.connector.expressions.Expression expr : transform.arguments()) { + if (expr instanceof Literal) { + if (((Literal) expr).dataType() instanceof IntegerType) { + Literal lit = (Literal) expr; + Preconditions.checkArgument( + lit.value() > 0, "Unsupported width for transform: %s", transform.describe()); + return lit.value(); + + } else if (((Literal) expr).dataType() instanceof LongType) { + Literal lit = (Literal) expr; + Preconditions.checkArgument( + lit.value() > 0 && lit.value() < Integer.MAX_VALUE, + "Unsupported width for transform: %s", + transform.describe()); + if (lit.value() > Integer.MAX_VALUE) { + throw new IllegalArgumentException(); + } + return lit.value().intValue(); + } + } + } + + throw new IllegalArgumentException("Cannot find width for transform: " + transform.describe()); + } } diff --git a/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java b/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java index ea00eeb5b58..d76a05e7bc4 100644 --- a/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java +++ b/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java @@ -202,5 +202,21 @@ private void initSparkToGravitinoTransformMap() { sparkToGravitinoPartitionTransformMaps.put( SparkTransformConverter.createSparkIdentityTransform("a.b"), Transforms.identity(new String[] {"a", "b"})); + sparkToGravitinoPartitionTransformMaps.put( + SparkTransformConverter.createSparkHoursTransform(new String[] {"date"}), + Transforms.hour("date")); + sparkToGravitinoPartitionTransformMaps.put( + SparkTransformConverter.createSparkDaysTransform(new String[] {"date"}), + Transforms.day("date")); + sparkToGravitinoPartitionTransformMaps.put( + SparkTransformConverter.createSparkMonthsTransform(new String[] {"date"}), + Transforms.month("date")); + sparkToGravitinoPartitionTransformMaps.put( + SparkTransformConverter.createSparkYearsTransform(new String[] {"date"}), + Transforms.year("date")); + sparkToGravitinoPartitionTransformMaps.put( + SparkTransformConverter.createSparkTruncateTransform( + "truncate", 10, new String[] {"package"}), + Transforms.truncate(10, "package")); } }