Skip to content

Commit

Permalink
Support PostgreSQL, openGauss function table and update from parse (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu authored Sep 25, 2024
1 parent f1ed537 commit bdeccfe
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public UpdateStatement bind(final UpdateStatement sqlStatement, final SQLStateme
UpdateStatement result = copy(sqlStatement);
Map<String, TableSegmentBinderContext> tableBinderContexts = new LinkedHashMap<>();
result.setTable(TableSegmentBinder.bind(sqlStatement.getTable(), binderContext, tableBinderContexts, Collections.emptyMap()));
sqlStatement.getFrom().ifPresent(optional -> result.setFrom(TableSegmentBinder.bind(optional, binderContext, tableBinderContexts, Collections.emptyMap())));
sqlStatement.getAssignmentSegment().ifPresent(optional -> result.setSetAssignment(AssignmentSegmentBinder.bind(optional, binderContext, tableBinderContexts, Collections.emptyMap())));
sqlStatement.getWhere().ifPresent(optional -> result.setWhere(WhereSegmentBinder.bind(optional, binderContext, tableBinderContexts, Collections.emptyMap())));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WindowItemSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WindowSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.FunctionTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.JoinTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SubqueryTableSegment;
Expand Down Expand Up @@ -898,6 +899,9 @@ public ASTNode visitUpdate(final UpdateContext ctx) {
if (null != ctx.whereOrCurrentClause()) {
result.setWhere((WhereSegment) visit(ctx.whereOrCurrentClause()));
}
if (null != ctx.fromClause()) {
result.setFrom((TableSegment) visit(ctx.fromClause()));
}
result.addParameterMarkerSegments(getParameterMarkerSegments());
return result;
}
Expand Down Expand Up @@ -1222,6 +1226,9 @@ public ASTNode visitTableReference(final TableReferenceContext ctx) {
if (null != ctx.tableReference()) {
return getJoinTableSegment(ctx);
}
if (null != ctx.functionTable() && null != ctx.functionTable().functionExprWindowless() && null != ctx.functionTable().functionExprWindowless().funcApplication()) {
return getFunctionTableSegment(ctx);
}
// TODO deal with functionTable and xmlTable
return new SimpleTableSegment(new TableNameSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), new IdentifierValue("not support")));
}
Expand Down Expand Up @@ -1263,6 +1270,19 @@ private JoinTableSegment getJoinTableSegment(final TableReferenceContext ctx) {
return result;
}

private FunctionTableSegment getFunctionTableSegment(final TableReferenceContext ctx) {
FunctionSegment functionSegment = (FunctionSegment) visit(ctx.functionTable().functionExprWindowless().funcApplication());
return new FunctionTableSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), functionSegment);
}

@Override
public ASTNode visitFuncApplication(final FuncApplicationContext ctx) {
Collection<ExpressionSegment> expressionSegments = getExpressionSegments(getTargetRuleContextFromParseTree(ctx, AExprContext.class));
FunctionSegment result = new FunctionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), ctx.funcName().getText(), getOriginalText(ctx));
result.getParameters().addAll(expressionSegments);
return result;
}

