Skip to content

Commit

Permalink
Merge pull request #33704 from strongduanmu/dev-1118
Browse files Browse the repository at this point in the history
Extract subquery context for encrypt predicate rewrite and refactor subquery type to SelectStatement
  • Loading branch information
iamhucong authored Nov 18, 2024
2 parents e878d61 + 5819b41 commit 3de68f4
Show file tree
Hide file tree
Showing 28 changed files with 383 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

import org.apache.shardingsphere.encrypt.exception.metadata.MissingMatchedEncryptQueryAlgorithmException;
import org.apache.shardingsphere.encrypt.rewrite.token.comparator.JoinConditionsEncryptorComparator;
import org.apache.shardingsphere.encrypt.rewrite.util.EncryptPredicateSegmentUtils;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.checker.SupportedSQLChecker;
Expand Down Expand Up @@ -53,7 +55,9 @@ public boolean isCheck(final SQLStatementContext sqlStatementContext) {

@Override
public void check(final EncryptRule rule, final ShardingSphereSchema schema, final SQLStatementContext sqlStatementContext) {
ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(((WhereAvailable) sqlStatementContext).getJoinConditions(), rule),
Collection<SelectStatementContext> allSubqueryContexts = EncryptPredicateSegmentUtils.getAllSubqueryContexts(sqlStatementContext);
Collection<BinaryOperationExpression> joinConditions = EncryptPredicateSegmentUtils.getJoinConditions((WhereAvailable) sqlStatementContext, allSubqueryContexts);
ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(joinConditions, rule),
() -> new UnsupportedSQLOperationException("Can not use different encryptor in join condition"));
check(rule, schema, (WhereAvailable) sqlStatementContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptConditionEngine;
import org.apache.shardingsphere.encrypt.rewrite.parameter.EncryptParameterRewritersRegistry;
import org.apache.shardingsphere.encrypt.rewrite.token.EncryptTokenGenerateBuilder;
import org.apache.shardingsphere.encrypt.rewrite.util.EncryptPredicateSegmentUtils;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
Expand Down Expand Up @@ -88,8 +90,9 @@ private Collection<EncryptCondition> doCreateEncryptConditions(final EncryptRule
if (!(sqlStatementContext instanceof WhereAvailable)) {
return Collections.emptyList();
}
Collection<WhereSegment> whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments();
Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
Collection<SelectStatementContext> allSubqueryContexts = EncryptPredicateSegmentUtils.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = EncryptPredicateSegmentUtils.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = EncryptPredicateSegmentUtils.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase().getSchemas()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext,
sqlRewriteContext.getDatabase().getName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
Expand All @@ -51,7 +53,25 @@ public final class EncryptPredicateParameterRewriter implements ParameterRewrite

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
if (sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty()) {
return true;
}
if (sqlStatementContext instanceof SelectStatementContext) {
return isSubqueryNeedRewrite((SelectStatementContext) sqlStatementContext);
}
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) {
return isSubqueryNeedRewrite(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
return false;
}

private boolean isSubqueryNeedRewrite(final SelectStatementContext selectStatementContext) {
for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) {
if (isNeedRewrite(each)) {
return true;
}
}
return false;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Preconditions;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.util.EncryptPredicateSegmentUtils;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.encrypt.rule.column.item.LikeQueryColumnItem;
Expand All @@ -28,6 +29,7 @@
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
Expand Down Expand Up @@ -72,8 +74,9 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)

@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
Collection<WhereSegment> whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments();
Collection<SelectStatementContext> allSubqueryContexts = EncryptPredicateSegmentUtils.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = EncryptPredicateSegmentUtils.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = EncryptPredicateSegmentUtils.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema);
Map<String, String> columnExpressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema);
return generateSQLTokens(columnSegments, columnExpressionTableNames, whereSegments, sqlStatementContext.getDatabaseType());
Expand Down
Loading

0 comments on commit 3de68f4

Please sign in to comment.