From c530ccb211c25382d44b6cea6f3cf6608df0484e Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 21 May 2024 11:27:48 +0800 Subject: [PATCH] Enhance the used columns analyze and required datasets for dynamic query mode (#556) * check if relationable is added before genrated * enhance analyze winodw, groupby ,orderby and aliased relation * fix the relationable generated --- .../src/main/java/io/wren/base/WrenMDL.java | 12 ++++- .../wren/base/sqlrewrite/WrenSqlRewrite.java | 22 ++++++--- .../analyzer/ExpressionAnalyzer.java | 26 +++++++++++ .../analyzer/StatementAnalyzer.java | 46 ++++++++++--------- .../base/sqlrewrite/AbstractTestModel.java | 25 ++++++++++ .../io/wren/base/sqlrewrite/TestMetric.java | 9 +++- 6 files changed, 108 insertions(+), 32 deletions(-) diff --git a/wren-base/src/main/java/io/wren/base/WrenMDL.java b/wren-base/src/main/java/io/wren/base/WrenMDL.java index b9f8f3bb6..a47f84ffa 100644 --- a/wren-base/src/main/java/io/wren/base/WrenMDL.java +++ b/wren-base/src/main/java/io/wren/base/WrenMDL.java @@ -26,6 +26,7 @@ import io.wren.base.dto.Manifest; import io.wren.base.dto.Metric; import io.wren.base.dto.Model; +import io.wren.base.dto.Relationable; import io.wren.base.dto.Relationship; import io.wren.base.dto.View; import io.wren.base.jinjava.JinjavaExpressionProcessor; @@ -255,13 +256,13 @@ public List listViews() return manifest.getViews(); } - public static Optional getRelationshipColumn(Model model, String name) + public static Optional getRelationshipColumn(Model model, String name) { return getColumn(model, name) .filter(column -> column.getRelationship().isPresent()); } - private static Optional getColumn(Model model, String name) + private static Optional getColumn(Model model, String name) { requireNonNull(model); requireNonNull(name); @@ -317,4 +318,11 @@ public boolean isObjectExist(String name) || getCumulativeMetric(name).isPresent() || getView(name).isPresent(); } + + public Optional getRelationable(String name) + { + return getModel(name) + .map(model -> (Relationable) model) + .or(() -> getMetric(name).map(metric -> (Relationable) metric)); + } } diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/WrenSqlRewrite.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/WrenSqlRewrite.java index 84f02ea6a..8f6810868 100644 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/WrenSqlRewrite.java +++ b/wren-base/src/main/java/io/wren/base/sqlrewrite/WrenSqlRewrite.java @@ -28,9 +28,11 @@ import io.wren.base.SessionContext; import io.wren.base.Utils; import io.wren.base.WrenMDL; +import io.wren.base.dto.Column; import io.wren.base.dto.CumulativeMetric; import io.wren.base.dto.Metric; import io.wren.base.dto.Model; +import io.wren.base.dto.Relationable; import io.wren.base.sqlrewrite.analyzer.Analysis; import io.wren.base.sqlrewrite.analyzer.StatementAnalyzer; import org.jgrapht.graph.DirectedAcyclicGraph; @@ -47,7 +49,9 @@ import static com.google.common.base.Strings.nullToEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.wren.base.sqlrewrite.Utils.toCatalogSchemaTableName; +import static java.lang.String.format; import static java.util.stream.Collectors.toMap; import static java.util.stream.Collectors.toSet; @@ -92,19 +96,23 @@ public Statement apply(Statement root, SessionContext sessionContext, Analysis a .filter(e -> wrenMDL.getView(e.getKey()).isEmpty()) .collect(toMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> a, LinkedHashMap::new)); + // Some node be applied `count(*)` which won't be collected but its source is required. + analysis.getRequiredSourceNodes().forEach(node -> { + String tableName = analysis.getSourceNodeNames(node).map(QualifiedName::toString) + .orElseThrow(() -> new IllegalArgumentException(format("source node name not found: %s", node))); + if (!tableRequiredFields.containsKey(tableName)) { + Relationable relationable = wrenMDL.getRelationable(tableName) + .orElseThrow(() -> new IllegalArgumentException(format("dataset not found: %s", tableName))); + tableRequiredFields.put(tableName, relationable.getColumns().stream().map(Column::getName).collect(toImmutableSet())); + } + }); + ImmutableList.Builder descriptorsBuilder = ImmutableList.builder(); tableRequiredFields.forEach((name, value) -> { addDescriptor(name, value, wrenMDL, descriptorsBuilder); visitedTables.remove(toCatalogSchemaTableName(sessionContext, QualifiedName.of(name))); }); - // Some node be applied `count(*)` which won't be collected but its source is required. - analysis.getRequiredSourceNodes().forEach(node -> - analysis.getSourceNodeNames(node).map(QualifiedName::toString).ifPresent(name -> { - addDescriptor(name, wrenMDL, descriptorsBuilder); - visitedTables.remove(toCatalogSchemaTableName(sessionContext, QualifiedName.of(name))); - })); - List withQueries = new ArrayList<>(); // add date spine if needed if (tableRequiredFields.keySet().stream() diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/ExpressionAnalyzer.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/ExpressionAnalyzer.java index 7db05084d..ec7c064f9 100644 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/ExpressionAnalyzer.java +++ b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/ExpressionAnalyzer.java @@ -19,11 +19,15 @@ import io.trino.sql.tree.DefaultTraversalVisitor; import io.trino.sql.tree.DereferenceExpression; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.FrameBound; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.Identifier; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.SubqueryExpression; +import io.trino.sql.tree.Window; +import io.trino.sql.tree.WindowOperation; +import io.trino.sql.tree.WindowSpecification; import io.wren.base.SessionContext; import io.wren.base.WrenMDL; @@ -111,6 +115,9 @@ protected Void visitFunctionCall(FunctionCall node, Void context) return null; } node.getArguments().forEach(this::process); + node.getWindow().ifPresent(this::analyzeWindow); + node.getFilter().ifPresent(this::process); + node.getOrderBy().ifPresent(orderBy -> orderBy.getSortItems().forEach(sortItem -> process(sortItem.getSortKey()))); return null; } @@ -121,6 +128,25 @@ protected Void visitSubqueryExpression(SubqueryExpression node, Void context) return null; } + @Override + protected Void visitWindowOperation(WindowOperation node, Void context) + { + analyzeWindow(node.getWindow()); + return null; + } + + private void analyzeWindow(Window window) + { + if (window instanceof WindowSpecification windowSpecification) { + windowSpecification.getPartitionBy().forEach(this::process); + windowSpecification.getOrderBy().ifPresent(orderBy -> orderBy.getSortItems().forEach(sortItem -> process(sortItem.getSortKey()))); + windowSpecification.getFrame().ifPresent(frame -> { + frame.getStart().getValue().ifPresent(this::process); + frame.getEnd().flatMap(FrameBound::getValue).ifPresent(this::process); + }); + } + } + public Map, Field> getReferenceFields() { return referenceFields; diff --git a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/StatementAnalyzer.java b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/StatementAnalyzer.java index 167409200..62c971513 100644 --- a/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/StatementAnalyzer.java +++ b/wren-base/src/main/java/io/wren/base/sqlrewrite/analyzer/StatementAnalyzer.java @@ -20,6 +20,7 @@ import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.DereferenceExpression; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.FrameBound; import io.trino.sql.tree.FunctionRelation; import io.trino.sql.tree.Identifier; import io.trino.sql.tree.Join; @@ -41,6 +42,7 @@ import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.Unnest; import io.trino.sql.tree.Values; +import io.trino.sql.tree.WindowSpecification; import io.trino.sql.tree.With; import io.trino.sql.tree.WithQuery; import io.wren.base.CatalogSchemaTableName; @@ -52,13 +54,10 @@ import io.wren.base.dto.Model; import io.wren.base.dto.TimeUnit; import io.wren.base.dto.View; -import io.wren.base.sqlrewrite.analyzer.decisionpoint.QueryAnalysis; -import io.wren.base.sqlrewrite.analyzer.matcher.PredicateMatcher; import javax.annotation.Nullable; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; @@ -315,27 +314,24 @@ protected Scope visitQuery(Query node, Optional scope) @Override protected Scope visitQuerySpecification(QuerySpecification node, Optional scope) { - QueryAnalysis.Builder queryAnalysisBuilder = QueryAnalysis.builder(); Scope sourceScope = analyzeFrom(node, scope); - List outputExpressions = analyzeSelect(node, sourceScope); + analyzeSelect(node, sourceScope); node.getWhere().ifPresent(where -> analyzeWhere(where, sourceScope)); node.getHaving().ifPresent(having -> analyzeExpression(having, sourceScope)); + node.getWindows().forEach(window -> analyzeWindowSpecification(window.getWindow(), sourceScope)); + node.getGroupBy().ifPresent(groupBy -> groupBy.getGroupingElements().forEach(groupingElement -> { + groupingElement.getExpressions().forEach(expression -> analyzeExpression(expression, sourceScope)); + })); node.getOrderBy().ifPresent(orderBy -> orderBy.getSortItems() .forEach(item -> { - QualifiedName name; - if (item.getSortKey() instanceof LongLiteral) { - long index = ((LongLiteral) item.getSortKey()).getValue() - 1; - name = getQualifiedName(outputExpressions.get((int) index)); - } - else { - name = getQualifiedName(item.getSortKey()); + if (!(item.getSortKey() instanceof LongLiteral)) { + analyzeExpression(item.getSortKey(), sourceScope); } })); - // TODO: this scope is wrong. return createAndAssignScope(node, scope, sourceScope); } - private List analyzeSelect(QuerySpecification node, Scope scope) + private void analyzeSelect(QuerySpecification node, Scope scope) { ImmutableList.Builder outputExpressions = ImmutableList.builder(); for (SelectItem item : node.getSelect().getSelectItems()) { @@ -349,7 +345,7 @@ else if (item instanceof SingleColumn) { throw new IllegalArgumentException("Unsupported SelectItem type: " + item.getClass().getName()); } } - return outputExpressions.build(); + outputExpressions.build(); } private void analyzeSelectAllColumns(AllColumns allColumns, Scope scope, ImmutableList.Builder outputExpressions) @@ -394,16 +390,21 @@ private Scope analyzeFrom(QuerySpecification node, Optional scope) private void analyzeWhere(Expression node, Scope scope) { ExpressionAnalysis expressionAnalysis = analyzeExpression(node, scope); - Map, Field> fields = expressionAnalysis.getReferencedFields(); - expressionAnalysis.getPredicates().stream() - .filter(PredicateMatcher.PREDICATE_MATCHER::shapeMatches) - .forEach(comparisonExpression -> { - Expression expression = comparisonExpression.getLeft(); - }); typeCoercionOptional.flatMap(typeCoercion -> typeCoercion.coerceExpression(node, scope)) .ifPresent(expression -> analysis.addTypeCoercion(NodeRef.of(node), expression)); } + private void analyzeWindowSpecification(WindowSpecification windowSpecification, Scope scope) + { + windowSpecification.getExistingWindowName().ifPresent(name -> analyzeExpression(name, scope)); + windowSpecification.getPartitionBy().forEach(expression -> analyzeExpression(expression, scope)); + windowSpecification.getOrderBy().ifPresent(orderBy -> orderBy.getSortItems().forEach(item -> analyzeExpression(item.getSortKey(), scope))); + windowSpecification.getFrame().ifPresent(frame -> { + frame.getStart().getValue().ifPresent(start -> analyzeExpression(start, scope)); + frame.getEnd().flatMap(FrameBound::getValue).ifPresent(end -> analyzeExpression(end, scope)); + }); + } + @Override protected Scope visitValues(Values node, Optional scope) { @@ -517,7 +518,8 @@ protected Scope visitJoin(Join node, Optional scope) protected Scope visitAliasedRelation(AliasedRelation relation, Optional scope) { Scope relationScope = process(relation.getRelation(), scope); - + relationScope.getRelationId().getSourceNode().flatMap(analysis::getSourceNodeNames) + .ifPresent(names -> analysis.addSourceNodeName(NodeRef.of(relation), names)); List fields = relationScope.getRelationType().getFields(); // if scope is a data source scope, we should get the fields from MDL if (relationScope.isDataSourceScope()) { diff --git a/wren-base/src/test/java/io/wren/base/sqlrewrite/AbstractTestModel.java b/wren-base/src/test/java/io/wren/base/sqlrewrite/AbstractTestModel.java index 57995e01d..abb064c06 100644 --- a/wren-base/src/test/java/io/wren/base/sqlrewrite/AbstractTestModel.java +++ b/wren-base/src/test/java/io/wren/base/sqlrewrite/AbstractTestModel.java @@ -188,6 +188,31 @@ public void testToOneCalculated() "WHERE l.orderkey = 44995"); assertQuery(mdl, "SELECT count(*) FROM Lineitem", "SELECT count(*) FROM lineitem"); + assertQuery(mdl, "SELECT count(*) FROM Lineitem WHERE orderkey = 44995", + "SELECT count(*) FROM lineitem WHERE orderkey = 44995"); + assertQuery(mdl, "SELECT count(*) FROM Lineitem l WHERE l.orderkey = 44995", + "SELECT count(*) FROM lineitem l WHERE l.orderkey = 44995"); + + assertQuery(mdl, "SELECT col_1 FROM Lineitem ORDER BY col_2", "SELECT (totalprice + totalprice) AS col_1\n" + + "FROM lineitem l\n" + + "LEFT JOIN orders o ON l.orderkey = o.orderkey\n" + + "LEFT JOIN customer c ON o.custkey = c.custkey\n" + + "ORDER BY concat(l.orderkey, '#', c.custkey)"); + assertQuery(mdl, "SELECT count(*) FROM Lineitem group by col_1, col_2 order by 1", "SELECT count(*)\n" + + "FROM lineitem l\n" + + "LEFT JOIN orders o ON l.orderkey = o.orderkey\n" + + "LEFT JOIN customer c ON o.custkey = c.custkey\n" + + "GROUP BY (totalprice + totalprice), concat(l.orderkey, '#', c.custkey)\n" + + "ORDER BY 1"); + assertQuery(mdl, "SELECT rank() over (order by col_1) FROM Lineitem", + "SELECT rank() OVER (ORDER BY (totalprice + totalprice))\n" + + "FROM lineitem l\n" + + "LEFT JOIN orders o ON l.orderkey = o.orderkey"); + assertQuery(mdl, "SELECT count(f1) FROM (SELECT lag(extendedprice) over (partition by col_2) as f1 FROM Lineitem)", + "SELECT count(f1) FROM (SELECT lag(extendedprice) OVER (PARTITION BY concat(l.orderkey, '#', c.custkey)) as f1\n" + + "FROM lineitem l\n" + + "LEFT JOIN orders o ON l.orderkey = o.orderkey\n" + + "LEFT JOIN customer c ON o.custkey = c.custkey)"); } @Test diff --git a/wren-base/src/test/java/io/wren/base/sqlrewrite/TestMetric.java b/wren-base/src/test/java/io/wren/base/sqlrewrite/TestMetric.java index 2b9428a3d..d92977c02 100644 --- a/wren-base/src/test/java/io/wren/base/sqlrewrite/TestMetric.java +++ b/wren-base/src/test/java/io/wren/base/sqlrewrite/TestMetric.java @@ -36,6 +36,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestMetric extends AbstractTestFramework @@ -355,6 +356,10 @@ public void testDynamicMetricOnModel() // apply count(custkey) on metric assertThat(query(rewrite("SELECT count(custkey) FROM CountOrders ORDER BY 1", true))) .isEqualTo(query("WITH output AS (SELECT custkey, count(*) FROM orders GROUP BY 1) SELECT count(custkey) FROM output")); + + assertThatThrownBy(() -> query(rewrite("SELECT count(custkey) FROM notfound ORDER BY 1", true))) + .rootCause() + .hasMessageMatching(".*Table with name notfound does not exist(.|\\n)*"); } @Test @@ -420,6 +425,8 @@ private String rewrite(String sql, WrenMDL wrenMDL, boolean enableDynamicField) .setSchema("test") .setEnableDynamic(enableDynamicField) .build(); - return WrenPlanner.rewrite(sql, sessionContext, new AnalyzedMDL(wrenMDL, null), List.of(WREN_SQL_REWRITE)); + String result = WrenPlanner.rewrite(sql, sessionContext, new AnalyzedMDL(wrenMDL, null), List.of(WREN_SQL_REWRITE)); + System.out.println(result); + return result; } }