private JoinTableSegment visitJoinedTable(final JoinedTableContext ctx, final JoinTableSegment tableSegment) {
TableSegment right = (TableSegment) visit(ctx.tableReference());
tableSegment.setRight(right);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WindowItemSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WindowSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.FunctionTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.JoinTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SubqueryTableSegment;
Expand Down Expand Up @@ -871,6 +872,9 @@ public ASTNode visitUpdate(final UpdateContext ctx) {
if (null != ctx.whereOrCurrentClause()) {
result.setWhere((WhereSegment) visit(ctx.whereOrCurrentClause()));
}
if (null != ctx.fromClause()) {
result.setFrom((TableSegment) visit(ctx.fromClause()));
}
result.addParameterMarkerSegments(getParameterMarkerSegments());
return result;
}
Expand Down Expand Up @@ -1184,25 +1188,39 @@ public ASTNode visitFromList(final FromListContext ctx) {
@Override
public ASTNode visitTableReference(final TableReferenceContext ctx) {
if (null != ctx.relationExpr()) {
SimpleTableSegment result = (SimpleTableSegment) visit(ctx.relationExpr().qualifiedName());
if (null != ctx.aliasClause()) {
result.setAlias((AliasSegment) visit(ctx.aliasClause()));
}
return result;
return getSimpleTableSegment(ctx);
}
if (null != ctx.selectWithParens()) {
PostgreSQLSelectStatement select = (PostgreSQLSelectStatement) visit(ctx.selectWithParens());
SubquerySegment subquery = new SubquerySegment(ctx.selectWithParens().start.getStartIndex(), ctx.selectWithParens().stop.getStopIndex(), select, getOriginalText(ctx.selectWithParens()));
AliasSegment alias = null == ctx.aliasClause() ? null : (AliasSegment) visit(ctx.aliasClause());
SubqueryTableSegment result = new SubqueryTableSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), subquery);
result.setAlias(alias);
return result;
return getSubqueryTableSegment(ctx);
}
if (null != ctx.tableReference()) {
return getJoinTableSegment(ctx);
}
if (null == ctx.tableReference()) {
// TODO deal with functionTable and xmlTable
TableNameSegment tableName = new TableNameSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), new IdentifierValue("not support"));
return new SimpleTableSegment(tableName);
if (null != ctx.functionTable() && null != ctx.functionTable().functionExprWindowless() && null != ctx.functionTable().functionExprWindowless().funcApplication()) {
return getFunctionTableSegment(ctx);
}
// TODO deal with functionTable and xmlTable
return new SimpleTableSegment(new TableNameSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), new IdentifierValue("not support")));
}

private SimpleTableSegment getSimpleTableSegment(final TableReferenceContext ctx) {
SimpleTableSegment result = (SimpleTableSegment) visit(ctx.relationExpr().qualifiedName());
if (null != ctx.aliasClause()) {
result.setAlias((AliasSegment) visit(ctx.aliasClause()));
}
return result;
}

private SubqueryTableSegment getSubqueryTableSegment(final TableReferenceContext ctx) {
PostgreSQLSelectStatement select = (PostgreSQLSelectStatement) visit(ctx.selectWithParens());
SubquerySegment subquery = new SubquerySegment(ctx.selectWithParens().start.getStartIndex(), ctx.selectWithParens().stop.getStopIndex(), select, getOriginalText(ctx.selectWithParens()));
AliasSegment alias = null == ctx.aliasClause() ? null : (AliasSegment) visit(ctx.aliasClause());
SubqueryTableSegment result = new SubqueryTableSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), subquery);
result.setAlias(alias);
return result;
}

private ASTNode getJoinTableSegment(final TableReferenceContext ctx) {
JoinTableSegment result = new JoinTableSegment();
result.setLeft((TableSegment) visit(ctx.tableReference()));
int startIndex = null == ctx.LP_() ? ctx.tableReference().start.getStartIndex() : ctx.LP_().getSymbol().getStartIndex();
Expand All @@ -1221,6 +1239,19 @@ public ASTNode visitTableReference(final TableReferenceContext ctx) {
return result;
}

private FunctionTableSegment getFunctionTableSegment(final TableReferenceContext ctx) {
FunctionSegment functionSegment = (FunctionSegment) visit(ctx.functionTable().functionExprWindowless().funcApplication());
return new FunctionTableSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), functionSegment);
}

@Override
public ASTNode visitFuncApplication(final FuncApplicationContext ctx) {
Collection<ExpressionSegment> expressionSegments = getExpressionSegments(getTargetRuleContextFromParseTree(ctx, AExprContext.class));
FunctionSegment result = new FunctionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), ctx.funcName().getText(), getOriginalText(ctx));
result.getParameters().addAll(expressionSegments);
return result;
}

