Skip to content

Commit

Permalink
Support like concat nested concat statement rewrite with encrypt feat…
Browse files Browse the repository at this point in the history
…ure (#32970)
  • Loading branch information
strongduanmu authored Sep 24, 2024
1 parent 184c4f9 commit 13691f4
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ private SQLToken generateSQLToken(final String schemaName, final EncryptTable en
encryptCondition.getPositionValueMap().keySet(), getEncryptedValues(schemaName, encryptTable, encryptCondition, encryptCondition.getValues(parameters)));
Collection<Integer> parameterMarkerIndexes = encryptCondition.getPositionIndexMap().keySet();
if (encryptCondition instanceof EncryptBinaryCondition && ((EncryptBinaryCondition) encryptCondition).getExpressionSegment() instanceof FunctionSegment) {
return new EncryptPredicateFunctionRightValueToken(startIndex, stopIndex,
((FunctionSegment) ((EncryptBinaryCondition) encryptCondition).getExpressionSegment()).getFunctionName(), indexValues, parameterMarkerIndexes);
FunctionSegment functionSegment = (FunctionSegment) ((EncryptBinaryCondition) encryptCondition).getExpressionSegment();
return new EncryptPredicateFunctionRightValueToken(startIndex, stopIndex, functionSegment.getFunctionName(), functionSegment.getParameters(), indexValues, parameterMarkerIndexes);
}
return encryptCondition instanceof EncryptInCondition
? new EncryptPredicateInRightValueToken(startIndex, stopIndex, indexValues, parameterMarkerIndexes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,52 +20,76 @@
import lombok.Getter;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.Substitutable;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;

import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;

/**
* Predicate in right value token for encrypt.
*/
public final class EncryptPredicateFunctionRightValueToken extends SQLToken implements Substitutable {

private static final String COMMA_SEPARATOR = ", ";

@Getter
private final int stopIndex;

private final String functionName;

private final Collection<ExpressionSegment> parameters;

private final Map<Integer, Object> indexValues;

private final Collection<Integer> paramMarkerIndexes;

public EncryptPredicateFunctionRightValueToken(final int startIndex, final int stopIndex, final String functionName,
public EncryptPredicateFunctionRightValueToken(final int startIndex, final int stopIndex, final String functionName, final Collection<ExpressionSegment> parameters,
final Map<Integer, Object> indexValues, final Collection<Integer> paramMarkerIndexes) {
super(startIndex);
this.stopIndex = stopIndex;
this.functionName = functionName;
this.parameters = parameters;
this.indexValues = indexValues;
this.paramMarkerIndexes = paramMarkerIndexes;
}

@Override
public String toString() {
StringBuilder result = new StringBuilder();
result.append(functionName).append(" (");
for (int i = 0; i < indexValues.size() + paramMarkerIndexes.size(); i++) {
if (paramMarkerIndexes.contains(i)) {
result.append('?');
AtomicInteger parameterIndex = new AtomicInteger();
appendFunctionSegment(functionName, parameters, result, parameterIndex);
return result.toString();
}

private void appendFunctionSegment(final String functionName, final Collection<ExpressionSegment> parameters, final StringBuilder builder, final AtomicInteger parameterIndex) {
builder.append(functionName).append(" (");
for (ExpressionSegment each : parameters) {
if (each instanceof FunctionSegment) {
appendFunctionSegment(((FunctionSegment) each).getFunctionName(), ((FunctionSegment) each).getParameters(), builder, parameterIndex);
} else {
if (indexValues.get(i) instanceof String) {
result.append('\'').append(indexValues.get(i)).append('\'');
} else {
result.append(indexValues.get(i));
}
appendRewrittenParameters(builder, parameterIndex.getAndIncrement());
}
result.append(", ");
}
result.delete(result.length() - 2, result.length()).append(')');
return result.toString();
if (builder.toString().endsWith(COMMA_SEPARATOR)) {
builder.delete(builder.length() - 2, builder.length());
}
builder.append(')');
}

private void appendRewrittenParameters(final StringBuilder builder, final int parameterIndex) {
if (paramMarkerIndexes.contains(parameterIndex)) {
builder.append('?');
} else {
if (indexValues.get(parameterIndex) instanceof String) {
builder.append('\'').append(indexValues.get(parameterIndex)).append('\'');
} else {
builder.append(indexValues.get(parameterIndex));
}
}
builder.append(COMMA_SEPARATOR);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptPredicateFunctionRightValueToken;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment;
import org.junit.jupiter.api.Test;

import java.util.Collections;
Expand All @@ -31,13 +32,34 @@
class EncryptPredicateFunctionRightValueTokenTest {

@Test
void assertToStringWithoutPlaceholderWithoutTableOwnerWithFunction() {
void assertToStringWithSimpleFunction() {
Map<Integer, Object> indexValues = new LinkedHashMap<>(3, 1F);
indexValues.put(0, "%");
indexValues.put(1, "abc");
indexValues.put(2, "%");
FunctionSegment functionSegment = new FunctionSegment(0, 0, "CONCAT", "('%','abc','%')");
EncryptPredicateFunctionRightValueToken actual = new EncryptPredicateFunctionRightValueToken(0, 0, functionSegment.getFunctionName(), indexValues, Collections.emptyList());
functionSegment.getParameters().add(new LiteralExpressionSegment(0, 0, "%"));
functionSegment.getParameters().add(new LiteralExpressionSegment(0, 0, "abc"));
functionSegment.getParameters().add(new LiteralExpressionSegment(0, 0, "%"));
EncryptPredicateFunctionRightValueToken actual =
new EncryptPredicateFunctionRightValueToken(0, 0, functionSegment.getFunctionName(), functionSegment.getParameters(), indexValues, Collections.emptyList());
assertThat(actual.toString(), is("CONCAT ('%', 'abc', '%')"));
}

@Test
void assertToStringWithNestedFunction() {
Map<Integer, Object> indexValues = new LinkedHashMap<>(3, 1F);
indexValues.put(0, "%");
indexValues.put(1, "abc");
indexValues.put(2, "%");
FunctionSegment functionSegment = new FunctionSegment(0, 0, "CONCAT", "('%',CONCAT('abc','%'))");
functionSegment.getParameters().add(new LiteralExpressionSegment(0, 0, "%"));
FunctionSegment nestedFunctionSegment = new FunctionSegment(0, 0, "CONCAT", "('abc','%')");
nestedFunctionSegment.getParameters().add(new LiteralExpressionSegment(0, 0, "abc"));
nestedFunctionSegment.getParameters().add(new LiteralExpressionSegment(0, 0, "%"));
functionSegment.getParameters().add(nestedFunctionSegment);
EncryptPredicateFunctionRightValueToken actual =
new EncryptPredicateFunctionRightValueToken(0, 0, functionSegment.getFunctionName(), functionSegment.getParameters(), indexValues, Collections.emptyList());
assertThat(actual.toString(), is("CONCAT ('%', CONCAT ('abc', '%'))"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,8 @@
<test-case sql="SELECT * FROM t_merchant WHERE business_code LIKE CONCAT('%', ?, '%')" db-types="MySQL,PostgreSQL,openGauss" scenario-types="encrypt">
<assertion parameters="abc:String" expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT * FROM t_merchant WHERE business_code LIKE CONCAT('%', CONCAT(?, '%'))" db-types="MySQL,PostgreSQL,openGauss" scenario-types="encrypt">
<assertion parameters="abc:String" expected-data-source-name="read_dataset" />
</test-case>
</e2e-test-cases>
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,25 @@
<input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_bak AS a WHERE a.amount in (?, ?)" parameters="1000, 2000" />
<output sql="SELECT a.account_id, a.cipher_password AS password, a.cipher_amount AS a, a.status AS s FROM t_account_bak AS a WHERE a.cipher_amount in (?, ?)" parameters="encrypt_1000, encrypt_2000" />
</rewrite-assertion>

<rewrite-assertion id="select_where_with_cipher_column_like_concat_for_literals" db-types="PostgreSQL,openGauss">
<input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_bak AS a WHERE a.account_id = 1 AND a.certificate_number like concat('%','abc','%')" />
<output sql="SELECT a.account_id, a.cipher_password AS password, a.cipher_amount AS a, a.status AS s FROM t_account_bak AS a WHERE a.account_id = 1 AND a.like_query_certificate_number like concat ('like_query_%', 'like_query_abc', 'like_query_%')" />

<rewrite-assertion id="select_where_with_cipher_column_like_concat_for_parameters" db-types="MySQL,PostgreSQL,openGauss">
<input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_bak a WHERE a.account_id = 1 AND a.certificate_number like concat ('%', ? ,'%')" parameters="abc" />
<output sql="SELECT a.account_id, a.cipher_password AS password, a.cipher_amount AS a, a.status AS s FROM t_account_bak a WHERE a.account_id = 1 AND a.like_query_certificate_number like concat ('like_query_%', ?, 'like_query_%')" parameters="like_query_abc" />
</rewrite-assertion>

<rewrite-assertion id="select_where_with_cipher_column_like_concat_for_literals" db-types="MySQL,PostgreSQL,openGauss">
<input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_bak a WHERE a.account_id = 1 AND a.certificate_number like concat ('%','abc','%')" />
<output sql="SELECT a.account_id, a.cipher_password AS password, a.cipher_amount AS a, a.status AS s FROM t_account_bak a WHERE a.account_id = 1 AND a.like_query_certificate_number like concat ('like_query_%', 'like_query_abc', 'like_query_%')" />
</rewrite-assertion>

<rewrite-assertion id="select_where_with_cipher_column_like_nested_concat_for_parameters" db-types="MySQL,PostgreSQL,openGauss">
<input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_bak a WHERE a.account_id = 1 AND a.certificate_number like concat ('%', concat(?, '%'))" parameters="abc" />
<output sql="SELECT a.account_id, a.cipher_password AS password, a.cipher_amount AS a, a.status AS s FROM t_account_bak a WHERE a.account_id = 1 AND a.like_query_certificate_number like concat ('like_query_%', concat (?, 'like_query_%'))" parameters="like_query_abc" />
</rewrite-assertion>

<rewrite-assertion id="select_where_with_cipher_column_like_nested_concat_for_literals" db-types="MySQL,PostgreSQL,openGauss">
<input sql="SELECT a.account_id, a.password, a.amount AS a, a.status AS s FROM t_account_bak a WHERE a.account_id = 1 AND a.certificate_number like concat ('%', concat('abc','%'))" />
<output sql="SELECT a.account_id, a.cipher_password AS password, a.cipher_amount AS a, a.status AS s FROM t_account_bak a WHERE a.account_id = 1 AND a.like_query_certificate_number like concat ('like_query_%', concat ('like_query_abc', 'like_query_%'))" />
</rewrite-assertion>

<rewrite-assertion id="select_from_user_with_column_alias" db-types="SQLServer">
Expand Down

0 comments on commit 13691f4

Please sign in to comment.