Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#1736] feat(postgresql): Support PostgreSQL index. #1796

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2024 Datastrato Pvt Ltd.
* This software is licensed under the Apache License version 2.
*/
package com.datastrato.gravitino.catalog.jdbc.bean;

import com.datastrato.gravitino.rel.indexes.Index;
import java.util.Objects;

/** Store JDBC index information. */
public class JdbcIndexBean {
Clearvive marked this conversation as resolved.
Show resolved Hide resolved

private final Index.IndexType indexType;

private final String colName;

private final String name;

/** Used for sorting */
private final int order;

public JdbcIndexBean(Index.IndexType indexType, String colName, String name, int order) {
this.indexType = indexType;
this.colName = colName;
this.name = name;
this.order = order;
}

public Index.IndexType getIndexType() {
return indexType;
}

public String getColName() {
return colName;
}

public String getName() {
return name;
}

public int getOrder() {
return order;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JdbcIndexBean that = (JdbcIndexBean) o;
return order == that.order
&& indexType == that.indexType
&& Objects.equals(colName, that.colName)
&& Objects.equals(name, that.name);
}

@Override
public int hashCode() {
return Objects.hash(indexType, colName, name, order);
}

@Override
public String toString() {
return "JdbcIndexBean{"
+ "indexType="
+ indexType
+ ", colName='"
+ colName
+ '\''
+ ", name='"
+ name
+ '\''
+ ", order="
+ order
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import com.datastrato.gravitino.catalog.jdbc.JdbcColumn;
import com.datastrato.gravitino.catalog.jdbc.JdbcTable;
import com.datastrato.gravitino.catalog.jdbc.bean.JdbcIndexBean;
import com.datastrato.gravitino.catalog.jdbc.converter.JdbcExceptionConverter;
import com.datastrato.gravitino.catalog.jdbc.converter.JdbcTypeConverter;
import com.datastrato.gravitino.catalog.jdbc.utils.JdbcConnectorUtils;
Expand All @@ -16,15 +17,19 @@
import com.datastrato.gravitino.rel.TableChange;
import com.datastrato.gravitino.rel.expressions.transforms.Transform;
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.indexes.Indexes;
import com.google.common.collect.Lists;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
Expand Down Expand Up @@ -239,7 +244,75 @@ protected void correctJdbcTableFields(

protected List<Index> getIndexes(String databaseName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return Collections.emptyList();
List<Index> indexes = new ArrayList<>();

// Get primary key information
ResultSet primaryKeys = getPrimaryKeys(databaseName, tableName, metaData);
List<JdbcIndexBean> jdbcIndexBeans = new ArrayList<>();
while (primaryKeys.next()) {
jdbcIndexBeans.add(
new JdbcIndexBean(
Index.IndexType.PRIMARY_KEY,
primaryKeys.getString("COLUMN_NAME"),
primaryKeys.getString("PK_NAME"),
primaryKeys.getInt("KEY_SEQ")));
}

Set<String> primaryIndexNames =
jdbcIndexBeans.stream().map(JdbcIndexBean::getName).collect(Collectors.toSet());

// Get unique key information
ResultSet indexInfo = getIndexInfo(databaseName, tableName, metaData);
while (indexInfo.next()) {
String indexName = indexInfo.getString("INDEX_NAME");
if (!indexInfo.getBoolean("NON_UNIQUE") && !primaryIndexNames.contains(indexName)) {
Clearvive marked this conversation as resolved.
Show resolved Hide resolved
jdbcIndexBeans.add(
new JdbcIndexBean(
Index.IndexType.UNIQUE_KEY,
indexInfo.getString("COLUMN_NAME"),
indexName,
indexInfo.getInt("ORDINAL_POSITION")));
}
}

// Assemble into Index
Map<Index.IndexType, List<JdbcIndexBean>> indexBeanGroupByIndexType =
jdbcIndexBeans.stream().collect(Collectors.groupingBy(JdbcIndexBean::getIndexType));

for (Map.Entry<Index.IndexType, List<JdbcIndexBean>> entry :
indexBeanGroupByIndexType.entrySet()) {
// Group by index Name
Map<String, List<JdbcIndexBean>> indexBeanGroupByName =
entry.getValue().stream().collect(Collectors.groupingBy(JdbcIndexBean::getName));
for (Map.Entry<String, List<JdbcIndexBean>> indexEntry : indexBeanGroupByName.entrySet()) {
List<String> colNames =
indexEntry.getValue().stream()
.sorted(Comparator.comparingInt(JdbcIndexBean::getOrder))
.map(JdbcIndexBean::getColName)
.collect(Collectors.toList());
String[][] colStrArrays = convertIndexFieldNames(colNames);
if (entry.getKey() == Index.IndexType.PRIMARY_KEY) {
indexes.add(Indexes.primary(indexEntry.getKey(), colStrArrays));
} else {
indexes.add(Indexes.unique(indexEntry.getKey(), colStrArrays));
}
}
}
return indexes;
}

protected ResultSet getIndexInfo(String databaseName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return metaData.getIndexInfo(databaseName, null, tableName, false, false);
}

protected ResultSet getPrimaryKeys(
String databaseName, String tableName, DatabaseMetaData metaData) throws SQLException {
return metaData.getPrimaryKeys(databaseName, null, tableName);
}

protected String[][] convertIndexFieldNames(List<String> fieldNames) {
return fieldNames.stream().map(colName -> new String[] {colName}).toArray(String[][]::new);
}

protected abstract String generateCreateTableSql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.indexes.Indexes;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.SetMultimap;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand All @@ -35,7 +32,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
Expand Down Expand Up @@ -215,43 +211,6 @@ public static void appendIndexesSql(Index[] indexes, StringBuilder sqlBuilder) {
}
}

protected List<Index> getIndexes(String databaseName, String tableName, DatabaseMetaData metaData)
throws SQLException {
List<Index> indexes = new ArrayList<>();

// Get primary key information
SetMultimap<String, String> primaryKeyGroupByName = HashMultimap.create();
ResultSet primaryKeys = metaData.getPrimaryKeys(databaseName, null, tableName);
while (primaryKeys.next()) {
String columnName = primaryKeys.getString("COLUMN_NAME");
primaryKeyGroupByName.put(primaryKeys.getString("PK_NAME"), columnName);
}
for (String key : primaryKeyGroupByName.keySet()) {
indexes.add(Indexes.primary(key, convertIndexFieldNames(primaryKeyGroupByName.get(key))));
}

// Get unique key information
SetMultimap<String, String> indexGroupByName = HashMultimap.create();
ResultSet indexInfo = metaData.getIndexInfo(databaseName, null, tableName, false, false);
while (indexInfo.next()) {
String indexName = indexInfo.getString("INDEX_NAME");
if (!indexInfo.getBoolean("NON_UNIQUE")
&& !StringUtils.equalsIgnoreCase(Indexes.DEFAULT_MYSQL_PRIMARY_KEY_NAME, indexName)) {
String columnName = indexInfo.getString("COLUMN_NAME");
indexGroupByName.put(indexName, columnName);
}
}
for (String key : indexGroupByName.keySet()) {
indexes.add(Indexes.unique(key, convertIndexFieldNames(indexGroupByName.get(key))));
}

return indexes;
}

private String[][] convertIndexFieldNames(Set<String> fieldNames) {
return fieldNames.stream().map(colName -> new String[] {colName}).toArray(String[][]::new);
}

@Override
protected boolean getAutoIncrementInfo(ResultSet resultSet) throws SQLException {
return "YES".equalsIgnoreCase(resultSet.getString("IS_AUTOINCREMENT"));
Expand Down
6 changes: 6 additions & 0 deletions catalogs/catalog-jdbc-postgresql/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ dependencies {
implementation(libs.commons.lang3)
implementation(libs.commons.collections4)
implementation(libs.jsqlparser)

testImplementation(libs.guava)
testImplementation(libs.commons.lang3)
testImplementation(libs.junit.jupiter.api)
testImplementation(libs.junit.jupiter.params)
testRuntimeOnly(libs.junit.jupiter.engine)
}

tasks {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.datastrato.gravitino.rel.expressions.transforms.Transform;
import com.datastrato.gravitino.rel.indexes.Index;
import com.datastrato.gravitino.rel.types.Types;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
Expand All @@ -25,6 +26,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.ArrayUtils;
Expand Down Expand Up @@ -82,6 +84,7 @@ protected String generateCreateTableSql(
sqlBuilder.append(",\n");
}
}
appendIndexesSql(indexes, sqlBuilder);
sqlBuilder.append("\n)");
// Add table properties if any
if (MapUtils.isNotEmpty(properties)) {
Expand Down Expand Up @@ -120,6 +123,40 @@ protected String generateCreateTableSql(
return result;
}

@VisibleForTesting
static void appendIndexesSql(Index[] indexes, StringBuilder sqlBuilder) {
for (Index index : indexes) {
String fieldStr =
Arrays.stream(index.fieldNames())
.map(
colNames -> {
if (colNames.length > 1) {
throw new IllegalArgumentException(
"Index does not support complex fields in PostgreSQL");
}
return PG_QUOTE + colNames[0] + PG_QUOTE;
})
.collect(Collectors.joining(", "));
sqlBuilder.append(",\n");
switch (index.type()) {
case PRIMARY_KEY:
if (StringUtils.isNotEmpty(index.name())) {
sqlBuilder.append("CONSTRAINT ").append(PG_QUOTE).append(index.name()).append(PG_QUOTE);
}
sqlBuilder.append(" PRIMARY KEY (").append(fieldStr).append(")");
break;
case UNIQUE_KEY:
if (StringUtils.isNotEmpty(index.name())) {
sqlBuilder.append("CONSTRAINT ").append(PG_QUOTE).append(index.name()).append(PG_QUOTE);
}
sqlBuilder.append(" UNIQUE (").append(fieldStr).append(")");
break;
default:
throw new IllegalArgumentException("PostgreSQL doesn't support index : " + index.type());
}
}
}

private StringBuilder appendColumnDefinition(JdbcColumn column, StringBuilder sqlBuilder) {
// Add data type
sqlBuilder
Expand Down Expand Up @@ -440,6 +477,18 @@ private String updateColumnCommentFieldDefinition(
+ "';";
}

@Override
protected ResultSet getIndexInfo(String schemaName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return metaData.getIndexInfo(database, schemaName, tableName, false, false);
}

@Override
protected ResultSet getPrimaryKeys(String schemaName, String tableName, DatabaseMetaData metaData)
throws SQLException {
return metaData.getPrimaryKeys(database, schemaName, tableName);
}

@Override
protected Connection getConnection(String schema) throws SQLException {
Connection connection = dataSource.getConnection();
Expand Down
Loading
Loading