private JoinTableSegment visitJoinedTable(final JoinedTableContext ctx, final JoinTableSegment tableSegment) {
TableSegment right = (TableSegment) visit(ctx.tableReference());
tableSegment.setRight(right);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ exec
;

update
: withClause? UPDATE top? tableReferences withTableHint? setAssignmentsClause outputClause? whereClause? optionHint?
: withClause? UPDATE top? tableReferences withTableHint? setAssignmentsClause fromClause? outputClause? whereClause? optionHint?
;

assignment
: columnName ((PLUS_ | MINUS_ | ASTERISK_ | SLASH_ | MOD_)? EQ_ | DOT_) assignmentValue
;

setAssignmentsClause
: SET assignment (COMMA_ assignment)* fromClause?
: SET assignment (COMMA_ assignment)*
;

assignmentValues
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,9 @@ public ASTNode visitUpdate(final UpdateContext ctx) {
}
result.setTable((TableSegment) visit(ctx.tableReferences()));
result.setSetAssignment((SetAssignmentSegment) visit(ctx.setAssignmentsClause()));
if (null != ctx.fromClause()) {
result.setFrom((TableSegment) visit(ctx.fromClause()));
}
if (null != ctx.withTableHint()) {
result.setWithTableHintSegment((WithTableHintSegment) visit(ctx.withTableHint()));
}
Expand Down Expand Up @@ -1362,11 +1365,7 @@ public ASTNode visitSetAssignmentsClause(final SetAssignmentsClauseContext ctx)
for (AssignmentContext each : ctx.assignment()) {
assignments.add((ColumnAssignmentSegment) visit(each));
}
SetAssignmentSegment result = new SetAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), assignments);
if (null != ctx.fromClause()) {
result.setFrom((TableSegment) visit(ctx.fromClause()));
}
return result;
return new SetAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), assignments);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.sql.parser.statement.core.segment.SQLSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment;

import java.util.Collection;

