Skip to content

Commit

Permalink
[feature](nereids)use mtmv to match legacy mv (#33699)
Browse files Browse the repository at this point in the history
  • Loading branch information
starocean999 authored Jul 1, 2024
1 parent 6f43278 commit 707cb40
Show file tree
Hide file tree
Showing 187 changed files with 6,141 additions and 265 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2369,7 +2369,7 @@ public static Boolean getResultIsNullable(String name, List<Type> typeList, List
FunctionName fnName = new FunctionName(name);
Function searchDesc = new Function(fnName, typeList, Type.INVALID, false, true);
List<Expr> mockedExprs = getMockedExprs(typeList, nullableList);
Function f = Env.getCurrentEnv().getFunction(searchDesc, Function.CompareMode.IS_IDENTICAL);
Function f = Env.getCurrentEnv().getFunction(searchDesc, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
return isNullable(f, mockedExprs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ public class AggregateFunction extends Function {
"approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT,
"ndv_no_finalize", FunctionSet.WINDOW_FUNNEL, FunctionSet.RETENTION, FunctionSet.SEQUENCE_MATCH,
FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG, FunctionSet.ARRAY_AGG,
FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET, FunctionSet.GROUP_ARRAY_INTERSECT);
FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET, FunctionSet.GROUP_ARRAY_INTERSECT,
FunctionSet.SUM0, FunctionSet.MULTI_DISTINCT_SUM0);

public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx", "first_value",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ public void addBuiltinBothScalaAndVectorized(Function fn) {

public static final String ARRAY_AGG = "array_agg";

public static final String SUM0 = "sum0";

public static final String MULTI_DISTINCT_SUM0 = "multi_distinct_sum0";

// Populate all the aggregate builtins in the catalog.
// null symbols indicate the function does not need that step of the evaluation.
// An empty symbol indicates a TODO for the BE to implement the function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,56 @@ public class UnboundFunction extends Function implements Unbound, PropagateNulla
private final String dbName;
private final boolean isDistinct;

// the start and end position of the function string in original sql
private final Optional<FunctionIndexInSql> functionIndexInSql;

/**
* FunctionIndexInSql
*/
public static class FunctionIndexInSql implements Comparable<FunctionIndexInSql> {
public final int functionNameBegin;
public final int functionNameEnd;
public final int functionExpressionEnd;

public FunctionIndexInSql(int nameBegin, int nameEnd, int expressionEnd) {
this.functionNameBegin = nameBegin;
this.functionNameEnd = nameEnd;
this.functionExpressionEnd = expressionEnd;
}

@Override
public int compareTo(FunctionIndexInSql functionIndexInSql) {
return this.functionNameBegin - functionIndexInSql.functionNameBegin;
}

public FunctionIndexInSql indexInQueryPart(int offset) {
return new FunctionIndexInSql(functionNameBegin - offset, functionNameEnd - offset,
functionExpressionEnd - offset);
}
}

public UnboundFunction(String name, List<Expression> arguments) {
this(null, name, false, arguments);
this(null, name, false, arguments, Optional.empty());
}

public UnboundFunction(String dbName, String name, List<Expression> arguments) {
this(dbName, name, false, arguments);
this(dbName, name, false, arguments, Optional.empty());
}

public UnboundFunction(String name, boolean isDistinct, List<Expression> arguments) {
this(null, name, isDistinct, arguments);
this(null, name, isDistinct, arguments, Optional.empty());
}

public UnboundFunction(String dbName, String name, boolean isDistinct, List<Expression> arguments) {
this(dbName, name, isDistinct, arguments, Optional.empty());
}

public UnboundFunction(String dbName, String name, boolean isDistinct,
List<Expression> arguments, Optional<FunctionIndexInSql> functionIndexInSql) {
super(name, arguments);
this.dbName = dbName;
this.isDistinct = isDistinct;
this.functionIndexInSql = functionIndexInSql;
}

@Override
Expand Down Expand Up @@ -100,6 +134,14 @@ public UnboundFunction withChildren(List<Expression> children) {
return new UnboundFunction(dbName, getName(), isDistinct, children);
}

public Optional<FunctionIndexInSql> getFunctionIndexInSql() {
return functionIndexInSql;
}

public UnboundFunction withIndexInSqlString(Optional<FunctionIndexInSql> functionIndexInSql) {
return new UnboundFunction(dbName, getName(), isDistinct, children, functionIndexInSql);
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.parser;

import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.DorisParser;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.BuiltinFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.commands.CreateMTMVCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;

/**
* LogicalPlanBuilderForSyncMv
*/
public class LogicalPlanBuilderForSyncMv extends LogicalPlanBuilder {
private Optional<String> querySql;

public LogicalPlanBuilderForSyncMv() {
super(false);
}

@Override
public Expression visitFunctionCallExpression(DorisParser.FunctionCallExpressionContext ctx) {
Expression expression = super.visitFunctionCallExpression(ctx);
if (expression instanceof UnboundFunction) {
return ((UnboundFunction) expression)
.withIndexInSqlString(Optional.of(new UnboundFunction.FunctionIndexInSql(
ctx.functionIdentifier().functionNameIdentifier().start.getStartIndex(),
ctx.functionIdentifier().functionNameIdentifier().stop.getStopIndex(),
ctx.stop.getStopIndex())));
} else {
return expression;
}
}

@Override
public LogicalPlan visitQuery(DorisParser.QueryContext ctx) {
LogicalPlan logicalPlan = super.visitQuery(ctx);
PlanUtils.OutermostPlanFinderContext outermostPlanFinderContext =
new PlanUtils.OutermostPlanFinderContext();
logicalPlan.accept(PlanUtils.OutermostPlanFinder.INSTANCE, outermostPlanFinderContext);

// find outermost logicalAggregate to rewrite agg_state related function
Plan outermostAgg = outermostPlanFinderContext.outermostPlan;
while (!(outermostAgg instanceof LogicalAggregate)) {
if (!outermostAgg.children().isEmpty()) {
outermostAgg = outermostAgg.child(0);
} else {
break;
}
}
String originSql = getOriginSql(ctx);
if (outermostAgg instanceof LogicalAggregate) {
List<NamedExpression> outputs = ((LogicalAggregate) outermostAgg).getOutputs();
TreeMap<Pair<Integer, Integer>, String> indexInSqlToString =
new TreeMap<>(new Pair.PairComparator<>());
AggStateFunctionFinder aggStateFunctionFinder =
new AggStateFunctionFinder(ctx.start.getStartIndex());
for (Expression expr : outputs) {
aggStateFunctionFinder.find(expr, indexInSqlToString);
}
querySql = Optional.of(rewriteSql(originSql, indexInSqlToString));
} else {
querySql = Optional.of(originSql);
}
return logicalPlan;
}

@Override
public CreateMTMVCommand visitCreateMTMV(DorisParser.CreateMTMVContext ctx) {
visitQuery(ctx.query());
return null;
}

public Optional<String> getQuerySql() {
return querySql;
}

private static class AggStateFunctionFinder
extends DefaultExpressionRewriter<TreeMap<Pair<Integer, Integer>, String>> {
private int sqlBeginIndex;

private FunctionRegistry functionRegistry;

public AggStateFunctionFinder(int sqlBeginIndex) {
this.sqlBeginIndex = sqlBeginIndex;
this.functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
}

public Expression find(Expression expression,
TreeMap<Pair<Integer, Integer>, String> indexInSqlToNewString) {
return expression.accept(this, indexInSqlToNewString);
}

@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction,
TreeMap<Pair<Integer, Integer>, String> indexInSqlToNewString) {
if (unboundFunction.getFunctionIndexInSql().isPresent()) {
// try bind agg function
List<Object> arguments = unboundFunction.isDistinct()
? ImmutableList.builder().add(unboundFunction.isDistinct())
.addAll(unboundFunction.getArguments()).build()
: (List) unboundFunction.getArguments();

String functionName = unboundFunction.getName();
FunctionBuilder builder = functionRegistry
.findFunctionBuilder(unboundFunction.getDbName(), functionName, arguments);
if (builder instanceof BuiltinFunctionBuilder) {
BoundFunction boundFunction =
(BoundFunction) builder.build(functionName, arguments).first;
if (boundFunction instanceof AggregateFunction) {
// rewrite to agg_state
UnboundFunction.FunctionIndexInSql functionIndexInSql = unboundFunction
.getFunctionIndexInSql().get().indexInQueryPart(sqlBeginIndex);
functionName = boundFunction.getName();
switch (functionName) {
case "min":
case "max":
case "sum":
case "count":
case "bitmap_union":
case "hll_union": {
// no need rewrite
break;
}
default: {
indexInSqlToNewString.put(
Pair.of(functionIndexInSql.functionNameBegin,
functionIndexInSql.functionNameEnd),
String.format("%s%s(%s%s", functionName,
AggCombinerFunctionBuilder.UNION_SUFFIX,
functionName,
AggCombinerFunctionBuilder.STATE_SUFFIX));
indexInSqlToNewString
.put(Pair.of(functionIndexInSql.functionExpressionEnd,
functionIndexInSql.functionExpressionEnd), "))");
break;
}
}
}
}
}
return unboundFunction;
}
}

private static String rewriteSql(String querySql,
Map<Pair<Integer, Integer>, String> indexStringSqlMap) {
StringBuilder builder = new StringBuilder();
int beg = 0;
for (Map.Entry<Pair<Integer, Integer>, String> entry : indexStringSqlMap.entrySet()) {
Pair<Integer, Integer> index = entry.getKey();
builder.append(querySql, beg, index.first);
builder.append(entry.getValue());
beg = index.second + 1;
}
builder.append(querySql, beg, querySql.length());
return builder.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,13 @@ private <T> T parseForCreateViewInternal(String sql, @Nullable LogicalPlanBuilde
return (T) realLogicalPlanBuilder.visit(tree);
}

public Optional<String> parseForSyncMv(String sql) {
ParserRuleContext tree = toAst(sql, DorisParser::singleStatement);
LogicalPlanBuilderForSyncMv logicalPlanBuilderForSyncMv = new LogicalPlanBuilderForSyncMv();
logicalPlanBuilderForSyncMv.visit(tree);
return logicalPlanBuilderForSyncMv.getQuerySql();
}

/** toAst */
public static ParserRuleContext toAst(String sql, Function<DorisParser, ParserRuleContext> parseFunction) {
DorisLexer lexer = new DorisLexer(new CaseInsensitiveStream(CharStreams.fromString(sql)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewFilterProjectScanRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewFilterScanRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewOnlyJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewOnlyScanRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterJoinRule;
Expand Down Expand Up @@ -240,6 +241,7 @@ public class RuleSet {
.add(MaterializedViewProjectScanRule.INSTANCE)
.add(MaterializedViewProjectFilterScanRule.INSTANCE)
.add(MaterializedViewAggregateOnNoneAggregateRule.INSTANCE)
.add(MaterializedViewOnlyScanRule.INSTANCE)
.build();

public static final List<Rule> DPHYP_REORDER_RULES = ImmutableList.<Rule>builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ public enum RuleType {
MATERIALIZED_VIEW_PROJECT_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_PROJECT_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_FILTER_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_ONLY_SCAN(RuleTypeClass.EXPLORATION),

// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ private boolean isGroupByEquals(Pair<Plan, LogicalAggregate<Plan>> queryTopPlanA
viewGroupShuttledExpressionQueryBased.add(
groupByExpressionToViewShuttledExpressionQueryBasedMap.get(viewExpression));
}
return queryGroupShuttledExpression.equals(viewGroupShuttledExpressionQueryBased);
return materializationContext instanceof SyncMaterializationContext ? false
: queryGroupShuttledExpression.equals(viewGroupShuttledExpressionQueryBased);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ protected List<Plan> doRewrite(StructInfo queryStructInfo, CascadesContext casca
}
rewrittenPlan = new LogicalFilter<>(Sets.newLinkedHashSet(rewriteCompensatePredicates), mvScan);
}
boolean checkResult = rewriteQueryByViewPreCheck(matchMode, queryStructInfo,
viewStructInfo, viewToQuerySlotMapping, rewrittenPlan, materializationContext);
if (!checkResult) {
continue;
}
// Rewrite query by view
rewrittenPlan = rewriteQueryByView(matchMode, queryStructInfo, viewStructInfo, viewToQuerySlotMapping,
rewrittenPlan, materializationContext, cascadesContext);
Expand Down Expand Up @@ -481,6 +486,23 @@ protected Pair<Map<BaseTableInfo, Set<String>>, Map<BaseTableInfo, Set<String>>>
return Pair.of(mvPartitionNeedRemoveNameMap, baseTablePartitionNeedUnionNameMap);
}

/**
* Query rewrite result may output origin plan , this will cause loop.
* if return origin plan, need add check hear.
*/
protected boolean rewriteQueryByViewPreCheck(MatchMode matchMode, StructInfo queryStructInfo,
StructInfo viewStructInfo, SlotMapping viewToQuerySlotMapping, Plan tempRewritedPlan,
MaterializationContext materializationContext) {
if (materializationContext instanceof SyncMaterializationContext
&& queryStructInfo.getBottomPlan() instanceof LogicalOlapScan) {
LogicalOlapScan olapScan = (LogicalOlapScan) queryStructInfo.getBottomPlan();
if (olapScan.getSelectedIndexId() != olapScan.getTable().getBaseIndexId()) {
return false;
}
}
return true;
}

/**
* Rewrite query by view, for aggregate or join rewriting should be different inherit class implementation
*/
Expand Down
Loading

0 comments on commit 707cb40

Please sign in to comment.