Skip to content

Commit

Permalink
Support subqueries in MATCH_RECOGNIZE
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi committed Aug 6, 2021
1 parent 236a6d4 commit 41d147b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,7 @@ protected Scope visitPatternRecognitionRelation(PatternRecognitionRelation relat
for (VariableDefinition variableDefinition : relation.getVariableDefinitions()) {
Expression expression = variableDefinition.getExpression();
ExpressionAnalysis expressionAnalysis = analyzePatternRecognitionExpression(expression, inputScope, patternRecognitionAnalysis.getAllLabels());
analysis.recordSubqueries(relation, expressionAnalysis);
Type type = expressionAnalysis.getType(expression);
if (!type.equals(BOOLEAN)) {
throw semanticException(TYPE_MISMATCH, expression, "Expression defining a label must be boolean (actual type: %s)", type);
Expand All @@ -1744,6 +1745,7 @@ protected Scope visitPatternRecognitionRelation(PatternRecognitionRelation relat
for (MeasureDefinition measureDefinition : relation.getMeasures()) {
Expression expression = measureDefinition.getExpression();
ExpressionAnalysis expressionAnalysis = analyzePatternRecognitionExpression(expression, inputScope, patternRecognitionAnalysis.getAllLabels());
analysis.recordSubqueries(relation, expressionAnalysis);
measureTypesBuilder.put(NodeRef.of(expression), expressionAnalysis.getType(expression));
}
Map<NodeRef<Node>, Type> measureTypes = measureTypesBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ private PlanBuilder planWindowMeasures(Node node, PlanBuilder subPlan, List<Wind
return subPlan;
}

private static List<Expression> extractPatternRecognitionExpressions(List<VariableDefinition> variableDefinitions, List<MeasureDefinition> measureDefinitions)
public static List<Expression> extractPatternRecognitionExpressions(List<VariableDefinition> variableDefinitions, List<MeasureDefinition> measureDefinitions)
{
ImmutableList.Builder<Expression> expressions = ImmutableList.builder();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
import static io.trino.sql.planner.PlanBuilder.newPlanBuilder;
import static io.trino.sql.planner.QueryPlanner.coerce;
import static io.trino.sql.planner.QueryPlanner.coerceIfNecessary;
import static io.trino.sql.planner.QueryPlanner.extractPatternRecognitionExpressions;
import static io.trino.sql.planner.QueryPlanner.planWindowSpecification;
import static io.trino.sql.planner.QueryPlanner.pruneInvisibleFields;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
Expand Down Expand Up @@ -360,6 +361,8 @@ protected RelationPlan visitPatternRecognitionRelation(PatternRecognitionRelatio
.forEach(outputLayout::add);
}

planBuilder = subqueryPlanner.handleSubqueries(planBuilder, extractPatternRecognitionExpressions(node.getVariableDefinitions(), node.getMeasures()), analysis.getSubqueries(node));

PatternRecognitionComponents components = planPatternRecognitionComponents(
planBuilder::rewrite,
node.getSubsets(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1314,4 +1314,63 @@ public void testMultipleMatchRecognize()
" ('A', 'B', 'C')," +
" ('A', 'B', 'C') ");
}

@Test
public void testSubqueries()
{
String query = "SELECT m.val " +
" FROM (VALUES " +
" (1, 100), " +
" (2, 200), " +
" (3, 300), " +
" (4, 400) " +
" ) t(id, value) " +
" MATCH_RECOGNIZE ( " +
" ORDER BY id " +
" MEASURES %s AS val " +
" ONE ROW PER MATCH " +
" AFTER MATCH SKIP TO NEXT ROW " +
" PATTERN (A+) " +
" DEFINE A AS %s " +
" ) AS m";

assertThat(assertions.query(format(query, "(SELECT 'x')", "(SELECT true)")))
.matches("VALUES " +
" ('x'), " +
" ('x'), " +
" ('x'), " +
" ('x') ");

// subquery nested in navigation
assertThat(assertions.query(format(query, "FINAL LAST(A.value + (SELECT 1000))", "FIRST(A.value < 0 OR (SELECT true))")))
.matches("VALUES " +
" (1400), " +
" (1400), " +
" (1400), " +
" (1400) ");

// IN-predicate: value and value list without column references
assertThat(assertions.query(format(query, "LAST(A.id < 0 OR 1 IN (SELECT 1))", "FIRST(A.id > 0 AND 1 IN (SELECT 1))")))
.matches("VALUES " +
" (true), " +
" (true), " +
" (true), " +
" (true) ");

// IN-predicate: unlabeled column reference in value
assertThat(assertions.query(format(query, "FIRST(id % 2 IN (SELECT 0))", "FIRST(value * 0 IN (SELECT 0))")))
.matches("VALUES " +
" (false), " +
" (true), " +
" (false), " +
" (true) ");

// EXISTS-predicate
assertThat(assertions.query(format(query, "LAST(A.value < 0 OR EXISTS(SELECT 1))", "FIRST(A.value < 0 OR EXISTS(SELECT 1))")))
.matches("VALUES " +
" (true), " +
" (true), " +
" (true), " +
" (true) ");
}
}

0 comments on commit 41d147b

Please sign in to comment.