diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedChecker.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedChecker.java index 00ae44a762f1c..bf010e28022da 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedChecker.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedChecker.java @@ -19,10 +19,12 @@ import org.apache.shardingsphere.encrypt.exception.metadata.MissingMatchedEncryptQueryAlgorithmException; import org.apache.shardingsphere.encrypt.rewrite.token.comparator.JoinConditionsEncryptorComparator; +import org.apache.shardingsphere.encrypt.rewrite.util.EncryptPredicateSegmentUtils; import org.apache.shardingsphere.encrypt.rule.EncryptRule; import org.apache.shardingsphere.encrypt.rule.table.EncryptTable; import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.binder.context.type.TableAvailable; import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; import org.apache.shardingsphere.infra.checker.SupportedSQLChecker; @@ -53,7 +55,9 @@ public boolean isCheck(final SQLStatementContext sqlStatementContext) { @Override public void check(final EncryptRule rule, final ShardingSphereSchema schema, final SQLStatementContext sqlStatementContext) { - ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(((WhereAvailable) sqlStatementContext).getJoinConditions(), rule), + Collection allSubqueryContexts = EncryptPredicateSegmentUtils.getAllSubqueryContexts(sqlStatementContext); + Collection joinConditions = EncryptPredicateSegmentUtils.getJoinConditions((WhereAvailable) sqlStatementContext, allSubqueryContexts); + ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(joinConditions, rule), () -> new UnsupportedSQLOperationException("Can not use different encryptor in join condition")); check(rule, schema, (WhereAvailable) sqlStatementContext); } diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java index f117fc4e27797..52cbfb2e281ac 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/context/EncryptSQLRewriteContextDecorator.java @@ -22,10 +22,12 @@ import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptConditionEngine; import org.apache.shardingsphere.encrypt.rewrite.parameter.EncryptParameterRewritersRegistry; import org.apache.shardingsphere.encrypt.rewrite.token.EncryptTokenGenerateBuilder; +import org.apache.shardingsphere.encrypt.rewrite.util.EncryptPredicateSegmentUtils; import org.apache.shardingsphere.encrypt.rule.EncryptRule; import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.binder.context.type.TableAvailable; import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; import org.apache.shardingsphere.infra.config.props.ConfigurationProperties; @@ -88,8 +90,9 @@ private Collection doCreateEncryptConditions(final EncryptRule if (!(sqlStatementContext instanceof WhereAvailable)) { return Collections.emptyList(); } - Collection whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments(); - Collection columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments(); + Collection allSubqueryContexts = EncryptPredicateSegmentUtils.getAllSubqueryContexts(sqlStatementContext); + Collection whereSegments = EncryptPredicateSegmentUtils.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts); + Collection columnSegments = EncryptPredicateSegmentUtils.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts); return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase().getSchemas()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext, sqlRewriteContext.getDatabase().getName()); } diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/rewriter/EncryptPredicateParameterRewriter.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/rewriter/EncryptPredicateParameterRewriter.java index b4ca5e082c7cc..ffe18e38314e1 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/rewriter/EncryptPredicateParameterRewriter.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/parameter/rewriter/EncryptPredicateParameterRewriter.java @@ -25,6 +25,8 @@ import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn; import org.apache.shardingsphere.encrypt.rule.table.EncryptTable; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.binder.context.type.TableAvailable; import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry; @@ -51,7 +53,25 @@ public final class EncryptPredicateParameterRewriter implements ParameterRewrite @Override public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) { - return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty(); + if (sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty()) { + return true; + } + if (sqlStatementContext instanceof SelectStatementContext) { + return isSubqueryNeedRewrite((SelectStatementContext) sqlStatementContext); + } + if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) { + return isSubqueryNeedRewrite(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext()); + } + return false; + } + + private boolean isSubqueryNeedRewrite(final SelectStatementContext selectStatementContext) { + for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) { + if (isNeedRewrite(each)) { + return true; + } + } + return false; } @Override diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java index 9b2bd6706f1ff..179d24a793253 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java @@ -20,6 +20,7 @@ import com.google.common.base.Preconditions; import lombok.RequiredArgsConstructor; import lombok.Setter; +import org.apache.shardingsphere.encrypt.rewrite.util.EncryptPredicateSegmentUtils; import org.apache.shardingsphere.encrypt.rule.EncryptRule; import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn; import org.apache.shardingsphere.encrypt.rule.column.item.LikeQueryColumnItem; @@ -28,6 +29,7 @@ import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection; import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.binder.context.type.TableAvailable; import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter; @@ -72,8 +74,9 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) @Override public Collection generateSQLTokens(final SQLStatementContext sqlStatementContext) { - Collection columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments(); - Collection whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments(); + Collection allSubqueryContexts = EncryptPredicateSegmentUtils.getAllSubqueryContexts(sqlStatementContext); + Collection whereSegments = EncryptPredicateSegmentUtils.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts); + Collection columnSegments = EncryptPredicateSegmentUtils.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts); ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema); Map columnExpressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema); return generateSQLTokens(columnSegments, columnExpressionTableNames, whereSegments, sqlStatementContext.getDatabaseType()); diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java index facfea346055b..f6fb343b46777 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java @@ -17,6 +17,7 @@ package org.apache.shardingsphere.encrypt.rewrite.token.generator.projection; +import com.cedarsoftware.util.CaseInsensitiveSet; import lombok.RequiredArgsConstructor; import org.apache.shardingsphere.encrypt.rule.EncryptRule; import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn; @@ -34,11 +35,16 @@ import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken; import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.SubstitutableColumnNameToken; import org.apache.shardingsphere.sql.parser.statement.core.enums.SubqueryType; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OwnerSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.ParenthesesSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.JoinTableSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SubqueryTableSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment; +import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectStatement; import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue; import java.util.Collection; @@ -46,6 +52,8 @@ import java.util.LinkedList; import java.util.List; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; /** * Projection token generator for encrypt. @@ -74,28 +82,31 @@ public Collection generateSQLTokens(final SelectStatementContext selec private Collection generateSelectSQLTokens(final SelectStatementContext selectStatementContext) { Collection result = new LinkedList<>(); + Collection existColumnNames = new CaseInsensitiveSet<>(); for (ProjectionSegment each : selectStatementContext.getSqlStatement().getProjections().getProjections()) { if (each instanceof ColumnProjectionSegment) { - generateSQLToken(selectStatementContext, (ColumnProjectionSegment) each).ifPresent(result::add); + generateSQLToken(selectStatementContext, (ColumnProjectionSegment) each, existColumnNames).ifPresent(result::add); } if (each instanceof ShorthandProjectionSegment) { ShorthandProjectionSegment shorthandSegment = (ShorthandProjectionSegment) each; Collection actualColumns = getShorthandProjection(shorthandSegment, selectStatementContext.getProjectionsContext()).getActualColumns(); if (!actualColumns.isEmpty()) { - result.add(generateSQLToken(shorthandSegment, actualColumns, selectStatementContext, selectStatementContext.getSubqueryType())); + result.add(generateSQLToken(shorthandSegment, actualColumns, selectStatementContext, selectStatementContext.getSubqueryType(), existColumnNames)); } } } return result; } - private Optional generateSQLToken(final SelectStatementContext selectStatementContext, final ColumnProjectionSegment columnSegment) { + private Optional generateSQLToken(final SelectStatementContext selectStatementContext, final ColumnProjectionSegment columnSegment, + final Collection existColumnNames) { ColumnProjection columnProjection = buildColumnProjection(columnSegment); String columnName = columnProjection.getOriginalColumn().getValue(); + boolean newAddedColumn = existColumnNames.add(columnProjection.getOriginalTable().getValue() + "." + columnName); Optional encryptTable = encryptRule.findEncryptTable(columnProjection.getOriginalTable().getValue()); - if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(columnName) && !selectStatementContext.containsTableSubquery()) { + if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(columnName) && isNeedRewrite(selectStatementContext, columnSegment)) { EncryptColumn encryptColumn = encryptTable.get().getEncryptColumn(columnName); - Collection projections = generateProjections(encryptColumn, columnProjection, selectStatementContext.getSubqueryType()); + Collection projections = generateProjections(encryptColumn, columnProjection, selectStatementContext.getSubqueryType(), newAddedColumn); int startIndex = getStartIndex(columnSegment); int stopIndex = getStopIndex(columnSegment); previousSQLTokens.removeIf(each -> each.getStartIndex() == startIndex); @@ -104,16 +115,17 @@ private Optional generateSQLToken(final SelectStat return Optional.empty(); } - private SubstitutableColumnNameToken generateSQLToken(final ShorthandProjectionSegment segment, final Collection actualColumns, - final SelectStatementContext selectStatementContext, final SubqueryType subqueryType) { + private SubstitutableColumnNameToken generateSQLToken(final ShorthandProjectionSegment segment, final Collection actualColumns, final SelectStatementContext selectStatementContext, + final SubqueryType subqueryType, final Collection existColumnNames) { Collection projections = new LinkedList<>(); for (Projection each : actualColumns) { if (each instanceof ColumnProjection) { ColumnProjection columnProjection = (ColumnProjection) each; + boolean newAddedColumn = existColumnNames.add(columnProjection.getOriginalTable().getValue() + "." + columnProjection.getOriginalColumn().getValue()); Optional encryptTable = encryptRule.findEncryptTable(columnProjection.getOriginalTable().getValue()); if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(columnProjection.getOriginalColumn().getValue()) && !selectStatementContext.containsTableSubquery()) { EncryptColumn encryptColumn = encryptTable.get().getEncryptColumn(columnProjection.getOriginalColumn().getValue()); - projections.addAll(generateProjections(encryptColumn, columnProjection, subqueryType)); + projections.addAll(generateProjections(encryptColumn, columnProjection, subqueryType, newAddedColumn)); continue; } } @@ -125,6 +137,51 @@ private SubstitutableColumnNameToken generateSQLToken(final ShorthandProjectionS return new SubstitutableColumnNameToken(startIndex, segment.getStopIndex(), projections, selectStatementContext.getDatabaseType()); } + private boolean isNeedRewrite(final SelectStatementContext selectStatementContext, final ColumnProjectionSegment columnSegment) { + SelectStatement sqlStatement = selectStatementContext.getSqlStatement(); + if (sqlStatement.getWithSegment().isPresent() && !(sqlStatement.getFrom().isPresent() && sqlStatement.getFrom().get() instanceof SubqueryTableSegment) + && columnSegment.getColumn().getOwner().isPresent()) { + if (columnSegment.getStopIndex() < sqlStatement.getWithSegment().get().getStartIndex() || columnSegment.getStartIndex() > sqlStatement.getWithSegment().get().getStopIndex()) { + Set withTableAlias = + sqlStatement.getWithSegment().get().getCommonTableExpressions().stream().map(each -> each.getAliasSegment().getIdentifier().getValue()).collect(Collectors.toSet()); + return !withTableAlias.contains(columnSegment.getColumn().getOwner().get().getIdentifier().getValue()); + } + } + if (sqlStatement.getFrom().isPresent() && isContainsInJoinSubquery(sqlStatement.getFrom().get(), columnSegment)) { + return false; + } + return !selectStatementContext.containsTableSubquery(); + } + + private boolean isContainsInJoinSubquery(final TableSegment tableSegment, final ColumnProjectionSegment columnSegment) { + if (tableSegment instanceof JoinTableSegment && isContainsInJoinSubquery(((JoinTableSegment) tableSegment).getLeft(), columnSegment)) { + return true; + } + if (tableSegment instanceof JoinTableSegment && isContainsInJoinSubquery(((JoinTableSegment) tableSegment).getRight(), columnSegment)) { + return true; + } + if (tableSegment instanceof SubqueryTableSegment) { + SubqueryTableSegment subqueryTable = (SubqueryTableSegment) tableSegment; + ColumnSegment column = columnSegment.getColumn(); + if (subqueryTable.getAliasName().isPresent() && column.getOwner().isPresent()) { + return subqueryTable.getAliasName().get().equalsIgnoreCase(column.getOwner().get().getIdentifier().getValue()); + } else { + return isContainsInSubqueryProjections(columnSegment, subqueryTable); + } + } + return false; + } + + private boolean isContainsInSubqueryProjections(final ColumnProjectionSegment columnSegment, final SubqueryTableSegment subqueryTable) { + for (ProjectionSegment each : subqueryTable.getSubquery().getSelect().getProjections().getProjections()) { + if (each instanceof ColumnProjectionSegment + && ((ColumnProjectionSegment) each).getColumn().getIdentifier().getValue().equalsIgnoreCase(columnSegment.getColumn().getIdentifier().getValue())) { + return true; + } + } + return false; + } + private int getStartIndex(final ColumnProjectionSegment columnSegment) { if (columnSegment.getColumn().getLeftParentheses().isPresent()) { return columnSegment.getColumn().getLeftParentheses().get().getStartIndex(); @@ -146,12 +203,12 @@ private ColumnProjection buildColumnProjection(final ColumnProjectionSegment seg } private Collection generateProjections(final EncryptColumn encryptColumn, final ColumnProjection columnProjection, - final SubqueryType subqueryType) { - if (null == subqueryType || SubqueryType.PROJECTION == subqueryType) { + final SubqueryType subqueryType, final boolean newAddedColumn) { + if (null == subqueryType || SubqueryType.PROJECTION == subqueryType || SubqueryType.WITH == subqueryType) { return Collections.singleton(generateProjection(encryptColumn, columnProjection)); } if (SubqueryType.TABLE == subqueryType || SubqueryType.JOIN == subqueryType) { - return generateProjectionsInTableSegmentSubquery(encryptColumn, columnProjection, subqueryType); + return generateProjectionsInTableSegmentSubquery(encryptColumn, columnProjection, newAddedColumn); } if (SubqueryType.PREDICATE == subqueryType) { return Collections.singleton(generateProjectionInPredicateSubquery(encryptColumn, columnProjection)); @@ -169,14 +226,17 @@ private ColumnProjection generateProjection(final EncryptColumn encryptColumn, f columnProjection.getLeftParentheses().orElse(null), columnProjection.getRightParentheses().orElse(null)); } - private Collection generateProjectionsInTableSegmentSubquery(final EncryptColumn encryptColumn, final ColumnProjection columnProjection, final SubqueryType subqueryType) { + private Collection generateProjectionsInTableSegmentSubquery(final EncryptColumn encryptColumn, final ColumnProjection columnProjection, final boolean newAddedColumn) { Collection result = new LinkedList<>(); QuoteCharacter quoteCharacter = columnProjection.getName().getQuoteCharacter(); - IdentifierValue alias = SubqueryType.JOIN == subqueryType ? null : columnProjection.getAlias().orElse(columnProjection.getName()); + IdentifierValue alias = columnProjection.getAlias().orElse(columnProjection.getName()); IdentifierValue cipherColumnName = new IdentifierValue(encryptColumn.getCipher().getName(), quoteCharacter); ParenthesesSegment leftParentheses = columnProjection.getLeftParentheses().orElse(null); ParenthesesSegment rightParentheses = columnProjection.getRightParentheses().orElse(null); result.add(new ColumnProjection(columnProjection.getOwner().orElse(null), cipherColumnName, alias, databaseType, leftParentheses, rightParentheses)); + if (newAddedColumn) { + result.add(new ColumnProjection(columnProjection.getOwner().orElse(null), cipherColumnName, null, databaseType)); + } IdentifierValue assistedColumOwner = columnProjection.getOwner().orElse(null); encryptColumn.getAssistedQuery().ifPresent( optional -> result.add(new ColumnProjection(assistedColumOwner, new IdentifierValue(optional.getName(), quoteCharacter), null, databaseType, leftParentheses, rightParentheses))); diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateEqualRightValueToken.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateEqualRightValueToken.java index 704473d859e16..ea9f4b7e167fd 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateEqualRightValueToken.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateEqualRightValueToken.java @@ -23,7 +23,6 @@ import java.util.Collection; import java.util.Map; -import java.util.Objects; /** * Predicate equal right value token for encrypt. @@ -51,16 +50,4 @@ public String toString() { } return "?"; } - - @Override - public boolean equals(final Object obj) { - return obj instanceof EncryptPredicateEqualRightValueToken && ((EncryptPredicateEqualRightValueToken) obj).getStartIndex() == getStartIndex() - && ((EncryptPredicateEqualRightValueToken) obj).getStopIndex() == stopIndex && ((EncryptPredicateEqualRightValueToken) obj).indexValues.equals(indexValues) - && ((EncryptPredicateEqualRightValueToken) obj).paramMarkerIndexes.equals(paramMarkerIndexes); - } - - @Override - public int hashCode() { - return Objects.hash(getStartIndex(), stopIndex, indexValues, paramMarkerIndexes); - } } diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateFunctionRightValueToken.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateFunctionRightValueToken.java index 399dee059f6e6..0223b6ac55f49 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateFunctionRightValueToken.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateFunctionRightValueToken.java @@ -25,7 +25,6 @@ import java.util.Collection; import java.util.Map; -import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; /** @@ -91,16 +90,4 @@ private void appendRewrittenParameters(final StringBuilder builder, final int pa } builder.append(COMMA_SEPARATOR); } - - @Override - public boolean equals(final Object obj) { - return obj instanceof EncryptPredicateFunctionRightValueToken && ((EncryptPredicateFunctionRightValueToken) obj).getStartIndex() == getStartIndex() - && ((EncryptPredicateFunctionRightValueToken) obj).getStopIndex() == stopIndex && ((EncryptPredicateFunctionRightValueToken) obj).indexValues.equals(indexValues) - && ((EncryptPredicateFunctionRightValueToken) obj).paramMarkerIndexes.equals(paramMarkerIndexes); - } - - @Override - public int hashCode() { - return Objects.hash(getStartIndex(), stopIndex, indexValues, paramMarkerIndexes); - } } diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateInRightValueToken.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateInRightValueToken.java index adae78e4cc58c..2b7f47b9c12d2 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateInRightValueToken.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/pojo/EncryptPredicateInRightValueToken.java @@ -23,7 +23,6 @@ import java.util.Collection; import java.util.Map; -import java.util.Objects; /** * Predicate in right value token for encrypt. @@ -63,16 +62,4 @@ public String toString() { result.delete(result.length() - 2, result.length()).append(')'); return result.toString(); } - - @Override - public boolean equals(final Object obj) { - return obj instanceof EncryptPredicateInRightValueToken && ((EncryptPredicateInRightValueToken) obj).getStartIndex() == getStartIndex() - && ((EncryptPredicateInRightValueToken) obj).getStopIndex() == stopIndex && ((EncryptPredicateInRightValueToken) obj).indexValues.equals(indexValues) - && ((EncryptPredicateInRightValueToken) obj).paramMarkerIndexes.equals(paramMarkerIndexes); - } - - @Override - public int hashCode() { - return Objects.hash(getStartIndex(), stopIndex, indexValues, paramMarkerIndexes); - } } diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/util/EncryptPredicateSegmentUtils.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/util/EncryptPredicateSegmentUtils.java new file mode 100644 index 0000000000000..d4a5b4d9413fd --- /dev/null +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/util/EncryptPredicateSegmentUtils.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.encrypt.rewrite.util; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertSelectContext; +import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; +import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment; + +import java.util.Collection; +import java.util.LinkedList; + +/** + * Encrypt predicate segment utility class. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class EncryptPredicateSegmentUtils { + + /** + * Get all subquery contexts. + * + * @param sqlStatementContext SQL statement context + * @return all subquery contexts + */ + public static Collection getAllSubqueryContexts(final SQLStatementContext sqlStatementContext) { + Collection result = new LinkedList<>(); + if (sqlStatementContext instanceof SelectStatementContext) { + result.addAll(((SelectStatementContext) sqlStatementContext).getSubqueryContexts().values()); + ((SelectStatementContext) sqlStatementContext).getSubqueryContexts().values().forEach(each -> result.addAll(getAllSubqueryContexts(each))); + } + if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) { + InsertSelectContext insertSelectContext = ((InsertStatementContext) sqlStatementContext).getInsertSelectContext(); + result.add(insertSelectContext.getSelectStatementContext()); + result.addAll(insertSelectContext.getSelectStatementContext().getSubqueryContexts().values()); + insertSelectContext.getSelectStatementContext().getSubqueryContexts().values().forEach(each -> result.addAll(getAllSubqueryContexts(each))); + } + return result; + } + + /** + * Get all where segments. + * + * @param whereAvailable where available + * @param allSubqueryContexts all subquery contexts + * @return all where segments + */ + public static Collection getWhereSegments(final WhereAvailable whereAvailable, final Collection allSubqueryContexts) { + Collection result = new LinkedList<>(whereAvailable.getWhereSegments()); + allSubqueryContexts.forEach(each -> result.addAll(each.getWhereSegments())); + return result; + } + + /** + * Get all column segments. + * + * @param whereAvailable where available + * @param allSubqueryContexts all subquery contexts + * @return all column segments + */ + public static Collection getColumnSegments(final WhereAvailable whereAvailable, final Collection allSubqueryContexts) { + Collection result = new LinkedList<>(whereAvailable.getColumnSegments()); + allSubqueryContexts.forEach(each -> result.addAll(each.getColumnSegments())); + return result; + } + + /** + * Get all join conditions. + * + * @param whereAvailable where available + * @param allSubqueryContexts all subquery contexts + * @return all join conditions + */ + public static Collection getJoinConditions(final WhereAvailable whereAvailable, final Collection allSubqueryContexts) { + Collection result = new LinkedList<>(whereAvailable.getJoinConditions()); + allSubqueryContexts.forEach(each -> result.addAll(each.getJoinConditions())); + return result; + } +} diff --git a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedCheckerTest.java b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedCheckerTest.java index 5983cf5d37037..1c6c1e87845fe 100644 --- a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedCheckerTest.java +++ b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedCheckerTest.java @@ -91,6 +91,8 @@ private SQLStatementContext mockSelectStatementContextWithLike() { when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user")); when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment)); when(result.getWhereSegments()).thenReturn(Collections.singleton(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, columnSegment, columnSegment, "LIKE", "")))); + when(result.getSubqueryContexts()).thenReturn(Collections.emptyMap()); + when(result.getJoinConditions()).thenReturn(Collections.emptyList()); return result; } @@ -107,6 +109,8 @@ private SQLStatementContext mockSelectStatementContextWithEqual() { when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user")); when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment)); when(result.getWhereSegments()).thenReturn(Collections.singleton(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, columnSegment, columnSegment, "=", "")))); + when(result.getSubqueryContexts()).thenReturn(Collections.emptyMap()); + when(result.getJoinConditions()).thenReturn(Collections.emptyList()); return result; } } diff --git a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/util/EncryptPredicateSegmentUtilsTest.java b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/util/EncryptPredicateSegmentUtilsTest.java new file mode 100644 index 0000000000000..37cc449212848 --- /dev/null +++ b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/util/EncryptPredicateSegmentUtilsTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.encrypt.rewrite.util; + +import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertSelectContext; +import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext; +import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment; +import org.junit.jupiter.api.Test; + +import java.util.Collection; +import java.util.Collections; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class EncryptPredicateSegmentUtilsTest { + + @Test + void assertGetAllSubqueryContextsForSelectStatement() { + SelectStatementContext selectStatementContext = mock(SelectStatementContext.class); + SelectStatementContext subquerySelectStatementContext = mock(SelectStatementContext.class); + when(subquerySelectStatementContext.getSubqueryContexts()).thenReturn(Collections.emptyMap()); + when(selectStatementContext.getSubqueryContexts()).thenReturn(Collections.singletonMap(0, subquerySelectStatementContext)); + Collection actual = EncryptPredicateSegmentUtils.getAllSubqueryContexts(selectStatementContext); + assertThat(actual.size(), is(1)); + } + + @Test + void assertGetAllSubqueryContextsForInsertStatement() { + InsertStatementContext insertStatementContext = mock(InsertStatementContext.class); + SelectStatementContext selectStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS); + when(selectStatementContext.getSubqueryContexts()).thenReturn(Collections.singletonMap(0, mock(SelectStatementContext.class))); + InsertSelectContext insertSelectContext = mock(InsertSelectContext.class); + when(insertSelectContext.getSelectStatementContext()).thenReturn(selectStatementContext); + when(insertStatementContext.getInsertSelectContext()).thenReturn(insertSelectContext); + Collection actual = EncryptPredicateSegmentUtils.getAllSubqueryContexts(insertStatementContext); + assertThat(actual.size(), is(2)); + } + + @Test + void assertGetWhereSegments() { + SelectStatementContext selectStatementContext = mock(SelectStatementContext.class); + SelectStatementContext subquerySelectStatementContext = mock(SelectStatementContext.class); + when(selectStatementContext.getWhereSegments()).thenReturn(Collections.singleton(mock(WhereSegment.class))); + when(subquerySelectStatementContext.getWhereSegments()).thenReturn(Collections.singleton(mock(WhereSegment.class))); + Collection actual = EncryptPredicateSegmentUtils.getWhereSegments(selectStatementContext, Collections.singleton(subquerySelectStatementContext)); + assertThat(actual.size(), is(2)); + } + + @Test + void assertGetColumnSegment() { + SelectStatementContext selectStatementContext = mock(SelectStatementContext.class); + SelectStatementContext subquerySelectStatementContext = mock(SelectStatementContext.class); + when(selectStatementContext.getColumnSegments()).thenReturn(Collections.singleton(mock(ColumnSegment.class))); + when(subquerySelectStatementContext.getColumnSegments()).thenReturn(Collections.singleton(mock(ColumnSegment.class))); + Collection actual = EncryptPredicateSegmentUtils.getColumnSegments(selectStatementContext, Collections.singleton(subquerySelectStatementContext)); + assertThat(actual.size(), is(2)); + } + + @Test + void assertGetJoinConditions() { + SelectStatementContext selectStatementContext = mock(SelectStatementContext.class); + SelectStatementContext subquerySelectStatementContext = mock(SelectStatementContext.class); + when(selectStatementContext.getJoinConditions()).thenReturn(Collections.singleton(mock(BinaryOperationExpression.class))); + when(subquerySelectStatementContext.getJoinConditions()).thenReturn(Collections.singleton(mock(BinaryOperationExpression.class))); + Collection actual = EncryptPredicateSegmentUtils.getJoinConditions(selectStatementContext, Collections.singleton(subquerySelectStatementContext)); + assertThat(actual.size(), is(2)); + } +} diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java index 264c03ac33799..7d00fa3470d75 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java @@ -28,6 +28,7 @@ import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext; import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext; import org.apache.shardingsphere.infra.binder.context.type.TableAvailable; +import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable; import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry; import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions; import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.NoDatabaseSelectedException; @@ -42,8 +43,10 @@ import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.OnDuplicateKeyColumnsSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.combine.CombineSegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubquerySegment; +import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment; import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.InsertStatement; import org.apache.shardingsphere.sql.parser.statement.core.util.TableExtractor; @@ -62,14 +65,12 @@ /** * Insert SQL statement context. */ -public final class InsertStatementContext extends CommonSQLStatementContext implements TableAvailable, ParameterAware { +public final class InsertStatementContext extends CommonSQLStatementContext implements TableAvailable, ParameterAware, WhereAvailable { private final ShardingSphereMetaData metaData; private final String currentDatabaseName; - private final List insertColumnNames; - private final Map insertColumnNamesAndIndexes; private final List> valueExpressions; @@ -95,13 +96,13 @@ public InsertStatementContext(final ShardingSphereMetaData metaData, final List< super(sqlStatement); this.metaData = metaData; this.currentDatabaseName = currentDatabaseName; - insertColumnNames = getInsertColumnNames(); valueExpressions = getAllValueExpressions(sqlStatement); AtomicInteger parametersOffset = new AtomicInteger(0); insertValueContexts = getInsertValueContexts(params, parametersOffset, valueExpressions); insertSelectContext = getInsertSelectContext(metaData, params, parametersOffset, currentDatabaseName).orElse(null); onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null); tablesContext = new TablesContext(getAllSimpleTableSegments(), getDatabaseType(), currentDatabaseName); + List insertColumnNames = getInsertColumnNames(); ShardingSphereSchema schema = getSchema(metaData, currentDatabaseName); columnNames = containsInsertColumns() ? insertColumnNames : sqlStatement.getTable().map(optional -> schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList); @@ -314,4 +315,19 @@ public void setUpParameters(final List params) { ShardingSphereSchema schema = getSchema(metaData, currentDatabaseName); generatedKeyContext = new GeneratedKeyContextEngine(getSqlStatement(), schema).createGenerateKeyContext(insertColumnNamesAndIndexes, insertValueContexts, params).orElse(null); } + + @Override + public Collection getWhereSegments() { + return null == insertSelectContext ? Collections.emptyList() : insertSelectContext.getSelectStatementContext().getWhereSegments(); + } + + @Override + public Collection getColumnSegments() { + return null == insertSelectContext ? Collections.emptyList() : insertSelectContext.getSelectStatementContext().getColumnSegments(); + } + + @Override + public Collection getJoinConditions() { + return null == insertSelectContext ? Collections.emptyList() : insertSelectContext.getSelectStatementContext().getJoinConditions(); + } } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java index 2fe420260efe8..9c36de7cae642 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java @@ -191,7 +191,7 @@ private Map createSubqueryContexts(final Shardi Map result = new HashMap<>(subquerySegments.size(), 1F); for (SubquerySegment each : subquerySegments) { SelectStatementContext subqueryContext = new SelectStatementContext(metaData, params, each.getSelect(), currentDatabaseName, tableSegments); - subqueryContext.setSubqueryType(each.getSubqueryType()); + each.getSelect().getSubqueryType().ifPresent(subqueryContext::setSubqueryType); result.put(each.getStartIndex(), subqueryContext); } return result; diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/combine/CombineSegmentBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/combine/CombineSegmentBinder.java index a21b4a8bea17e..a8c9d3538df67 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/combine/CombineSegmentBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/combine/CombineSegmentBinder.java @@ -51,11 +51,9 @@ public static CombineSegment bind(final CombineSegment segment, final SQLStateme private static SubquerySegment bindSubquerySegment(final SubquerySegment segment, final SQLStatementBinderContext binderContext, final Multimap outerTableBinderContexts) { SubquerySegment result = new SubquerySegment(segment.getStartIndex(), segment.getStopIndex(), segment.getText()); - result.setSubqueryType(segment.getSubqueryType()); SQLStatementBinderContext subqueryBinderContext = new SQLStatementBinderContext(segment.getSelect(), binderContext.getMetaData(), binderContext.getCurrentDatabaseName()); subqueryBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts()); result.setSelect(new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSelect(), subqueryBinderContext)); return result; } - } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/SubquerySegmentBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/SubquerySegmentBinder.java index e986c0a0cfd57..bcefdbd14f9bc 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/SubquerySegmentBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/SubquerySegmentBinder.java @@ -46,8 +46,6 @@ public static SubquerySegment bind(final SubquerySegment segment, final SQLState SQLStatementBinderContext selectBinderContext = new SQLStatementBinderContext(segment.getSelect(), binderContext.getMetaData(), binderContext.getCurrentDatabaseName()); selectBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts()); SelectStatement boundSelectStatement = new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSelect(), selectBinderContext); - SubquerySegment result = new SubquerySegment(segment.getStartIndex(), segment.getStopIndex(), boundSelectStatement, segment.getText()); - result.setSubqueryType(segment.getSubqueryType()); - return result; + return new SubquerySegment(segment.getStartIndex(), segment.getStopIndex(), boundSelectStatement, segment.getText()); } } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/from/type/SubqueryTableSegmentBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/from/type/SubqueryTableSegmentBinder.java index e1f8e17040233..fd056f567f1be 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/from/type/SubqueryTableSegmentBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/from/type/SubqueryTableSegmentBinder.java @@ -55,7 +55,6 @@ public static SubqueryTableSegment bind(final SubqueryTableSegment segment, fina subqueryBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts()); SelectStatement boundSubSelect = new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSubquery().getSelect(), subqueryBinderContext); SubquerySegment boundSubquerySegment = new SubquerySegment(segment.getSubquery().getStartIndex(), segment.getSubquery().getStopIndex(), boundSubSelect, segment.getSubquery().getText()); - boundSubquerySegment.setSubqueryType(segment.getSubquery().getSubqueryType()); IdentifierValue subqueryTableName = segment.getAliasSegment().map(AliasSegment::getIdentifier).orElseGet(() -> new IdentifierValue("")); SubqueryTableSegment result = new SubqueryTableSegment(segment.getStartIndex(), segment.getStopIndex(), boundSubquerySegment); segment.getAliasSegment().ifPresent(result::setAlias); diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/statement/dml/SelectStatementBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/statement/dml/SelectStatementBinder.java index be63bc0c7a815..47c8357c87dc8 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/statement/dml/SelectStatementBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/statement/dml/SelectStatementBinder.java @@ -70,6 +70,7 @@ private SelectStatement copy(final SelectStatement sqlStatement) { sqlStatement.getLimit().ifPresent(result::setLimit); sqlStatement.getWindow().ifPresent(result::setWindow); sqlStatement.getModelSegment().ifPresent(result::setModelSegment); + sqlStatement.getSubqueryType().ifPresent(result::setSubqueryType); sqlStatement.getWithSegment().ifPresent(result::setWithSegment); result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments()); result.getCommentSegments().addAll(sqlStatement.getCommentSegments()); diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/ExistsSubqueryExpressionBinderTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/ExistsSubqueryExpressionBinderTest.java index 69b32b38daad7..ff89350b51902 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/ExistsSubqueryExpressionBinderTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/ExistsSubqueryExpressionBinderTest.java @@ -48,7 +48,6 @@ void assertBindExistsSubqueryExpression() { assertThat(actual.getText(), is("t_test")); assertThat(actual.getSubquery().getStartIndex(), is(existsSubqueryExpression.getSubquery().getStartIndex())); assertThat(actual.getSubquery().getStopIndex(), is(existsSubqueryExpression.getSubquery().getStopIndex())); - assertThat(actual.getSubquery().getSubqueryType(), is(existsSubqueryExpression.getSubquery().getSubqueryType())); assertThat(actual.getSubquery().getText(), is("t_test")); assertThat(actual.getSubquery().getSelect().getDatabaseType(), is(existsSubqueryExpression.getSubquery().getSelect().getDatabaseType())); } diff --git a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/projection/impl/SubqueryProjectionConverter.java b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/projection/impl/SubqueryProjectionConverter.java index 231c4150bbfdb..2df7a98a526c6 100644 --- a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/projection/impl/SubqueryProjectionConverter.java +++ b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/projection/impl/SubqueryProjectionConverter.java @@ -54,7 +54,7 @@ public static Optional convert(final SubqueryProjectionSegment segment) if (segment.getAliasName().isPresent()) { sqlNode = convertWithAlias(sqlNode, segment.getAliasName().get()); } - return SubqueryType.EXISTS == segment.getSubquery().getSubqueryType() + return segment.getSubquery().getSelect().getSubqueryType().map(optional -> optional == SubqueryType.EXISTS).orElse(false) ? Optional.of(new SqlBasicCall(SqlStdOperatorTable.EXISTS, Collections.singletonList(sqlNode), SqlParserPos.ZERO)) : Optional.of(sqlNode); } diff --git a/parser/sql/dialect/doris/src/main/java/org/apache/shardingsphere/sql/parser/doris/visitor/statement/DorisStatementVisitor.java b/parser/sql/dialect/doris/src/main/java/org/apache/shardingsphere/sql/parser/doris/visitor/statement/DorisStatementVisitor.java index 5bbb10f0ae5a0..1fc3967231b7e 100644 --- a/parser/sql/dialect/doris/src/main/java/org/apache/shardingsphere/sql/parser/doris/visitor/statement/DorisStatementVisitor.java +++ b/parser/sql/dialect/doris/src/main/java/org/apache/shardingsphere/sql/parser/doris/visitor/statement/DorisStatementVisitor.java @@ -606,7 +606,7 @@ public final ASTNode visitSimpleExpr(final SimpleExprContext ctx) { if (null == ctx.EXISTS()) { return new SubqueryExpressionSegment(subquerySegment); } - subquerySegment.setSubqueryType(SubqueryType.EXISTS); + subquerySegment.getSelect().setSubqueryType(SubqueryType.EXISTS); return new ExistsSubqueryExpression(startIndex, stopIndex, subquerySegment); } if (null != ctx.parameterMarker()) { diff --git a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java index 56dfda8057462..d058633429a6f 100644 --- a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java +++ b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java @@ -603,7 +603,7 @@ public final ASTNode visitSimpleExpr(final SimpleExprContext ctx) { SubquerySegment subquerySegment = new SubquerySegment( ctx.subquery().getStart().getStartIndex(), ctx.subquery().getStop().getStopIndex(), (MySQLSelectStatement) visit(ctx.subquery()), getOriginalText(ctx.subquery())); if (null != ctx.EXISTS()) { - subquerySegment.setSubqueryType(SubqueryType.EXISTS); + subquerySegment.getSelect().setSubqueryType(SubqueryType.EXISTS); return new ExistsSubqueryExpression(startIndex, stopIndex, subquerySegment); } return new SubqueryExpressionSegment(subquerySegment); diff --git a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java index f7785529e07aa..266d7d1fca7e9 100644 --- a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java +++ b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java @@ -422,7 +422,7 @@ private ExpressionSegment createSubqueryExpressionSegment(final CExprContext ctx SubquerySegment subquerySegment = new SubquerySegment(ctx.selectWithParens().getStart().getStartIndex(), ctx.selectWithParens().getStop().getStopIndex(), (OpenGaussSelectStatement) visit(ctx.selectWithParens()), getOriginalText(ctx.selectWithParens())); if (null != ctx.EXISTS()) { - subquerySegment.setSubqueryType(SubqueryType.EXISTS); + subquerySegment.getSelect().setSubqueryType(SubqueryType.EXISTS); return new ExistsSubqueryExpression(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), subquerySegment); } return new SubqueryExpressionSegment(subquerySegment); diff --git a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java index 908961e3fe014..123cd9fc4ef78 100644 --- a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java +++ b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java @@ -423,7 +423,7 @@ private ExpressionSegment createSubqueryExpressionSegment(final CExprContext ctx SubquerySegment subquerySegment = new SubquerySegment(ctx.selectWithParens().getStart().getStartIndex(), ctx.selectWithParens().getStop().getStopIndex(), (PostgreSQLSelectStatement) visit(ctx.selectWithParens()), getOriginalText(ctx.selectWithParens())); if (null != ctx.EXISTS()) { - subquerySegment.setSubqueryType(SubqueryType.EXISTS); + subquerySegment.getSelect().setSubqueryType(SubqueryType.EXISTS); return new ExistsSubqueryExpression(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), subquerySegment); } return new SubqueryExpressionSegment(subquerySegment); diff --git a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/segment/dml/expr/subquery/SubquerySegment.java b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/segment/dml/expr/subquery/SubquerySegment.java index c83c2bdf6cac0..630485d58347f 100644 --- a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/segment/dml/expr/subquery/SubquerySegment.java +++ b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/segment/dml/expr/subquery/SubquerySegment.java @@ -20,7 +20,6 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.Setter; -import org.apache.shardingsphere.sql.parser.statement.core.enums.SubqueryType; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment; import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.MergeStatement; import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectStatement; @@ -44,9 +43,6 @@ public final class SubquerySegment implements ExpressionSegment { private final String text; - @Setter - private SubqueryType subqueryType; - public SubquerySegment(final int startIndex, final int stopIndex, final SelectStatement select, final String text) { this.startIndex = startIndex; this.stopIndex = stopIndex; diff --git a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/statement/dml/SelectStatement.java b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/statement/dml/SelectStatement.java index 7254c28a45dcf..33b35962c3d5d 100644 --- a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/statement/dml/SelectStatement.java +++ b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/statement/dml/SelectStatement.java @@ -19,6 +19,7 @@ import lombok.Getter; import lombok.Setter; +import org.apache.shardingsphere.sql.parser.statement.core.enums.SubqueryType; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.combine.CombineSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionsSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.GroupBySegment; @@ -58,6 +59,8 @@ public abstract class SelectStatement extends AbstractSQLStatement implements DM private WithSegment withSegment; + private SubqueryType subqueryType; + /** * Get from. * @@ -121,6 +124,15 @@ public Optional getWithSegment() { return Optional.ofNullable(withSegment); } + /** + * Get subquery type. + * + * @return subquery type + */ + public Optional getSubqueryType() { + return Optional.ofNullable(subqueryType); + } + /** * Get limit segment. * diff --git a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java index 13716c2a2d79e..af880834f4fdd 100644 --- a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java +++ b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java @@ -65,18 +65,19 @@ public final class SubqueryExtractUtils { */ public static Collection getSubquerySegments(final SelectStatement selectStatement, final boolean needRecursive) { List result = new LinkedList<>(); - extractSubquerySegments(result, selectStatement, needRecursive); + SubqueryType parentSubqueryType = selectStatement.getSubqueryType().orElse(null); + extractSubquerySegments(result, selectStatement, needRecursive, parentSubqueryType); return result; } - private static void extractSubquerySegments(final List result, final SelectStatement selectStatement, final boolean needRecursive) { + private static void extractSubquerySegments(final List result, final SelectStatement selectStatement, final boolean needRecursive, final SubqueryType parentSubqueryType) { extractSubquerySegmentsFromProjections(result, selectStatement.getProjections(), needRecursive); selectStatement.getFrom().ifPresent(optional -> extractSubquerySegmentsFromTableSegment(result, optional, needRecursive)); if (selectStatement.getWhere().isPresent()) { extractSubquerySegmentsFromWhere(result, selectStatement.getWhere().get().getExpr(), needRecursive); } if (selectStatement.getCombine().isPresent()) { - extractSubquerySegmentsFromCombine(result, selectStatement.getCombine().get(), needRecursive); + extractSubquerySegmentsFromCombine(result, selectStatement.getCombine().get(), needRecursive, parentSubqueryType); } if (selectStatement.getWithSegment().isPresent()) { extractSubquerySegmentsFromCTEs(result, selectStatement.getWithSegment().get().getCommonTableExpressions(), needRecursive); @@ -85,15 +86,15 @@ private static void extractSubquerySegments(final List result, private static void extractSubquerySegmentsFromCTEs(final List result, final Collection withSegment, final boolean needRecursive) { for (CommonTableExpressionSegment each : withSegment) { - each.getSubquery().setSubqueryType(SubqueryType.WITH); + each.getSubquery().getSelect().setSubqueryType(SubqueryType.WITH); result.add(each.getSubquery()); - extractRecursive(needRecursive, result, each.getSubquery().getSelect()); + extractRecursive(needRecursive, result, each.getSubquery().getSelect(), SubqueryType.TABLE); } } - private static void extractRecursive(final boolean needRecursive, final List result, final SelectStatement select) { + private static void extractRecursive(final boolean needRecursive, final List result, final SelectStatement select, final SubqueryType parentSubqueryType) { if (needRecursive) { - extractSubquerySegments(result, select, true); + extractSubquerySegments(result, select, true, parentSubqueryType); } } @@ -104,9 +105,9 @@ private static void extractSubquerySegmentsFromProjections(final List result, final TableSegment tableSegment, final boolean needRecursive) { if (tableSegment instanceof SubqueryTableSegment) { SubquerySegment subquery = ((SubqueryTableSegment) tableSegment).getSubquery(); - subquery.setSubqueryType(SubqueryType.JOIN); + subquery.getSelect().setSubqueryType(SubqueryType.JOIN); result.add(subquery); - extractRecursive(needRecursive, result, subquery.getSelect()); + extractRecursive(needRecursive, result, subquery.getSelect(), SubqueryType.TABLE); } else if (tableSegment instanceof JoinTableSegment) { extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getLeft(), needRecursive); extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getRight(), needRecursive); @@ -137,9 +138,9 @@ private static void extractSubquerySegmentsFromJoinTableSegment(final List result, final SubqueryTableSegment subqueryTableSegment, final boolean needRecursive) { SubquerySegment subquery = subqueryTableSegment.getSubquery(); - subquery.setSubqueryType(SubqueryType.TABLE); + subquery.getSelect().setSubqueryType(SubqueryType.TABLE); result.add(subquery); - extractRecursive(needRecursive, result, subquery.getSelect()); + extractRecursive(needRecursive, result, subquery.getSelect(), SubqueryType.TABLE); } private static void extractSubquerySegmentsFromWhere(final List result, final ExpressionSegment expressionSegment, final boolean needRecursive) { @@ -150,15 +151,15 @@ private static void extractSubquerySegmentsFromExpression(final List extractSubquerySegmentsFromExpression(result, each, subqueryType, needRecursive)); @@ -210,10 +211,13 @@ private static void extractSubquerySegmentsFromCaseWhenExpression(final List result, final CombineSegment combineSegment, final boolean needRecursive) { + private static void extractSubquerySegmentsFromCombine(final List result, final CombineSegment combineSegment, final boolean needRecursive, + final SubqueryType parentSubqueryType) { + combineSegment.getLeft().getSelect().setSubqueryType(parentSubqueryType); + combineSegment.getRight().getSelect().setSubqueryType(parentSubqueryType); result.add(combineSegment.getLeft()); result.add(combineSegment.getRight()); - extractRecursive(needRecursive, result, combineSegment.getLeft().getSelect()); - extractRecursive(needRecursive, result, combineSegment.getRight().getSelect()); + extractRecursive(needRecursive, result, combineSegment.getLeft().getSelect(), parentSubqueryType); + extractRecursive(needRecursive, result, combineSegment.getRight().getSelect(), parentSubqueryType); } } diff --git a/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-cipher/dml/select/select-subquery.xml b/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-cipher/dml/select/select-subquery.xml index 6d9e381ba956e..5fd447c1a666b 100644 --- a/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-cipher/dml/select/select-subquery.xml +++ b/test/it/rewriter/src/test/resources/scenario/encrypt/case/query-with-cipher/dml/select/select-subquery.xml @@ -24,32 +24,32 @@ - + - + - + - + - + - + @@ -64,7 +64,7 @@ - + @@ -89,7 +89,7 @@ - + @@ -99,6 +99,6 @@ - + diff --git a/test/it/rewriter/src/test/resources/scenario/mix/case/query-with-cipher/dml/select/select-subquery.xml b/test/it/rewriter/src/test/resources/scenario/mix/case/query-with-cipher/dml/select/select-subquery.xml index 90828aac50a2c..aa6eb09503456 100644 --- a/test/it/rewriter/src/test/resources/scenario/mix/case/query-with-cipher/dml/select/select-subquery.xml +++ b/test/it/rewriter/src/test/resources/scenario/mix/case/query-with-cipher/dml/select/select-subquery.xml @@ -19,25 +19,25 @@ - - + + - - + + - - + + - - + +