Skip to content

Commit

Permalink
[#1736] feat(postgresql): Support PostgreSQL index.
Browse files Browse the repository at this point in the history
  • Loading branch information
Clearvive authored and Clearvive committed Jan 31, 2024
1 parent 347f5d0 commit 204ae9a
Show file tree
Hide file tree
Showing 10 changed files with 489 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2024 Datastrato Pvt Ltd.
* This software is licensed under the Apache License version 2.
*/
package com.datastrato.gravitino.catalog.jdbc.bean;

import com.datastrato.gravitino.rel.indexes.Index;
import java.util.Objects;

/** Store JDBC index information. */
public class JdbcIndexBean {

private final Index.IndexType indexType;

private final String colName;

private final String name;

/** Used for sorting */
private final int order;

public JdbcIndexBean(Index.IndexType indexType, String colName, String name, int order) {
this.indexType = indexType;
this.colName = colName;
this.name = name;
this.order = order;
}

public Index.IndexType getIndexType() {
return indexType;
}

public String getColName() {
return colName;
}

public String getName() {
return name;
}

public int getOrder() {
return order;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JdbcIndexBean that = (JdbcIndexBean) o;
return order == that.order
&& indexType == that.indexType
&& Objects.equals(colName, that.colName)
&& Objects.equals(name, that.name);
}

@Override
public int hashCode() {
return Objects.hash(indexType, colName, name, order);
}

@Override
public String toString() {
return "JdbcIndexBean{"
+ "indexType="
+ indexType
+ ", colName='"
+ colName
+ '\''
+ ", name='"
+ name
+ '\''
+ ", order="
+ order
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import com.datastrato.gravitino.catalog.jdbc.JdbcColumn;
import com.datastrato.gravitino.catalog.jdbc.JdbcTable;
import com.datastrato.gravitino.catalog.jdbc.bean.JdbcIndexBean;
import com.datastrato.gravitino.catalog.jdbc.converter.JdbcExceptionConverter;
import com.datastrato.gravitino.catalog.jdbc.converter.JdbcTypeConverter;
import com.datastrato.gravitino.catalog.jdbc.utils.JdbcConnectorUtils;
Expand All @@ -16,15 +17,19 @@
import com.datastrato.gravitino.rel.TableChange;
import com.datastrato.gravitino.rel.expressions.transforms.Transform;
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.indexes.Indexes;
import com.google.common.collect.Lists;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
Expand Down Expand Up @@ -239,7 +244,75 @@ protected void correctJdbcTableFields(

protected List<Index> getIndexes(String databaseName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return Collections.emptyList();
List<Index> indexes = new ArrayList<>();

// Get primary key information
ResultSet primaryKeys = getPrimaryKeys(databaseName, tableName, metaData);
List<JdbcIndexBean> jdbcIndexBeans = new ArrayList<>();
while (primaryKeys.next()) {
jdbcIndexBeans.add(
new JdbcIndexBean(
Index.IndexType.PRIMARY_KEY,
primaryKeys.getString("COLUMN_NAME"),
primaryKeys.getString("PK_NAME"),
primaryKeys.getInt("KEY_SEQ")));
}

Set<String> primaryIndexNames =
jdbcIndexBeans.stream().map(JdbcIndexBean::getName).collect(Collectors.toSet());

// Get unique key information
ResultSet indexInfo = getIndexInfo(databaseName, tableName, metaData);
while (indexInfo.next()) {
String indexName = indexInfo.getString("INDEX_NAME");
if (!indexInfo.getBoolean("NON_UNIQUE") && !primaryIndexNames.contains(indexName)) {
jdbcIndexBeans.add(
new JdbcIndexBean(
Index.IndexType.UNIQUE_KEY,
indexInfo.getString("COLUMN_NAME"),
indexName,
indexInfo.getInt("ORDINAL_POSITION")));
}
}

// Assemble into Index
Map<Index.IndexType, List<JdbcIndexBean>> indexBeanGroupByIndexType =
jdbcIndexBeans.stream().collect(Collectors.groupingBy(JdbcIndexBean::getIndexType));

for (Map.Entry<Index.IndexType, List<JdbcIndexBean>> entry :
indexBeanGroupByIndexType.entrySet()) {
// Group by index Name
Map<String, List<JdbcIndexBean>> indexBeanGroupByName =
entry.getValue().stream().collect(Collectors.groupingBy(JdbcIndexBean::getName));
for (Map.Entry<String, List<JdbcIndexBean>> indexEntry : indexBeanGroupByName.entrySet()) {
List<String> colNames =
indexEntry.getValue().stream()
.sorted(Comparator.comparingInt(JdbcIndexBean::getOrder))
.map(JdbcIndexBean::getColName)
.collect(Collectors.toList());
String[][] colStrArrays = convertIndexFieldNames(colNames);
if (entry.getKey() == Index.IndexType.PRIMARY_KEY) {
indexes.add(Indexes.primary(indexEntry.getKey(), colStrArrays));
} else {
indexes.add(Indexes.unique(indexEntry.getKey(), colStrArrays));
}
}
}
return indexes;
}

protected ResultSet getIndexInfo(String databaseName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return metaData.getIndexInfo(databaseName, null, tableName, false, false);
}

protected ResultSet getPrimaryKeys(
String databaseName, String tableName, DatabaseMetaData metaData) throws SQLException {
return metaData.getPrimaryKeys(databaseName, null, tableName);
}

protected String[][] convertIndexFieldNames(List<String> fieldNames) {
return fieldNames.stream().map(colName -> new String[] {colName}).toArray(String[][]::new);
}

protected abstract String generateCreateTableSql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.indexes.Indexes;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.SetMultimap;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand All @@ -35,7 +32,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
Expand Down Expand Up @@ -215,43 +211,6 @@ public static void appendIndexesSql(Index[] indexes, StringBuilder sqlBuilder) {
}
}

protected List<Index> getIndexes(String databaseName, String tableName, DatabaseMetaData metaData)
throws SQLException {
List<Index> indexes = new ArrayList<>();

// Get primary key information
SetMultimap<String, String> primaryKeyGroupByName = HashMultimap.create();
ResultSet primaryKeys = metaData.getPrimaryKeys(databaseName, null, tableName);
while (primaryKeys.next()) {
String columnName = primaryKeys.getString("COLUMN_NAME");
primaryKeyGroupByName.put(primaryKeys.getString("PK_NAME"), columnName);
}
for (String key : primaryKeyGroupByName.keySet()) {
indexes.add(Indexes.primary(key, convertIndexFieldNames(primaryKeyGroupByName.get(key))));
}

// Get unique key information
SetMultimap<String, String> indexGroupByName = HashMultimap.create();
ResultSet indexInfo = metaData.getIndexInfo(databaseName, null, tableName, false, false);
while (indexInfo.next()) {
String indexName = indexInfo.getString("INDEX_NAME");
if (!indexInfo.getBoolean("NON_UNIQUE")
&& !StringUtils.equalsIgnoreCase(Indexes.DEFAULT_MYSQL_PRIMARY_KEY_NAME, indexName)) {
String columnName = indexInfo.getString("COLUMN_NAME");
indexGroupByName.put(indexName, columnName);
}
}
for (String key : indexGroupByName.keySet()) {
indexes.add(Indexes.unique(key, convertIndexFieldNames(indexGroupByName.get(key))));
}

return indexes;
}

private String[][] convertIndexFieldNames(Set<String> fieldNames) {
return fieldNames.stream().map(colName -> new String[] {colName}).toArray(String[][]::new);
}

@Override
protected boolean getAutoIncrementInfo(ResultSet resultSet) throws SQLException {
return "YES".equalsIgnoreCase(resultSet.getString("IS_AUTOINCREMENT"));
Expand Down
6 changes: 6 additions & 0 deletions catalogs/catalog-jdbc-postgresql/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ dependencies {
implementation(libs.commons.lang3)
implementation(libs.commons.collections4)
implementation(libs.jsqlparser)

testImplementation(libs.guava)
testImplementation(libs.commons.lang3)
testImplementation(libs.junit.jupiter.api)
testImplementation(libs.junit.jupiter.params)
testRuntimeOnly(libs.junit.jupiter.engine)
}

tasks {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.datastrato.gravitino.rel.expressions.transforms.Transform;
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.types.Types;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
Expand All @@ -25,6 +26,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.ArrayUtils;
Expand Down Expand Up @@ -82,6 +84,7 @@ protected String generateCreateTableSql(
sqlBuilder.append(",\n");
}
}
appendIndexesSql(indexes, sqlBuilder);
sqlBuilder.append("\n)");
// Add table properties if any
if (MapUtils.isNotEmpty(properties)) {
Expand Down Expand Up @@ -120,6 +123,40 @@ protected String generateCreateTableSql(
return result;
}

@VisibleForTesting
static void appendIndexesSql(Index[] indexes, StringBuilder sqlBuilder) {
for (Index index : indexes) {
String fieldStr =
Arrays.stream(index.fieldNames())
.map(
colNames -> {
if (colNames.length > 1) {
throw new IllegalArgumentException(
"Index does not support complex fields in PostgreSQL");
}
return PG_QUOTE + colNames[0] + PG_QUOTE;
})
.collect(Collectors.joining(", "));
sqlBuilder.append(",\n");
switch (index.type()) {
case PRIMARY_KEY:
if (StringUtils.isNotEmpty(index.name())) {
sqlBuilder.append("CONSTRAINT ").append(PG_QUOTE).append(index.name()).append(PG_QUOTE);
}
sqlBuilder.append(" PRIMARY KEY (").append(fieldStr).append(")");
break;
case UNIQUE_KEY:
if (StringUtils.isNotEmpty(index.name())) {
sqlBuilder.append("CONSTRAINT ").append(PG_QUOTE).append(index.name()).append(PG_QUOTE);
}
sqlBuilder.append(" UNIQUE (").append(fieldStr).append(")");
break;
default:
throw new IllegalArgumentException("PostgreSQL doesn't support index : " + index.type());
}
}
}

private StringBuilder appendColumnDefinition(JdbcColumn column, StringBuilder sqlBuilder) {
// Add data type
sqlBuilder
Expand Down Expand Up @@ -440,6 +477,18 @@ private String updateColumnCommentFieldDefinition(
+ "';";
}

@Override
protected ResultSet getIndexInfo(String schemaName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return metaData.getIndexInfo(database, schemaName, tableName, false, false);
}

@Override
protected ResultSet getPrimaryKeys(String schemaName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return metaData.getPrimaryKeys(database, schemaName, tableName);
}

@Override
protected Connection getConnection(String schema) throws SQLException {
Connection connection = dataSource.getConnection();
Expand Down
Loading

0 comments on commit 204ae9a

Please sign in to comment.