Skip to content

Commit

Permalink
for #701: auto increment key can use cache for parse
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Apr 10, 2018
1 parent 4d04eef commit 84f0d43
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import io.shardingjdbc.core.parsing.lexer.LexerEngineFactory;
import io.shardingjdbc.core.parsing.parser.sql.SQLParserFactory;
import io.shardingjdbc.core.parsing.parser.sql.SQLStatement;
import io.shardingjdbc.core.parsing.parser.token.GeneratedKeyToken;
import io.shardingjdbc.core.parsing.parser.token.SQLToken;
import io.shardingjdbc.core.rule.ShardingRule;
import lombok.RequiredArgsConstructor;

Expand Down Expand Up @@ -57,8 +55,7 @@ public SQLStatement parse(final boolean useCache) {
LexerEngine lexerEngine = LexerEngineFactory.newInstance(dbType, sql);
lexerEngine.nextToken();
SQLStatement result = SQLParserFactory.newInstance(dbType, lexerEngine.getCurrentToken().getType(), shardingRule, lexerEngine).parse();
// TODO cannot cache InsertStatement here by generate key, should not modify original InsertStatement on router.
if (useCache && !findGeneratedKeyToken(result)) {
if (useCache) {
ParsingResultCache.getInstance().put(sql, result);
}
return result;
Expand All @@ -67,13 +64,4 @@ public SQLStatement parse(final boolean useCache) {
private Optional<SQLStatement> getSQLStatementFromCache(final boolean useCache) {
return useCache ? Optional.fromNullable(ParsingResultCache.getInstance().getSQLStatement(sql)) : Optional.<SQLStatement>absent();
}

private boolean findGeneratedKeyToken(final SQLStatement sqlStatement) {
for (SQLToken each : sqlStatement.getSqlTokens()) {
if (each instanceof GeneratedKeyToken) {
return true;
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,12 @@
package io.shardingjdbc.core.parsing.parser.sql.dml.insert;

import com.google.common.base.Optional;
import io.shardingjdbc.core.parsing.lexer.token.Symbol;
import io.shardingjdbc.core.parsing.parser.context.GeneratedKey;
import io.shardingjdbc.core.parsing.parser.context.condition.Column;
import io.shardingjdbc.core.parsing.parser.context.condition.Condition;
import io.shardingjdbc.core.parsing.parser.context.condition.Conditions;
import io.shardingjdbc.core.parsing.parser.expression.SQLNumberExpression;
import io.shardingjdbc.core.parsing.parser.expression.SQLPlaceholderExpression;
import io.shardingjdbc.core.parsing.parser.sql.dml.DMLStatement;
import io.shardingjdbc.core.parsing.parser.token.GeneratedKeyToken;
import io.shardingjdbc.core.parsing.parser.token.ItemsToken;
import io.shardingjdbc.core.parsing.parser.token.SQLToken;
import io.shardingjdbc.core.rule.ShardingRule;
import io.shardingjdbc.core.rule.TableRule;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
Expand Down Expand Up @@ -64,42 +57,11 @@ public final class InsertStatement extends DMLStatement {
private GeneratedKey generatedKey;

/**
* Append generate key token.
*
* @param shardingRule databases and tables sharding rule
* Find generated key token.
*
* @return generated key token
*/
public void appendGenerateKeyToken(final ShardingRule shardingRule) {
if (null != generatedKey) {
return;
}
Optional<TableRule> tableRule = shardingRule.tryFindTableRuleByLogicTable(getTables().getSingleTableName());
if (!tableRule.isPresent()) {
return;
}
Optional<GeneratedKeyToken> generatedKeysToken = findGeneratedKeyToken();
if (!generatedKeysToken.isPresent()) {
return;
}
ItemsToken valuesToken = new ItemsToken(generatedKeysToken.get().getBeginPosition());
appendGenerateKeyToken(shardingRule, tableRule.get(), valuesToken);
getSqlTokens().remove(generatedKeysToken.get());
getSqlTokens().add(valuesToken);
}

private void appendGenerateKeyToken(final ShardingRule shardingRule, final TableRule tableRule, final ItemsToken valuesToken) {
if (0 == getParametersIndex()) {
Number generatedKey = shardingRule.generateKey(tableRule.getLogicTable());
valuesToken.getItems().add(generatedKey.toString());
getConditions().add(new Condition(new Column(tableRule.getGenerateKeyColumn(), tableRule.getLogicTable()), new SQLNumberExpression(generatedKey)), shardingRule);
this.generatedKey = new GeneratedKey(tableRule.getLogicTable(), -1, generatedKey);
} else {
valuesToken.getItems().add(Symbol.QUESTION.getLiterals());
getConditions().add(new Condition(new Column(tableRule.getGenerateKeyColumn(), tableRule.getLogicTable()), new SQLPlaceholderExpression(getParametersIndex())), shardingRule);
generatedKey = new GeneratedKey(tableRule.getGenerateKeyColumn(), getParametersIndex(), null);
}
}

private Optional<GeneratedKeyToken> findGeneratedKeyToken() {
public Optional<GeneratedKeyToken> findGeneratedKeyToken() {
for (SQLToken each : getSqlTokens()) {
if (each instanceof GeneratedKeyToken) {
return Optional.of((GeneratedKeyToken) each);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
import com.google.common.base.Strings;
import io.shardingjdbc.core.constant.DatabaseType;
import io.shardingjdbc.core.parsing.lexer.token.DefaultKeyword;
import io.shardingjdbc.core.parsing.lexer.token.Symbol;
import io.shardingjdbc.core.parsing.parser.context.GeneratedKey;
import io.shardingjdbc.core.parsing.parser.context.OrderItem;
import io.shardingjdbc.core.parsing.parser.context.limit.Limit;
import io.shardingjdbc.core.parsing.parser.sql.SQLStatement;
import io.shardingjdbc.core.parsing.parser.sql.dql.select.SelectStatement;
import io.shardingjdbc.core.parsing.parser.token.GeneratedKeyToken;
import io.shardingjdbc.core.parsing.parser.token.IndexToken;
import io.shardingjdbc.core.parsing.parser.token.ItemsToken;
import io.shardingjdbc.core.parsing.parser.token.OffsetToken;
Expand Down Expand Up @@ -68,19 +71,23 @@ public final class SQLRewriteEngine {

private final SQLStatement sqlStatement;

private final GeneratedKey generatedKey;

/**
* Constructs SQL rewrite engine.
*
* @param shardingRule databases and tables sharding rule
* @param originalSQL original SQL
* @param databaseType database type
* @param sqlStatement SQL statement
* @param generatedKey generated key
*/
public SQLRewriteEngine(final ShardingRule shardingRule, final String originalSQL, final DatabaseType databaseType, final SQLStatement sqlStatement) {
public SQLRewriteEngine(final ShardingRule shardingRule, final String originalSQL, final DatabaseType databaseType, final SQLStatement sqlStatement, final GeneratedKey generatedKey) {
this.shardingRule = shardingRule;
this.originalSQL = originalSQL;
this.databaseType = databaseType;
this.sqlStatement = sqlStatement;
this.generatedKey = generatedKey;
sqlTokens.addAll(sqlStatement.getSqlTokens());
}

Expand Down Expand Up @@ -110,6 +117,8 @@ public SQLBuilder rewrite(final boolean isRewriteLimit) {
appendIndexPlaceholder(result, (IndexToken) each, count, sqlTokens);
} else if (each instanceof ItemsToken) {
appendItemsToken(result, (ItemsToken) each, count, sqlTokens);
} else if (each instanceof GeneratedKeyToken) {
appendGenerateKeyToken(result, (GeneratedKeyToken) each, count, sqlTokens);
} else if (each instanceof RowCountToken) {
appendLimitRowCount(result, (RowCountToken) each, count, sqlTokens, isRewriteLimit);
} else if (each instanceof OffsetToken) {
Expand Down Expand Up @@ -164,6 +173,16 @@ private void appendItemsToken(final SQLBuilder sqlBuilder, final ItemsToken item
appendRest(sqlBuilder, count, sqlTokens, beginPosition);
}

private void appendGenerateKeyToken(final SQLBuilder sqlBuilder, final GeneratedKeyToken generatedKeyToken, final int count, final List<SQLToken> sqlTokens) {
ItemsToken valuesToken = new ItemsToken(generatedKeyToken.getBeginPosition());
if (0 == sqlStatement.getParametersIndex()) {
valuesToken.getItems().add(generatedKey.getValue().toString());
} else {
valuesToken.getItems().add(Symbol.QUESTION.getLiterals());
}
appendItemsToken(sqlBuilder, valuesToken, count, sqlTokens);
}

private void appendLimitRowCount(final SQLBuilder sqlBuilder, final RowCountToken rowCountToken, final int count, final List<SQLToken> sqlTokens, final boolean isRewrite) {
SelectStatement selectStatement = (SelectStatement) sqlStatement;
Limit limit = selectStatement.getLimit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package io.shardingjdbc.core.routing;

import io.shardingjdbc.core.parsing.parser.context.GeneratedKey;
import io.shardingjdbc.core.parsing.parser.sql.SQLStatement;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
Expand All @@ -38,6 +39,8 @@ public final class SQLRouteResult {

private final SQLStatement sqlStatement;

private final GeneratedKey generatedKey;

private final Set<SQLExecutionUnit> executionUnits = new LinkedHashSet<>();

private final List<Number> generatedKeys = new LinkedList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public SQLStatement parse(final String logicSQL, final boolean useCache) {
@Override
// TODO insert SQL need parse gen key
public SQLRouteResult route(final String logicSQL, final List<Object> parameters, final SQLStatement sqlStatement) {
SQLRouteResult result = new SQLRouteResult(sqlStatement);
SQLRouteResult result = new SQLRouteResult(sqlStatement, null);
RoutingResult routingResult = new DatabaseHintRoutingEngine(shardingRule.getDataSourceNames(), (HintShardingStrategy) shardingRule.getDefaultDatabaseShardingStrategy()).route();
for (TableUnit each : routingResult.getTableUnits().getTableUnits()) {
result.getExecutionUnits().add(new SQLExecutionUnit(each.getDataSourceName(), logicSQL));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package io.shardingjdbc.core.routing.router;

import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import io.shardingjdbc.core.constant.DatabaseType;
import io.shardingjdbc.core.parsing.SQLParsingEngine;
import io.shardingjdbc.core.parsing.parser.context.GeneratedKey;
Expand All @@ -28,6 +30,7 @@
import io.shardingjdbc.core.parsing.parser.sql.ddl.DDLStatement;
import io.shardingjdbc.core.parsing.parser.sql.dml.insert.InsertStatement;
import io.shardingjdbc.core.parsing.parser.sql.dql.select.SelectStatement;
import io.shardingjdbc.core.parsing.parser.token.GeneratedKeyToken;
import io.shardingjdbc.core.rewrite.SQLBuilder;
import io.shardingjdbc.core.rewrite.SQLRewriteEngine;
import io.shardingjdbc.core.routing.SQLExecutionUnit;
Expand All @@ -45,6 +48,7 @@
import io.shardingjdbc.core.routing.type.standard.StandardRoutingEngine;
import io.shardingjdbc.core.routing.type.unicast.UnicastRoutingEngine;
import io.shardingjdbc.core.rule.ShardingRule;
import io.shardingjdbc.core.rule.TableRule;
import io.shardingjdbc.core.util.SQLLogger;
import lombok.RequiredArgsConstructor;

Expand All @@ -70,22 +74,21 @@ public final class ParsingSQLRouter implements SQLRouter {

@Override
public SQLStatement parse(final String logicSQL, final boolean useCache) {
SQLParsingEngine parsingEngine = new SQLParsingEngine(databaseType, logicSQL, shardingRule);
SQLStatement result = parsingEngine.parse(useCache);
if (result instanceof InsertStatement) {
((InsertStatement) result).appendGenerateKeyToken(shardingRule);
}
return result;
return new SQLParsingEngine(databaseType, logicSQL, shardingRule).parse(useCache);
}

@Override
public SQLRouteResult route(final String logicSQL, final List<Object> parameters, final SQLStatement sqlStatement) {
SQLRouteResult result = new SQLRouteResult(sqlStatement);
if (sqlStatement instanceof InsertStatement && null != ((InsertStatement) sqlStatement).getGeneratedKey()) {
processGeneratedKey(parameters, (InsertStatement) sqlStatement, result);
GeneratedKey generatedKey = null;
if (sqlStatement instanceof InsertStatement) {
generatedKey = getGenerateKey(shardingRule, (InsertStatement) sqlStatement);
}
RoutingResult routingResult = route(parameters, sqlStatement);
SQLRewriteEngine rewriteEngine = new SQLRewriteEngine(shardingRule, logicSQL, databaseType, sqlStatement);
SQLRouteResult result = new SQLRouteResult(sqlStatement, generatedKey);
if (null != generatedKey) {
processGeneratedKey(parameters, generatedKey, sqlStatement.getTables().getSingleTableName(), result);
}
RoutingResult routingResult = route(parameters, sqlStatement, generatedKey);
SQLRewriteEngine rewriteEngine = new SQLRewriteEngine(shardingRule, logicSQL, databaseType, sqlStatement, generatedKey);
boolean isSingleRouting = routingResult.isSingleRouting();
if (sqlStatement instanceof SelectStatement && null != ((SelectStatement) sqlStatement).getLimit()) {
processLimit(parameters, (SelectStatement) sqlStatement, isSingleRouting);
Expand All @@ -108,7 +111,7 @@ public SQLRouteResult route(final String logicSQL, final List<Object> parameters
return result;
}

private RoutingResult route(final List<Object> parameters, final SQLStatement sqlStatement) {
private RoutingResult route(final List<Object> parameters, final SQLStatement sqlStatement, final GeneratedKey generatedKey) {
Collection<String> tableNames = sqlStatement.getTables().getTableNames();
RoutingEngine routingEngine;
if (sqlStatement instanceof UseStatement) {
Expand All @@ -124,20 +127,38 @@ private RoutingResult route(final List<Object> parameters, final SQLStatement sq
} else if (tableNames.isEmpty()) {
routingEngine = new DatabaseBroadcastRoutingEngine(shardingRule);
} else if (1 == tableNames.size() || shardingRule.isAllBindingTables(tableNames) || shardingRule.isAllInDefaultDataSource(tableNames)) {
routingEngine = new StandardRoutingEngine(shardingRule, parameters, tableNames.iterator().next(), sqlStatement);
routingEngine = new StandardRoutingEngine(shardingRule, parameters, tableNames.iterator().next(), sqlStatement, generatedKey);
} else {
// TODO config for cartesian set
routingEngine = new ComplexRoutingEngine(shardingRule, parameters, tableNames, sqlStatement);
}
return routingEngine.route();
}

private void processGeneratedKey(final List<Object> parameters, final InsertStatement insertStatement, final SQLRouteResult sqlRouteResult) {
GeneratedKey generatedKey = insertStatement.getGeneratedKey();
private GeneratedKey getGenerateKey(final ShardingRule shardingRule, final InsertStatement insertStatement) {
if (null != insertStatement.getGeneratedKey()) {
return insertStatement.getGeneratedKey();
}
Optional<TableRule> tableRule = shardingRule.tryFindTableRuleByLogicTable(insertStatement.getTables().getSingleTableName());
if (!tableRule.isPresent()) {
return null;
}
Optional<GeneratedKeyToken> generatedKeysToken = insertStatement.findGeneratedKeyToken();
if (!generatedKeysToken.isPresent()) {
return null;
}
String logicTableName = insertStatement.getTables().getSingleTableName();
Optional<String> generateKeyColumn = shardingRule.getGenerateKeyColumn(logicTableName);
Preconditions.checkState(generateKeyColumn.isPresent());
return 0 == insertStatement.getParametersIndex()
? new GeneratedKey(generateKeyColumn.get(), -1, shardingRule.generateKey(logicTableName)) : new GeneratedKey(generateKeyColumn.get(), insertStatement.getParametersIndex(), null);
}

private void processGeneratedKey(final List<Object> parameters, final GeneratedKey generatedKey, final String logicTableName, final SQLRouteResult sqlRouteResult) {
if (parameters.isEmpty()) {
sqlRouteResult.getGeneratedKeys().add(generatedKey.getValue());
} else if (parameters.size() == generatedKey.getIndex()) {
Number key = shardingRule.generateKey(insertStatement.getTables().getSingleTableName());
Number key = shardingRule.generateKey(logicTableName);
parameters.add(key);
setGeneratedKeys(sqlRouteResult, key);
} else if (-1 != generatedKey.getIndex()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public RoutingResult route() {
Optional<TableRule> tableRule = shardingRule.tryFindTableRuleByLogicTable(each);
if (tableRule.isPresent()) {
if (!bindingTableNames.contains(each)) {
result.add(new StandardRoutingEngine(shardingRule, parameters, tableRule.get().getLogicTable(), sqlStatement).route());
result.add(new StandardRoutingEngine(shardingRule, parameters, tableRule.get().getLogicTable(), sqlStatement, null).route());
}
Optional<BindingTableRule> bindingTableRule = shardingRule.findBindingTableRule(each);
if (bindingTableRule.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import io.shardingjdbc.core.api.algorithm.sharding.ListShardingValue;
import io.shardingjdbc.core.api.algorithm.sharding.ShardingValue;
import io.shardingjdbc.core.hint.HintManagerHolder;
import io.shardingjdbc.core.hint.ShardingKey;
import io.shardingjdbc.core.parsing.parser.context.GeneratedKey;
import io.shardingjdbc.core.parsing.parser.context.condition.Column;
import io.shardingjdbc.core.parsing.parser.context.condition.Condition;
import io.shardingjdbc.core.parsing.parser.sql.SQLStatement;
Expand All @@ -36,6 +38,7 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

Expand All @@ -55,6 +58,8 @@ public final class StandardRoutingEngine implements RoutingEngine {

private final SQLStatement sqlStatement;

private final GeneratedKey generatedKey;

@Override
public RoutingResult route() {
TableRule tableRule = shardingRule.getTableRule(logicTableName);
Expand Down Expand Up @@ -106,6 +111,9 @@ private List<ShardingValue> getShardingValues(final Collection<String> shardingC
Optional<Condition> condition = sqlStatement.getConditions().find(new Column(each, logicTableName));
if (condition.isPresent()) {
result.add(condition.get().getShardingValue(parameters));
} else if (null != generatedKey && each.equals(generatedKey.getColumn())) {
Comparable key = null == generatedKey.getValue() ? (Comparable) parameters.get(generatedKey.getIndex()) : (Comparable) generatedKey.getValue();
result.add(new ListShardingValue<>(sqlStatement.getTables().getSingleTableName(), each, Collections.singletonList(key)));
}
}
return result;
Expand Down
Loading

0 comments on commit 84f0d43

Please sign in to comment.