Skip to content

Commit

Permalink
Enhance the used columns analyze and required datasets for dynamic qu…
Browse files Browse the repository at this point in the history
…ery mode (#556)

* check if relationable is added before genrated

* enhance analyze winodw, groupby ,orderby and aliased relation

* fix the relationable generated
  • Loading branch information
goldmedal authored May 21, 2024
1 parent 3f752ef commit c530ccb
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 32 deletions.
12 changes: 10 additions & 2 deletions wren-base/src/main/java/io/wren/base/WrenMDL.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.wren.base.dto.Manifest;
import io.wren.base.dto.Metric;
import io.wren.base.dto.Model;
import io.wren.base.dto.Relationable;
import io.wren.base.dto.Relationship;
import io.wren.base.dto.View;
import io.wren.base.jinjava.JinjavaExpressionProcessor;
Expand Down Expand Up @@ -255,13 +256,13 @@ public List<View> listViews()
return manifest.getViews();
}

public static Optional<io.wren.base.dto.Column> getRelationshipColumn(Model model, String name)
public static Optional<Column> getRelationshipColumn(Model model, String name)
{
return getColumn(model, name)
.filter(column -> column.getRelationship().isPresent());
}

private static Optional<io.wren.base.dto.Column> getColumn(Model model, String name)
private static Optional<Column> getColumn(Model model, String name)
{
requireNonNull(model);
requireNonNull(name);
Expand Down Expand Up @@ -317,4 +318,11 @@ public boolean isObjectExist(String name)
|| getCumulativeMetric(name).isPresent()
|| getView(name).isPresent();
}

public Optional<Relationable> getRelationable(String name)
{
return getModel(name)
.map(model -> (Relationable) model)
.or(() -> getMetric(name).map(metric -> (Relationable) metric));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
import io.wren.base.SessionContext;
import io.wren.base.Utils;
import io.wren.base.WrenMDL;
import io.wren.base.dto.Column;
import io.wren.base.dto.CumulativeMetric;
import io.wren.base.dto.Metric;
import io.wren.base.dto.Model;
import io.wren.base.dto.Relationable;
import io.wren.base.sqlrewrite.analyzer.Analysis;
import io.wren.base.sqlrewrite.analyzer.StatementAnalyzer;
import org.jgrapht.graph.DirectedAcyclicGraph;
Expand All @@ -47,7 +49,9 @@

import static com.google.common.base.Strings.nullToEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.wren.base.sqlrewrite.Utils.toCatalogSchemaTableName;
import static java.lang.String.format;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;

Expand Down Expand Up @@ -92,19 +96,23 @@ public Statement apply(Statement root, SessionContext sessionContext, Analysis a
.filter(e -> wrenMDL.getView(e.getKey()).isEmpty())
.collect(toMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> a, LinkedHashMap::new));

// Some node be applied `count(*)` which won't be collected but its source is required.
analysis.getRequiredSourceNodes().forEach(node -> {
String tableName = analysis.getSourceNodeNames(node).map(QualifiedName::toString)
.orElseThrow(() -> new IllegalArgumentException(format("source node name not found: %s", node)));
if (!tableRequiredFields.containsKey(tableName)) {
Relationable relationable = wrenMDL.getRelationable(tableName)
.orElseThrow(() -> new IllegalArgumentException(format("dataset not found: %s", tableName)));
tableRequiredFields.put(tableName, relationable.getColumns().stream().map(Column::getName).collect(toImmutableSet()));
}
});

ImmutableList.Builder<QueryDescriptor> descriptorsBuilder = ImmutableList.builder();
tableRequiredFields.forEach((name, value) -> {
addDescriptor(name, value, wrenMDL, descriptorsBuilder);
visitedTables.remove(toCatalogSchemaTableName(sessionContext, QualifiedName.of(name)));
});

// Some node be applied `count(*)` which won't be collected but its source is required.
analysis.getRequiredSourceNodes().forEach(node ->
analysis.getSourceNodeNames(node).map(QualifiedName::toString).ifPresent(name -> {
addDescriptor(name, wrenMDL, descriptorsBuilder);
visitedTables.remove(toCatalogSchemaTableName(sessionContext, QualifiedName.of(name)));
}));

List<WithQuery> withQueries = new ArrayList<>();
// add date spine if needed
if (tableRequiredFields.keySet().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
import io.trino.sql.tree.DefaultTraversalVisitor;
import io.trino.sql.tree.DereferenceExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FrameBound;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SubqueryExpression;
import io.trino.sql.tree.Window;
import io.trino.sql.tree.WindowOperation;
import io.trino.sql.tree.WindowSpecification;
import io.wren.base.SessionContext;
import io.wren.base.WrenMDL;

Expand Down Expand Up @@ -111,6 +115,9 @@ protected Void visitFunctionCall(FunctionCall node, Void context)
return null;
}
node.getArguments().forEach(this::process);
node.getWindow().ifPresent(this::analyzeWindow);
node.getFilter().ifPresent(this::process);
node.getOrderBy().ifPresent(orderBy -> orderBy.getSortItems().forEach(sortItem -> process(sortItem.getSortKey())));
return null;
}

