Skip to content

Commit

Permalink
Add builders and matchers for testing plans with PatternRecognitionNode
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi committed Jun 21, 2021
1 parent e07bc55 commit 378ab0d
Show file tree
Hide file tree
Showing 9 changed files with 761 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
import static io.trino.sql.planner.planprinter.TextRenderer.formatPositions;
import static io.trino.sql.planner.planprinter.TextRenderer.indentString;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.trino.sql.tree.PatternRecognitionRelation.RowsPerMatch.WINDOW;
import static java.lang.Math.abs;
import static java.lang.String.format;
import static java.util.Arrays.stream;
Expand Down Expand Up @@ -732,7 +733,9 @@ public Void visitPatternRecognition(PatternRecognitionNode node, Void context)
nodeOutput.appendDetailsLine("%s := %s", entry.getKey(), unresolveFunctions(entry.getValue().getExpressionAndValuePointers().getExpression()));
appendValuePointers(nodeOutput, entry.getValue().getExpressionAndValuePointers());
}
nodeOutput.appendDetailsLine(formatRowsPerMatch(node.getRowsPerMatch()));
if (node.getRowsPerMatch() != WINDOW) {
nodeOutput.appendDetailsLine(formatRowsPerMatch(node.getRowsPerMatch()));
}
nodeOutput.appendDetailsLine(formatSkipTo(node.getSkipToPosition(), node.getSkipToLabel()));
nodeOutput.appendDetailsLine(format("pattern[%s] (%s)", node.getPattern(), node.isInitial() ? "INITIAL" : "SEEK"));
nodeOutput.appendDetailsLine(format("subsets[%s]", node.getSubsets().entrySet().stream()
Expand Down Expand Up @@ -805,10 +808,9 @@ private String formatRowsPerMatch(RowsPerMatch rowsPerMatch)
return "ALL ROWS PER MATCH OMIT EMPTY MATCHES";
case ALL_WITH_UNMATCHED:
return "ALL ROWS PER MATCH WITH UNMATCHED ROWS";
case WINDOW:
throw new UnsupportedOperationException("pattern matching in WINDOW is not supported");
default:
throw new IllegalArgumentException("unexpected rowsPer match value: " + rowsPerMatch.name());
}
throw new UnsupportedOperationException("unsupported ROWS PER MATCH option");
}

private String formatSkipTo(Position position, Optional<IrLabel> label)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.rowpattern;

import com.google.common.collect.ImmutableMap;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.SymbolReference;

import java.util.Map;
import java.util.function.BiFunction;

import static io.trino.sql.util.AstUtils.treeEqual;

public class ExpressionAndValuePointersEquivalence
{
private ExpressionAndValuePointersEquivalence() {}

public static boolean equivalent(ExpressionAndValuePointers left, ExpressionAndValuePointers right)
{
return equivalent(left, right, Symbol::equals);
}

public static boolean equivalent(ExpressionAndValuePointers left, ExpressionAndValuePointers right, BiFunction<Symbol, Symbol, Boolean> symbolEquivalence)
{
if (left.getLayout().size() != right.getLayout().size()) {
return false;
}

for (int i = 0; i < left.getLayout().size(); i++) {
if (!left.getValuePointers().get(i).getLogicalIndexPointer().equals(right.getValuePointers().get(i).getLogicalIndexPointer())) {
return false;
}
}

ImmutableMap.Builder<Symbol, Symbol> mapping = ImmutableMap.builder();
for (int i = 0; i < left.getLayout().size(); i++) {
Symbol leftLayoutSymbol = left.getLayout().get(i);
boolean leftIsClassifier = left.getClassifierSymbols().contains(leftLayoutSymbol);
boolean leftIsMatchNumber = left.getMatchNumberSymbols().contains(leftLayoutSymbol);

Symbol rightLayoutSymbol = right.getLayout().get(i);
boolean rightIsClassifier = right.getClassifierSymbols().contains(rightLayoutSymbol);
boolean rightIsMatchNumber = right.getMatchNumberSymbols().contains(rightLayoutSymbol);

if (leftIsClassifier != rightIsClassifier || leftIsMatchNumber != rightIsMatchNumber) {
return false;
}

if (!leftIsClassifier && !leftIsMatchNumber) {
Symbol leftInputSymbol = left.getValuePointers().get(i).getInputSymbol();
Symbol rightInputSymbol = right.getValuePointers().get(i).getInputSymbol();
if (!symbolEquivalence.apply(leftInputSymbol, rightInputSymbol)) {
return false;
}
}

mapping.put(leftLayoutSymbol, rightLayoutSymbol);
}

return treeEqual(left.getExpression(), right.getExpression(), mappingComparator(mapping.build()));
}

private static BiFunction<Node, Node, Boolean> mappingComparator(Map<Symbol, Symbol> mapping)
{
return (left, right) -> {
if (left instanceof SymbolReference && right instanceof SymbolReference) {
Symbol leftSymbol = Symbol.from((SymbolReference) left);
Symbol rightSymbol = Symbol.from((SymbolReference) right);
return rightSymbol.equals(mapping.get(leftSymbol));
}
if (!left.shallowEquals(right)) {
return false;
}
return null;
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.assertions;

import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.PatternRecognitionNode.Measure;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.rowpattern.ExpressionAndValuePointersEquivalence;
import io.trino.sql.planner.rowpattern.ir.IrLabel;

import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
import static io.trino.sql.planner.assertions.PatternRecognitionMatcher.rewrite;
import static java.util.Objects.requireNonNull;

public class MeasureMatcher
implements RvalueMatcher
{
private final String expression;
private final Map<IrLabel, Set<IrLabel>> subsets;
private final Type type;

public MeasureMatcher(String expression, Map<IrLabel, Set<IrLabel>> subsets, Type type)
{
this.expression = requireNonNull(expression, "expression is null");
this.subsets = requireNonNull(subsets, "subsets is null");
this.type = requireNonNull(type, "type is null");
}

@Override
public Optional<Symbol> getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases)
{
Optional<Symbol> result = Optional.empty();
if (!(node instanceof PatternRecognitionNode)) {
return result;
}

PatternRecognitionNode patternRecognitionNode = (PatternRecognitionNode) node;

Measure expectedMeasure = new Measure(rewrite(expression, subsets), type);

for (Map.Entry<Symbol, Measure> assignment : patternRecognitionNode.getMeasures().entrySet()) {
Measure actualMeasure = assignment.getValue();
if (measuresEquivalent(actualMeasure, expectedMeasure, symbolAliases)) {
checkState(result.isEmpty(), "Ambiguous measures in %s", patternRecognitionNode);
result = Optional.of(assignment.getKey());
}
}

return result;
}

private static boolean measuresEquivalent(Measure actual, Measure expected, SymbolAliases symbolAliases)
{
if (!actual.getType().equals(expected.getType())) {
return false;
}

ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases);
return ExpressionAndValuePointersEquivalence.equivalent(
actual.getExpressionAndValuePointers(),
expected.getExpressionAndValuePointers(),
(actualSymbol, expectedSymbol) -> verifier.process(actualSymbol.toSymbolReference(), expectedSymbol.toSymbolReference()));
}

@Override
public String toString()
{
return toStringHelper(this)
.add("expression", expression)
.add("subsets", subsets)
.add("type", type)
.toString();
}
}
Loading

0 comments on commit 378ab0d

Please sign in to comment.