Skip to content

Commit

Permalink
[#1736] feat(postgresql): Support PostgreSQL index.
Browse files Browse the repository at this point in the history
  • Loading branch information
Clearvive authored and Clearvive committed Jan 31, 2024
1 parent f4b9d1c commit 5ab650b
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import com.datastrato.gravitino.rel.TableChange;
import com.datastrato.gravitino.rel.expressions.transforms.Transform;
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.indexes.Indexes;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.SetMultimap;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
Expand All @@ -25,6 +28,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.sql.DataSource;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
Expand Down Expand Up @@ -239,7 +243,47 @@ protected void correctJdbcTableFields(

protected List<Index> getIndexes(String databaseName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return Collections.emptyList();
List<Index> indexes = new ArrayList<>();

// Get primary key information
SetMultimap<String, String> primaryKeyGroupByName = HashMultimap.create();
ResultSet primaryKeys = getPrimaryKeys(databaseName, tableName, metaData);
while (primaryKeys.next()) {
String columnName = primaryKeys.getString("COLUMN_NAME");
primaryKeyGroupByName.put(primaryKeys.getString("PK_NAME"), columnName);
}
for (String key : primaryKeyGroupByName.keySet()) {
indexes.add(Indexes.primary(key, convertIndexFieldNames(primaryKeyGroupByName.get(key))));
}

// Get unique key information
SetMultimap<String, String> indexGroupByName = HashMultimap.create();
ResultSet indexInfo = getIndexInfo(databaseName, tableName, metaData);
while (indexInfo.next()) {
String indexName = indexInfo.getString("INDEX_NAME");
if (!indexInfo.getBoolean("NON_UNIQUE") && !primaryKeyGroupByName.containsKey(indexName)) {
String columnName = indexInfo.getString("COLUMN_NAME");
indexGroupByName.put(indexName, columnName);
}
}
for (String key : indexGroupByName.keySet()) {
indexes.add(Indexes.unique(key, convertIndexFieldNames(indexGroupByName.get(key))));
}
return indexes;
}

protected ResultSet getIndexInfo(String databaseName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return metaData.getIndexInfo(databaseName, null, tableName, false, false);
}

protected ResultSet getPrimaryKeys(
String databaseName, String tableName, DatabaseMetaData metaData) throws SQLException {
return metaData.getPrimaryKeys(databaseName, null, tableName);
}

protected String[][] convertIndexFieldNames(Set<String> fieldNames) {
return fieldNames.stream().map(colName -> new String[] {colName}).toArray(String[][]::new);
}

protected abstract String generateCreateTableSql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.indexes.Indexes;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.SetMultimap;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand All @@ -35,7 +32,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
Expand Down Expand Up @@ -210,43 +206,6 @@ public static void appendIndexesSql(Index[] indexes, StringBuilder sqlBuilder) {
}
}

protected List<Index> getIndexes(String databaseName, String tableName, DatabaseMetaData metaData)
throws SQLException {
List<Index> indexes = new ArrayList<>();

// Get primary key information
SetMultimap<String, String> primaryKeyGroupByName = HashMultimap.create();
ResultSet primaryKeys = metaData.getPrimaryKeys(databaseName, null, tableName);
while (primaryKeys.next()) {
String columnName = primaryKeys.getString("COLUMN_NAME");
primaryKeyGroupByName.put(primaryKeys.getString("PK_NAME"), columnName);
}
for (String key : primaryKeyGroupByName.keySet()) {
indexes.add(Indexes.primary(key, convertIndexFieldNames(primaryKeyGroupByName.get(key))));
}

// Get unique key information
SetMultimap<String, String> indexGroupByName = HashMultimap.create();
ResultSet indexInfo = metaData.getIndexInfo(databaseName, null, tableName, false, false);
while (indexInfo.next()) {
String indexName = indexInfo.getString("INDEX_NAME");
if (!indexInfo.getBoolean("NON_UNIQUE")
&& !StringUtils.equalsIgnoreCase(Indexes.DEFAULT_MYSQL_PRIMARY_KEY_NAME, indexName)) {
String columnName = indexInfo.getString("COLUMN_NAME");
indexGroupByName.put(indexName, columnName);
}
}
for (String key : indexGroupByName.keySet()) {
indexes.add(Indexes.unique(key, convertIndexFieldNames(indexGroupByName.get(key))));
}

return indexes;
}

private String[][] convertIndexFieldNames(Set<String> fieldNames) {
return fieldNames.stream().map(colName -> new String[] {colName}).toArray(String[][]::new);
}

@Override
protected boolean getAutoIncrementInfo(ResultSet resultSet) throws SQLException {
return "YES".equalsIgnoreCase(resultSet.getString("IS_AUTOINCREMENT"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.datastrato.gravitino.rel.expressions.transforms.Transform;
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.types.Types;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
Expand All @@ -25,6 +26,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.ArrayUtils;
Expand Down Expand Up @@ -77,6 +79,7 @@ protected String generateCreateTableSql(
sqlBuilder.append(",\n");
}
}
appendIndexesSql(indexes, sqlBuilder);
sqlBuilder.append("\n)");
// Add table properties if any
if (MapUtils.isNotEmpty(properties)) {
Expand Down Expand Up @@ -115,6 +118,40 @@ protected String generateCreateTableSql(
return result;
}

@VisibleForTesting
public static void appendIndexesSql(Index[] indexes, StringBuilder sqlBuilder) {
for (Index index : indexes) {
String fieldStr =
Arrays.stream(index.fieldNames())
.map(
colNames -> {
if (colNames.length > 1) {
throw new IllegalArgumentException(
"Index does not support complex fields in PostgreSQL");
}
return PG_QUOTE + colNames[0] + PG_QUOTE;
})
.collect(Collectors.joining(", "));
sqlBuilder.append(",\n");
switch (index.type()) {
case PRIMARY_KEY:
if (StringUtils.isNotEmpty(index.name())) {
sqlBuilder.append("CONSTRAINT ").append(PG_QUOTE).append(index.name()).append(PG_QUOTE);
}
sqlBuilder.append(" PRIMARY KEY (").append(fieldStr).append(")");
break;
case UNIQUE_KEY:
if (StringUtils.isNotEmpty(index.name())) {
sqlBuilder.append("CONSTRAINT ").append(PG_QUOTE).append(index.name()).append(PG_QUOTE);
}
sqlBuilder.append(" UNIQUE (").append(fieldStr).append(")");
break;
default:
throw new IllegalArgumentException("PostgreSQL doesn't support index : " + index.type());
}
}
}

private StringBuilder appendColumnDefinition(JdbcColumn column, StringBuilder sqlBuilder) {
// Add data type
sqlBuilder
Expand Down Expand Up @@ -377,6 +414,18 @@ private String updateColumnCommentFieldDefinition(
return "COMMENT ON COLUMN " + tableName + "." + col + " IS '" + newComment + "';";
}

@Override
protected ResultSet getIndexInfo(String schemaName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return metaData.getIndexInfo(database, schemaName, tableName, false, false);
}

@Override
protected ResultSet getPrimaryKeys(String schemaName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return metaData.getPrimaryKeys(database, schemaName, tableName);
}

@Override
protected Connection getConnection(String schema) throws SQLException {
Connection connection = dataSource.getConnection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package com.datastrato.gravitino.integration.test.catalog.jdbc.postgresql;

import static org.junit.jupiter.api.Assertions.assertThrows;

import com.datastrato.gravitino.Catalog;
import com.datastrato.gravitino.NameIdentifier;
import com.datastrato.gravitino.Namespace;
Expand All @@ -29,6 +31,8 @@
import com.datastrato.gravitino.rel.expressions.sorts.SortOrder;
import com.datastrato.gravitino.rel.expressions.transforms.Transform;
import com.datastrato.gravitino.rel.expressions.transforms.Transforms;
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.indexes.Indexes;
import com.datastrato.gravitino.rel.types.Types;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
Expand Down Expand Up @@ -552,4 +556,120 @@ void testCreateAndLoadSchema() {
Assertions.assertEquals("anonymous", schema.auditInfo().creator());
Assertions.assertTrue(StringUtils.isEmpty(schema.comment()));
}

@Test
void testCreateIndexTable() {
Column col1 = Column.of("col_1", Types.LongType.get(), "id", false, false, null);
Column col2 = Column.of("col_2", Types.VarCharType.of(100), "yes", false, false, null);
Column col3 = Column.of("col_3", Types.DateType.get(), "comment", false, false, null);
Column col4 = Column.of("col_4", Types.VarCharType.of(255), "code", false, false, null);
Column col5 = Column.of("col_5", Types.VarCharType.of(255), "config", false, false, null);
Column[] newColumns = new Column[] {col1, col2, col3, col4, col5};

Index[] indexes =
new Index[] {
Indexes.primary("k1_pk", new String[][] {{"col_1"}, {"col_2"}}),
Indexes.unique("u1_key", new String[][] {{"col_2"}, {"col_3"}}),
Indexes.unique("u2_key", new String[][] {{"col_3"}, {"col_4"}}),
Indexes.unique("u3_key", new String[][] {{"col_5"}, {"col_4"}}),
Indexes.unique("u4_key", new String[][] {{"col_2"}, {"col_3"}, {"col_4"}}),
Indexes.unique("u5_key", new String[][] {{"col_2"}, {"col_3"}, {"col_5"}}),
Indexes.unique("u6_key", new String[][] {{"col_1"}, {"col_2"}, {"col_3"}, {"col_4"}}),
};

NameIdentifier tableIdentifier =
NameIdentifier.of(metalakeName, catalogName, schemaName, tableName);

Map<String, String> properties = createProperties();
TableCatalog tableCatalog = catalog.asTableCatalog();
Table createdTable =
tableCatalog.createTable(
tableIdentifier,
newColumns,
table_comment,
properties,
Transforms.EMPTY_TRANSFORM,
Distributions.NONE,
new SortOrder[0],
indexes);
assertionsTableInfo(
tableName, table_comment, Arrays.asList(newColumns), properties, indexes, createdTable);
Table table = tableCatalog.loadTable(tableIdentifier);
assertionsTableInfo(
tableName, table_comment, Arrays.asList(newColumns), properties, indexes, table);

IllegalArgumentException illegalArgumentException =
assertThrows(
IllegalArgumentException.class,
() -> {
tableCatalog.createTable(
NameIdentifier.of(metalakeName, catalogName, schemaName, "test_failed"),
newColumns,
table_comment,
properties,
Transforms.EMPTY_TRANSFORM,
Distributions.NONE,
new SortOrder[0],
new Index[] {Indexes.createMysqlPrimaryKey(new String[][] {{"col_1", "col_2"}})});
});
Assertions.assertTrue(
StringUtils.contains(
illegalArgumentException.getMessage(),
"Index does not support complex fields in PostgreSQL"));

illegalArgumentException =
assertThrows(
IllegalArgumentException.class,
() -> {
tableCatalog.createTable(
NameIdentifier.of(metalakeName, catalogName, schemaName, "test_failed"),
newColumns,
table_comment,
properties,
Transforms.EMPTY_TRANSFORM,
Distributions.NONE,
new SortOrder[0],
new Index[] {Indexes.unique("u1_key", new String[][] {{"col_2", "col_3"}})});
});
Assertions.assertTrue(
StringUtils.contains(
illegalArgumentException.getMessage(),
"Index does not support complex fields in PostgreSQL"));

table =
tableCatalog.createTable(
NameIdentifier.of(metalakeName, catalogName, schemaName, "test_null_key"),
newColumns,
table_comment,
properties,
Transforms.EMPTY_TRANSFORM,
Distributions.NONE,
new SortOrder[0],
new Index[] {
Indexes.of(
Index.IndexType.UNIQUE_KEY,
null,
new String[][] {{"col_1"}, {"col_3"}, {"col_4"}}),
Indexes.of(Index.IndexType.UNIQUE_KEY, null, new String[][] {{"col_4"}}),
});
Assertions.assertEquals(2, table.index().length);
Assertions.assertNotNull(table.index()[0].name());
Assertions.assertNotNull(table.index()[1].name());

table =
tableCatalog.createTable(
NameIdentifier.of(metalakeName, catalogName, schemaName, "many_index"),
newColumns,
table_comment,
properties,
Transforms.EMPTY_TRANSFORM,
Distributions.NONE,
new SortOrder[0],
new Index[] {
Indexes.unique("u4_key_2", new String[][] {{"col_2"}, {"col_3"}, {"col_4"}}),
Indexes.unique("u5_key_3", new String[][] {{"col_2"}, {"col_3"}, {"col_4"}}),
});
Assertions.assertEquals(1, table.index().length);
Assertions.assertEquals("u4_key_2", table.index()[0].name());
}
}
Loading

0 comments on commit 5ab650b

Please sign in to comment.