Expand All @@ -121,6 +128,25 @@ protected Void visitSubqueryExpression(SubqueryExpression node, Void context)
return null;
}

@Override
protected Void visitWindowOperation(WindowOperation node, Void context)
{
analyzeWindow(node.getWindow());
return null;
}

private void analyzeWindow(Window window)
{
if (window instanceof WindowSpecification windowSpecification) {
windowSpecification.getPartitionBy().forEach(this::process);
windowSpecification.getOrderBy().ifPresent(orderBy -> orderBy.getSortItems().forEach(sortItem -> process(sortItem.getSortKey())));
windowSpecification.getFrame().ifPresent(frame -> {
frame.getStart().getValue().ifPresent(this::process);
frame.getEnd().flatMap(FrameBound::getValue).ifPresent(this::process);
});
}
}

public Map<NodeRef<Expression>, Field> getReferenceFields()
{
return referenceFields;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.DereferenceExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FrameBound;
import io.trino.sql.tree.FunctionRelation;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.Join;
Expand All @@ -41,6 +42,7 @@
import io.trino.sql.tree.TableSubquery;
import io.trino.sql.tree.Unnest;
import io.trino.sql.tree.Values;
import io.trino.sql.tree.WindowSpecification;
import io.trino.sql.tree.With;
import io.trino.sql.tree.WithQuery;
import io.wren.base.CatalogSchemaTableName;
Expand All @@ -52,13 +54,10 @@
import io.wren.base.dto.Model;
import io.wren.base.dto.TimeUnit;
import io.wren.base.dto.View;
import io.wren.base.sqlrewrite.analyzer.decisionpoint.QueryAnalysis;
import io.wren.base.sqlrewrite.analyzer.matcher.PredicateMatcher;

import javax.annotation.Nullable;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -315,27 +314,24 @@ protected Scope visitQuery(Query node, Optional<Scope> scope)
@Override
protected Scope visitQuerySpecification(QuerySpecification node, Optional<Scope> scope)
{
QueryAnalysis.Builder queryAnalysisBuilder = QueryAnalysis.builder();
Scope sourceScope = analyzeFrom(node, scope);
List<Expression> outputExpressions = analyzeSelect(node, sourceScope);
analyzeSelect(node, sourceScope);
node.getWhere().ifPresent(where -> analyzeWhere(where, sourceScope));
node.getHaving().ifPresent(having -> analyzeExpression(having, sourceScope));
node.getWindows().forEach(window -> analyzeWindowSpecification(window.getWindow(), sourceScope));
node.getGroupBy().ifPresent(groupBy -> groupBy.getGroupingElements().forEach(groupingElement -> {
groupingElement.getExpressions().forEach(expression -> analyzeExpression(expression, sourceScope));
}));
node.getOrderBy().ifPresent(orderBy -> orderBy.getSortItems()
.forEach(item -> {
QualifiedName name;
if (item.getSortKey() instanceof LongLiteral) {
long index = ((LongLiteral) item.getSortKey()).getValue() - 1;
name = getQualifiedName(outputExpressions.get((int) index));
}
else {
name = getQualifiedName(item.getSortKey());
if (!(item.getSortKey() instanceof LongLiteral)) {
analyzeExpression(item.getSortKey(), sourceScope);
}
}));
// TODO: this scope is wrong.
return createAndAssignScope(node, scope, sourceScope);
}

private List<Expression> analyzeSelect(QuerySpecification node, Scope scope)
private void analyzeSelect(QuerySpecification node, Scope scope)
{
ImmutableList.Builder<Expression> outputExpressions = ImmutableList.builder();
for (SelectItem item : node.getSelect().getSelectItems()) {
Expand All @@ -349,7 +345,7 @@ else if (item instanceof SingleColumn) {
throw new IllegalArgumentException("Unsupported SelectItem type: " + item.getClass().getName());
}
}
return outputExpressions.build();
outputExpressions.build();
}

private void analyzeSelectAllColumns(AllColumns allColumns, Scope scope, ImmutableList.Builder<Expression> outputExpressions)
Expand Down Expand Up @@ -394,16 +390,21 @@ private Scope analyzeFrom(QuerySpecification node, Optional<Scope> scope)
private void analyzeWhere(Expression node, Scope scope)
{
ExpressionAnalysis expressionAnalysis = analyzeExpression(node, scope);
Map<NodeRef<Expression>, Field> fields = expressionAnalysis.getReferencedFields();
expressionAnalysis.getPredicates().stream()
.filter(PredicateMatcher.PREDICATE_MATCHER::shapeMatches)
.forEach(comparisonExpression -> {
Expression expression = comparisonExpression.getLeft();
});
typeCoercionOptional.flatMap(typeCoercion -> typeCoercion.coerceExpression(node, scope))
.ifPresent(expression -> analysis.addTypeCoercion(NodeRef.of(node), expression));
}

private void analyzeWindowSpecification(WindowSpecification windowSpecification, Scope scope)
{
windowSpecification.getExistingWindowName().ifPresent(name -> analyzeExpression(name, scope));
windowSpecification.getPartitionBy().forEach(expression -> analyzeExpression(expression, scope));
windowSpecification.getOrderBy().ifPresent(orderBy -> orderBy.getSortItems().forEach(item -> analyzeExpression(item.getSortKey(), scope)));
windowSpecification.getFrame().ifPresent(frame -> {
frame.getStart().getValue().ifPresent(start -> analyzeExpression(start, scope));
frame.getEnd().flatMap(FrameBound::getValue).ifPresent(end -> analyzeExpression(end, scope));
});
}

@Override
protected Scope visitValues(Values node, Optional<Scope> scope)
{
Expand Down Expand Up @@ -517,7 +518,8 @@ protected Scope visitJoin(Join node, Optional<Scope> scope)
protected Scope visitAliasedRelation(AliasedRelation relation, Optional<Scope> scope)
{
Scope relationScope = process(relation.getRelation(), scope);

relationScope.getRelationId().getSourceNode().flatMap(analysis::getSourceNodeNames)
.ifPresent(names -> analysis.addSourceNodeName(NodeRef.of(relation), names));
List<Field> fields = relationScope.getRelationType().getFields();
// if scope is a data source scope, we should get the fields from MDL
if (relationScope.isDataSourceScope()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,31 @@ public void testToOneCalculated()
"WHERE l.orderkey = 44995");

assertQuery(mdl, "SELECT count(*) FROM Lineitem", "SELECT count(*) FROM lineitem");
assertQuery(mdl, "SELECT count(*) FROM Lineitem WHERE orderkey = 44995",
"SELECT count(*) FROM lineitem WHERE orderkey = 44995");
assertQuery(mdl, "SELECT count(*) FROM Lineitem l WHERE l.orderkey = 44995",
"SELECT count(*) FROM lineitem l WHERE l.orderkey = 44995");

assertQuery(mdl, "SELECT col_1 FROM Lineitem ORDER BY col_2", "SELECT (totalprice + totalprice) AS col_1\n" +
"FROM lineitem l\n" +
"LEFT JOIN orders o ON l.orderkey = o.orderkey\n" +
"LEFT JOIN customer c ON o.custkey = c.custkey\n" +
"ORDER BY concat(l.orderkey, '#', c.custkey)");
assertQuery(mdl, "SELECT count(*) FROM Lineitem group by col_1, col_2 order by 1", "SELECT count(*)\n" +
"FROM lineitem l\n" +
"LEFT JOIN orders o ON l.orderkey = o.orderkey\n" +
"LEFT JOIN customer c ON o.custkey = c.custkey\n" +
"GROUP BY (totalprice + totalprice), concat(l.orderkey, '#', c.custkey)\n" +
"ORDER BY 1");
assertQuery(mdl, "SELECT rank() over (order by col_1) FROM Lineitem",
"SELECT rank() OVER (ORDER BY (totalprice + totalprice))\n" +
"FROM lineitem l\n" +
"LEFT JOIN orders o ON l.orderkey = o.orderkey");
assertQuery(mdl, "SELECT count(f1) FROM (SELECT lag(extendedprice) over (partition by col_2) as f1 FROM Lineitem)",
"SELECT count(f1) FROM (SELECT lag(extendedprice) OVER (PARTITION BY concat(l.orderkey, '#', c.custkey)) as f1\n" +
"FROM lineitem l\n" +
"LEFT JOIN orders o ON l.orderkey = o.orderkey\n" +
"LEFT JOIN customer c ON o.custkey = c.custkey)");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class TestMetric
extends AbstractTestFramework
Expand Down Expand Up @@ -355,6 +356,10 @@ public void testDynamicMetricOnModel()
// apply count(custkey) on metric
assertThat(query(rewrite("SELECT count(custkey) FROM CountOrders ORDER BY 1", true)))
.isEqualTo(query("WITH output AS (SELECT custkey, count(*) FROM orders GROUP BY 1) SELECT count(custkey) FROM output"));

assertThatThrownBy(() -> query(rewrite("SELECT count(custkey) FROM notfound ORDER BY 1", true)))
.rootCause()
.hasMessageMatching(".*Table with name notfound does not exist(.|\\n)*");
}

@Test
Expand Down Expand Up @@ -420,6 +425,8 @@ private String rewrite(String sql, WrenMDL wrenMDL, boolean enableDynamicField)
.setSchema("test")
.setEnableDynamic(enableDynamicField)
.build();
return WrenPlanner.rewrite(sql, sessionContext, new AnalyzedMDL(wrenMDL, null), List.of(WREN_SQL_REWRITE));
String result = WrenPlanner.rewrite(sql, sessionContext, new AnalyzedMDL(wrenMDL, null), List.of(WREN_SQL_REWRITE));
System.out.println(result);
return result;
}
}

0 comments on commit c530ccb

Please sign in to comment.