Skip to content

Commit

Permalink
for apache#2900, refactor EncryptRuntimeContext
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Aug 22, 2019
1 parent 876d008 commit d34a2ab
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public final class EncryptInsertOptimizeEngine implements EncryptOptimizeEngine<
public EncryptInsertOptimizedStatement optimize(final EncryptRule encryptRule, final TableMetas tableMetas, final String sql, final List<Object> parameters, final InsertStatement sqlStatement) {
InsertValueEngine insertValueEngine = new InsertValueEngine();
EncryptInsertOptimizedStatement result = new EncryptInsertOptimizedStatement(
sqlStatement, new EncryptInsertColumns(encryptRule, tableMetas, sqlStatement), insertValueEngine.createInsertValues(sqlStatement));
sqlStatement, new EncryptInsertColumns(tableMetas, sqlStatement), insertValueEngine.createInsertValues(sqlStatement));
int derivedColumnsCount = encryptRule.getAssistedQueryAndPlainColumnCount(sqlStatement.getTable().getTableName());
int parametersCount = 0;
Collection<String> columnNames = getColumnNames(tableMetas, sqlStatement);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.shardingsphere.core.metadata.table.TableMetas;
import org.apache.shardingsphere.core.optimize.api.segment.InsertColumns;
import org.apache.shardingsphere.core.parse.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.core.rule.EncryptRule;

