diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkIT.java index 918c5b0740b..9a5bb6711cf 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkIT.java @@ -8,6 +8,7 @@ import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo; import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker; import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -278,6 +279,34 @@ void testAlterTableSetAndRemoveProperty() { Assertions.assertTrue(newProperties.containsKey("key2")); } + @Test + void testAlterTableAddAndDeleteColumn() { + String tableName = "test_column"; + dropTableIfExists(tableName); + + List simpleTableColumns = getSimpleTableColumn(); + + createSimpleTable(tableName); + checkTableColumns(tableName, simpleTableColumns, getTableInfo(tableName)); + + sql(String.format("ALTER TABLE %S ADD COLUMNS (col1 string)", tableName)); + ArrayList addColumns = new ArrayList<>(simpleTableColumns); + addColumns.add(SparkColumnInfo.of("col1", DataTypes.StringType, null)); + checkTableColumns(tableName, addColumns, getTableInfo(tableName)); + + sql(String.format("ALTER TABLE %S DROP COLUMNS (col1)", tableName)); + checkTableColumns(tableName, simpleTableColumns, getTableInfo(tableName)); + } + + private void checkTableColumns( + String tableName, List columnInfos, SparkTableInfo tableInfo) { + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(columnInfos) + .withComment(null) + .check(tableInfo); + } + private void checkTableReadWrite(SparkTableInfo table) { String name = table.getTableIdentifier(); String insertValues = diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java index 71c4f600969..a622ec830f7 100644 --- a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java @@ -361,9 +361,37 @@ static com.datastrato.gravitino.rel.TableChange transformTableChange(TableChange } else if (change instanceof TableChange.RemoveProperty) { TableChange.RemoveProperty removeProperty = (TableChange.RemoveProperty) change; return com.datastrato.gravitino.rel.TableChange.removeProperty(removeProperty.property()); + } else if (change instanceof TableChange.AddColumn) { + TableChange.AddColumn addColumn = (TableChange.AddColumn) change; + return com.datastrato.gravitino.rel.TableChange.addColumn( + addColumn.fieldNames(), + SparkTypeConverter.toGravitinoType(addColumn.dataType()), + addColumn.comment(), + transformColumnPosition(addColumn.position()), + addColumn.isNullable()); + } else if (change instanceof TableChange.DeleteColumn) { + TableChange.DeleteColumn deleteColumn = (TableChange.DeleteColumn) change; + return com.datastrato.gravitino.rel.TableChange.deleteColumn( + deleteColumn.fieldNames(), deleteColumn.ifExists()); } else { throw new UnsupportedOperationException( String.format("Unsupported table change %s", change.getClass().getName())); } } + + private static com.datastrato.gravitino.rel.TableChange.ColumnPosition transformColumnPosition( + TableChange.ColumnPosition columnPosition) { + if (null == columnPosition) { + return com.datastrato.gravitino.rel.TableChange.ColumnPosition.defaultPos(); + } else if (columnPosition instanceof TableChange.First) { + return com.datastrato.gravitino.rel.TableChange.ColumnPosition.first(); + } else if (columnPosition instanceof TableChange.After) { + TableChange.After after = (TableChange.After) columnPosition; + return com.datastrato.gravitino.rel.TableChange.ColumnPosition.after(after.column()); + } else { + throw new UnsupportedOperationException( + String.format( + "Unsupported table column position %s", columnPosition.getClass().getName())); + } + } } diff --git a/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/catalog/TestTransformTableChange.java b/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/catalog/TestTransformTableChange.java index 3b35d0ee459..4f9288bf8a3 100644 --- a/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/catalog/TestTransformTableChange.java +++ b/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/catalog/TestTransformTableChange.java @@ -5,7 +5,10 @@ package com.datastrato.gravitino.spark.connector.catalog; +import org.apache.spark.sql.connector.catalog.ColumnDefaultValue; import org.apache.spark.sql.connector.catalog.TableChange; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.types.DataTypes; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -35,4 +38,88 @@ void testTransformRemoveProperty() { (com.datastrato.gravitino.rel.TableChange.RemoveProperty) tableChange; Assertions.assertEquals("key", gravitinoRemoveProperty.getProperty()); } + + @Test + void testTransformAddColumn() { + + TableChange.ColumnPosition first = TableChange.ColumnPosition.first(); + TableChange.ColumnPosition after = TableChange.ColumnPosition.after("col0"); + ColumnDefaultValue defaultValue = + new ColumnDefaultValue( + "CURRENT_DEFAULT", new LiteralValue("default_value", DataTypes.StringType)); + + TableChange.AddColumn sparkAddColumnFirst = + (TableChange.AddColumn) + TableChange.addColumn( + new String[] {"col1"}, DataTypes.StringType, true, "", first, defaultValue); + com.datastrato.gravitino.rel.TableChange gravitinoChangeFirst = + GravitinoCatalog.transformTableChange(sparkAddColumnFirst); + + Assertions.assertTrue( + gravitinoChangeFirst instanceof com.datastrato.gravitino.rel.TableChange.AddColumn); + com.datastrato.gravitino.rel.TableChange.AddColumn gravitinoAddColumnFirst = + (com.datastrato.gravitino.rel.TableChange.AddColumn) gravitinoChangeFirst; + + Assertions.assertEquals(sparkAddColumnFirst.fieldNames(), gravitinoAddColumnFirst.fieldName()); + Assertions.assertTrue( + "string".equalsIgnoreCase(gravitinoAddColumnFirst.getDataType().simpleString())); + Assertions.assertTrue( + gravitinoAddColumnFirst.getPosition() + instanceof com.datastrato.gravitino.rel.TableChange.First); + + TableChange.AddColumn sparkAddColumnAfter = + (TableChange.AddColumn) + TableChange.addColumn( + new String[] {"col1"}, DataTypes.StringType, true, "", after, defaultValue); + com.datastrato.gravitino.rel.TableChange gravitinoChangeAfter = + GravitinoCatalog.transformTableChange(sparkAddColumnAfter); + + Assertions.assertTrue( + gravitinoChangeAfter instanceof com.datastrato.gravitino.rel.TableChange.AddColumn); + com.datastrato.gravitino.rel.TableChange.AddColumn gravitinoAddColumnAfter = + (com.datastrato.gravitino.rel.TableChange.AddColumn) gravitinoChangeAfter; + + Assertions.assertEquals(sparkAddColumnAfter.fieldNames(), gravitinoAddColumnAfter.fieldName()); + Assertions.assertTrue( + "string".equalsIgnoreCase(gravitinoAddColumnAfter.getDataType().simpleString())); + Assertions.assertTrue( + gravitinoAddColumnAfter.getPosition() + instanceof com.datastrato.gravitino.rel.TableChange.After); + + TableChange.AddColumn sparkAddColumnDefault = + (TableChange.AddColumn) + TableChange.addColumn( + new String[] {"col1"}, DataTypes.StringType, true, "", null, defaultValue); + com.datastrato.gravitino.rel.TableChange gravitinoChangeDefault = + GravitinoCatalog.transformTableChange(sparkAddColumnDefault); + + Assertions.assertTrue( + gravitinoChangeDefault instanceof com.datastrato.gravitino.rel.TableChange.AddColumn); + com.datastrato.gravitino.rel.TableChange.AddColumn gravitinoAddColumnDefault = + (com.datastrato.gravitino.rel.TableChange.AddColumn) gravitinoChangeDefault; + + Assertions.assertEquals( + sparkAddColumnDefault.fieldNames(), gravitinoAddColumnDefault.fieldName()); + Assertions.assertTrue( + "string".equalsIgnoreCase(gravitinoAddColumnDefault.getDataType().simpleString())); + Assertions.assertTrue( + gravitinoAddColumnDefault.getPosition() + instanceof com.datastrato.gravitino.rel.TableChange.Default); + } + + @Test + void testTransformDeleteColumn() { + TableChange.DeleteColumn sparkDeleteColumn = + (TableChange.DeleteColumn) TableChange.deleteColumn(new String[] {"col1"}, true); + com.datastrato.gravitino.rel.TableChange gravitinoChange = + GravitinoCatalog.transformTableChange(sparkDeleteColumn); + + Assertions.assertTrue( + gravitinoChange instanceof com.datastrato.gravitino.rel.TableChange.DeleteColumn); + com.datastrato.gravitino.rel.TableChange.DeleteColumn gravitinoDeleteColumn = + (com.datastrato.gravitino.rel.TableChange.DeleteColumn) gravitinoChange; + + Assertions.assertEquals(sparkDeleteColumn.fieldNames(), gravitinoDeleteColumn.fieldName()); + Assertions.assertEquals(sparkDeleteColumn.ifExists(), gravitinoDeleteColumn.getIfExists()); + } }