Expand All @@ -37,7 +35,4 @@ public final class SetAssignmentSegment implements SQLSegment {
private final int stopIndex;

private final Collection<ColumnAssignmentSegment> assignments;

@Setter
private TableSegment from;
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,21 @@ public Optional<OutputSegment> getOutputSegment() {
*/
public void setOutputSegment(final OutputSegment outputSegment) {
}

/**
* Get from segment.
*
* @return from segment
*/
public Optional<TableSegment> getFrom() {
return Optional.empty();
}

/**
* Set from segment.
*
* @param from from segment
*/
public void setFrom(final TableSegment from) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,23 @@

package org.apache.shardingsphere.sql.parser.statement.opengauss.dml;

import lombok.Setter;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.statement.opengauss.OpenGaussStatement;

import java.util.Optional;

/**
* OpenGauss update statement.
*/
@Setter
public final class OpenGaussUpdateStatement extends UpdateStatement implements OpenGaussStatement {

private TableSegment from;

@Override
public Optional<TableSegment> getFrom() {
return Optional.ofNullable(from);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,23 @@

package org.apache.shardingsphere.sql.parser.statement.postgresql.dml;

import lombok.Setter;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.statement.postgresql.PostgreSQLStatement;

import java.util.Optional;

/**
* PostgreSQL update statement.
*/
@Setter
public final class PostgreSQLUpdateStatement extends UpdateStatement implements PostgreSQLStatement {

private TableSegment from;

@Override
public Optional<TableSegment> getFrom() {
return Optional.ofNullable(from);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
package org.apache.shardingsphere.sql.parser.statement.sqlserver.dml;

import lombok.Setter;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.hint.OptionHintSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.hint.WithTableHintSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OutputSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.hint.OptionHintSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.hint.WithTableHintSegment;
import org.apache.shardingsphere.sql.parser.statement.sqlserver.SQLServerStatement;

import java.util.Optional;
Expand All @@ -41,6 +42,8 @@ public final class SQLServerUpdateStatement extends UpdateStatement implements S

private OutputSegment outputSegment;

private TableSegment from;

@Override
public Optional<WithSegment> getWithSegment() {
return Optional.ofNullable(withSegment);
Expand All @@ -64,4 +67,9 @@ public Optional<OptionHintSegment> getOptionHintSegment() {
public Optional<OutputSegment> getOutputSegment() {
return Optional.ofNullable(outputSegment);
}

@Override
public Optional<TableSegment> getFrom() {
return Optional.ofNullable(from);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public final class UpdateStatementAssert {
public static void assertIs(final SQLCaseAssertContext assertContext, final UpdateStatement actual, final UpdateStatementTestCase expected) {
assertTable(assertContext, actual, expected);
assertSetClause(assertContext, actual, expected);
assertFromClause(assertContext, actual, expected);
assertWhereClause(assertContext, actual, expected);
assertOrderByClause(assertContext, actual, expected);
assertLimitClause(assertContext, actual, expected);
Expand All @@ -77,6 +78,15 @@ private static void assertSetClause(final SQLCaseAssertContext assertContext, fi
SetClauseAssert.assertIs(assertContext, actual.getSetAssignment(), expected.getSetClause());
}

private static void assertFromClause(final SQLCaseAssertContext assertContext, final UpdateStatement actual, final UpdateStatementTestCase expected) {
if (null == expected.getFrom()) {
assertFalse(actual.getFrom().isPresent(), assertContext.getText("Actual from segment should not exist."));
} else {
assertTrue(actual.getFrom().isPresent(), assertContext.getText("Actual from segment should exist."));
TableAssert.assertIs(assertContext, actual.getFrom().get(), expected.getFrom());
}
}

private static void assertWhereClause(final SQLCaseAssertContext assertContext, final UpdateStatement actual, final UpdateStatementTestCase expected) {
if (null == expected.getWhereClause()) {
assertFalse(actual.getWhere().isPresent(), assertContext.getText("Actual where segment should not exist."));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

import lombok.Getter;
import lombok.Setter;
import org.apache.shardingsphere.test.it.sql.parser.internal.cases.parser.jaxb.SQLParserTestCase;
import org.apache.shardingsphere.test.it.sql.parser.internal.cases.parser.jaxb.segment.impl.hint.ExpectedOptionHint;
import org.apache.shardingsphere.test.it.sql.parser.internal.cases.parser.jaxb.segment.impl.limit.ExpectedLimitClause;
import org.apache.shardingsphere.test.it.sql.parser.internal.cases.parser.jaxb.segment.impl.orderby.ExpectedOrderByClause;
import org.apache.shardingsphere.test.it.sql.parser.internal.cases.parser.jaxb.segment.impl.output.ExpectedOutputClause;
import org.apache.shardingsphere.test.it.sql.parser.internal.cases.parser.jaxb.segment.impl.set.ExpectedSetClause;
import org.apache.shardingsphere.test.it.sql.parser.internal.cases.parser.jaxb.segment.impl.table.ExpectedTable;
import org.apache.shardingsphere.test.it.sql.parser.internal.cases.parser.jaxb.segment.impl.where.ExpectedWhereClause;
import org.apache.shardingsphere.test.it.sql.parser.internal.cases.parser.jaxb.SQLParserTestCase;

import javax.xml.bind.annotation.XmlElement;

Expand All @@ -43,6 +43,9 @@ public final class UpdateStatementTestCase extends SQLParserTestCase {
@XmlElement(name = "set")
private ExpectedSetClause setClause;

@XmlElement
private ExpectedTable from;

@XmlElement(name = "where")
private ExpectedWhereClause whereClause;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2641,7 +2641,14 @@
<shorthand-projection start-index="7" stop-index="7" />
</projections>
<from>
<simple-table name="not support" start-index="14" stop-index="70" />
<function-table>
<table-function function-name="cypher" text="cypher('sharding_test_1', $$ CREATE (n) $$)">
<parameter>
<literal-expression value="sharding_test_1" start-index="21" stop-index="37" />
<common-expression literal-text="$$ CREATE (n) $$" start-index="40" stop-index="55"/>
</parameter>
</table-function>
</function-table>
</from>
</select>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4130,4 +4130,20 @@
</expression-projection>
</projections>
</select>

<select sql-case-id="select_from_table_function">
<projections start-index="7" stop-index="7">
<shorthand-projection start-index="7" stop-index="7"/>
</projections>
<from>
<function-table>
<table-function function-name="GENERATE_SERIES" text="GENERATE_SERIES(1, name)">
<parameter>
<literal-expression value="1" start-index="30" stop-index="30" />
<column name="name" start-index="33" stop-index="36" />
</parameter>
</table-function>
</function-table>
</from>
</select>
</sql-parser-test-cases>
Loading

0 comments on commit bdeccfe

Please sign in to comment.