import java.util.Collection;
import java.util.LinkedHashSet;
Expand All @@ -33,32 +32,13 @@
* @author zhangliang
* @author panjuan
*/
@Getter
@ToString
public final class EncryptInsertColumns implements InsertColumns {

private final Collection<String> assistedQueryAndPlainColumnNames;

@Getter
private final Collection<String> regularColumnNames;

public EncryptInsertColumns(final EncryptRule encryptRule, final TableMetas tableMetas, final InsertStatement insertStatement) {
assistedQueryAndPlainColumnNames = encryptRule.getAssistedQueryAndPlainColumns(insertStatement.getTable().getTableName());
regularColumnNames = insertStatement.useDefaultColumns() ? getRegularColumnNamesFromMetaData(encryptRule, tableMetas, insertStatement) : insertStatement.getColumnNames();
}

private Collection<String> getRegularColumnNamesFromMetaData(final EncryptRule encryptRule, final TableMetas tableMetas, final InsertStatement insertStatement) {
Collection<String> allColumnNames = tableMetas.getAllColumnNames(insertStatement.getTable().getTableName());
Collection<String> result = new LinkedHashSet<>(allColumnNames.size() - assistedQueryAndPlainColumnNames.size());
String tableName = insertStatement.getTable().getTableName();
for (String each : allColumnNames) {
if (encryptRule.getCipherColumns(tableName).contains(each)) {
result.add(encryptRule.getLogicColumn(tableName, each));
continue;
}
if (!assistedQueryAndPlainColumnNames.contains(each)) {
result.add(each);
}
}
return result;
public EncryptInsertColumns(final TableMetas tableMetas, final InsertStatement insertStatement) {
regularColumnNames = insertStatement.useDefaultColumns() ? new LinkedHashSet<>(tableMetas.getAllColumnNames(insertStatement.getTable().getTableName())) : insertStatement.getColumnNames();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Optional;
import lombok.Getter;
import org.apache.shardingsphere.core.metadata.column.ColumnMetaData;
import org.apache.shardingsphere.core.metadata.column.EncryptColumnMetaData;
import org.apache.shardingsphere.core.metadata.table.TableMetaData;
import org.apache.shardingsphere.core.metadata.table.TableMetas;
import org.apache.shardingsphere.core.rule.EncryptRule;
Expand All @@ -46,19 +47,25 @@
@Getter
public final class EncryptRuntimeContext extends AbstractRuntimeContext<EncryptRule> {

private static final String COLUMN_NAME = "COLUMN_NAME";

private static final String TYPE_NAME = "TYPE_NAME";

private static final String INDEX_NAME = "INDEX_NAME";

private final TableMetas tableMetas;

public EncryptRuntimeContext(final DataSource dataSource, final EncryptRule rule, final Properties props, final DatabaseType databaseType) throws SQLException {
super(rule, props, databaseType);
tableMetas = createEncryptTableMetas(dataSource, rule);
public EncryptRuntimeContext(final DataSource dataSource, final EncryptRule encryptRule, final Properties props, final DatabaseType databaseType) throws SQLException {
super(encryptRule, props, databaseType);
tableMetas = createEncryptTableMetas(dataSource, encryptRule);
}

private TableMetas createEncryptTableMetas(final DataSource dataSource, final EncryptRule encryptRule) throws SQLException {
Map<String, TableMetaData> tables = new LinkedHashMap<>();
try (Connection connection = dataSource.getConnection()) {
for (String each : encryptRule.getEncryptTableNames()) {
if (isTableExist(connection, each)) {
tables.put(each, new TableMetaData(getColumnMetaDataList(connection, each), getLogicIndexes(connection, each)));
tables.put(each, new TableMetaData(getColumnMetaDataList(connection, each, encryptRule), getIndexes(connection, each)));
}
}
}
Expand All @@ -71,47 +78,55 @@ private boolean isTableExist(final Connection connection, final String tableName
}
}

private List<ColumnMetaData> getColumnMetaDataList(final Connection connection, final String tableName) throws SQLException {
private List<ColumnMetaData> getColumnMetaDataList(final Connection connection, final String tableName, final EncryptRule encryptRule) throws SQLException {
List<ColumnMetaData> result = new LinkedList<>();
Collection<String> primaryKeys = getPrimaryKeys(connection, tableName);
Collection<String> derivedColumns = encryptRule.getAssistedQueryAndPlainColumns(tableName);
try (ResultSet resultSet = connection.getMetaData().getColumns(connection.getCatalog(), null, tableName, "%")) {
while (resultSet.next()) {
String columnName = resultSet.getString("COLUMN_NAME");
String columnType = resultSet.getString("TYPE_NAME");
result.add(new ColumnMetaData(columnName, columnType, primaryKeys.contains(columnName)));
String columnName = resultSet.getString(COLUMN_NAME);
String columnType = resultSet.getString(TYPE_NAME);
boolean isPrimaryKey = primaryKeys.contains(columnName);
Optional<ColumnMetaData> columnMetaData = getColumnMetaData(tableName, columnName, columnType, isPrimaryKey, encryptRule, derivedColumns);
if (columnMetaData.isPresent()) {
result.add(columnMetaData.get());
}
}
}
return result;
}

private Optional<ColumnMetaData> getColumnMetaData(final String logicTableName, final String columnName, final String columnType, final boolean isPrimaryKey,
final EncryptRule encryptRule, final Collection<String> derivedColumns) {
if (derivedColumns.contains(columnName)) {
return Optional.absent();
}
if (encryptRule.isCipherColumn(logicTableName, columnName)) {
String logicColumnName = encryptRule.getLogicColumn(logicTableName, columnName);
String plainColumnName = encryptRule.getPlainColumn(logicTableName, logicColumnName).orNull();
String assistedQueryColumnName = encryptRule.getAssistedQueryColumn(logicTableName, logicColumnName).orNull();
return Optional.<ColumnMetaData>of(new EncryptColumnMetaData(logicColumnName, columnType, isPrimaryKey, columnName, plainColumnName, assistedQueryColumnName));
}
return Optional.of(new ColumnMetaData(columnName, columnType, isPrimaryKey));
}

private Collection<String> getPrimaryKeys(final Connection connection, final String tableName) throws SQLException {
Collection<String> result = new HashSet<>();
try (ResultSet resultSet = connection.getMetaData().getPrimaryKeys(connection.getCatalog(), null, tableName)) {
while (resultSet.next()) {
result.add(resultSet.getString("COLUMN_NAME"));
result.add(resultSet.getString(COLUMN_NAME));
}
}
return result;
}

private Set<String> getLogicIndexes(final Connection connection, final String actualTableName) throws SQLException {
private Set<String> getIndexes(final Connection connection, final String tableName) throws SQLException {
Set<String> result = new HashSet<>();
try (ResultSet resultSet = connection.getMetaData().getIndexInfo(connection.getCatalog(), connection.getCatalog(), actualTableName, false, false)) {
try (ResultSet resultSet = connection.getMetaData().getIndexInfo(connection.getCatalog(), connection.getCatalog(), tableName, false, false)) {
while (resultSet.next()) {
Optional<String> logicIndex = getLogicIndex(resultSet.getString("INDEX_NAME"), actualTableName);
if (logicIndex.isPresent()) {
result.add(logicIndex.get());
}
result.add(resultSet.getString(INDEX_NAME));
}
}
return result;
}

private Optional<String> getLogicIndex(final String actualIndexName, final String actualTableName) {
String indexNameSuffix = "_" + actualTableName;
if (actualIndexName.contains(indexNameSuffix)) {
return Optional.of(actualIndexName.replace(indexNameSuffix, ""));
}
return Optional.absent();
}
}

0 comments on commit d34a2ab

Please sign in to comment.