Skip to content

Commit

Permalink
Extract subquery context for encrypt predicate rewrite and refactor s…
Browse files Browse the repository at this point in the history
…ubquery type to SelectStatement
  • Loading branch information
strongduanmu committed Nov 18, 2024
1 parent 5a85a20 commit 5819b41
Show file tree
Hide file tree
Showing 28 changed files with 383 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

import org.apache.shardingsphere.encrypt.exception.metadata.MissingMatchedEncryptQueryAlgorithmException;
import org.apache.shardingsphere.encrypt.rewrite.token.comparator.JoinConditionsEncryptorComparator;
import org.apache.shardingsphere.encrypt.rewrite.util.EncryptPredicateSegmentUtils;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.checker.SupportedSQLChecker;
Expand Down Expand Up @@ -53,7 +55,9 @@ public boolean isCheck(final SQLStatementContext sqlStatementContext) {

@Override
public void check(final EncryptRule rule, final ShardingSphereSchema schema, final SQLStatementContext sqlStatementContext) {
ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(((WhereAvailable) sqlStatementContext).getJoinConditions(), rule),
Collection<SelectStatementContext> allSubqueryContexts = EncryptPredicateSegmentUtils.getAllSubqueryContexts(sqlStatementContext);
Collection<BinaryOperationExpression> joinConditions = EncryptPredicateSegmentUtils.getJoinConditions((WhereAvailable) sqlStatementContext, allSubqueryContexts);
ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(joinConditions, rule),
() -> new UnsupportedSQLOperationException("Can not use different encryptor in join condition"));
check(rule, schema, (WhereAvailable) sqlStatementContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptConditionEngine;
import org.apache.shardingsphere.encrypt.rewrite.parameter.EncryptParameterRewritersRegistry;
import org.apache.shardingsphere.encrypt.rewrite.token.EncryptTokenGenerateBuilder;
import org.apache.shardingsphere.encrypt.rewrite.util.EncryptPredicateSegmentUtils;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
Expand Down Expand Up @@ -88,8 +90,9 @@ private Collection<EncryptCondition> doCreateEncryptConditions(final EncryptRule
if (!(sqlStatementContext instanceof WhereAvailable)) {
return Collections.emptyList();
}
Collection<WhereSegment> whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments();
Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
Collection<SelectStatementContext> allSubqueryContexts = EncryptPredicateSegmentUtils.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = EncryptPredicateSegmentUtils.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = EncryptPredicateSegmentUtils.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase().getSchemas()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext,
sqlRewriteContext.getDatabase().getName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
Expand All @@ -51,7 +53,25 @@ public final class EncryptPredicateParameterRewriter implements ParameterRewrite

@Override
public boolean isNeedRewrite(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty();
if (sqlStatementContext instanceof WhereAvailable && !((WhereAvailable) sqlStatementContext).getWhereSegments().isEmpty()) {
return true;
}
if (sqlStatementContext instanceof SelectStatementContext) {
return isSubqueryNeedRewrite((SelectStatementContext) sqlStatementContext);
}
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) {
return isSubqueryNeedRewrite(((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext());
}
return false;
}

private boolean isSubqueryNeedRewrite(final SelectStatementContext selectStatementContext) {
for (SelectStatementContext each : selectStatementContext.getSubqueryContexts().values()) {
if (isNeedRewrite(each)) {
return true;
}
}
return false;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Preconditions;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.util.EncryptPredicateSegmentUtils;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
import org.apache.shardingsphere.encrypt.rule.column.item.LikeQueryColumnItem;
Expand All @@ -28,6 +29,7 @@
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
Expand Down Expand Up @@ -72,8 +74,9 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)

@Override
public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
Collection<WhereSegment> whereSegments = ((WhereAvailable) sqlStatementContext).getWhereSegments();
Collection<SelectStatementContext> allSubqueryContexts = EncryptPredicateSegmentUtils.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = EncryptPredicateSegmentUtils.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = EncryptPredicateSegmentUtils.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema);
Map<String, String> columnExpressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema);
return generateSQLTokens(columnSegments, columnExpressionTableNames, whereSegments, sqlStatementContext.getDatabaseType());
Expand Down
Loading

0 comments on commit 5819b41

Please sign in to comment.