Skip to content

Commit

Permalink
Skip RemoveRedundantTableScanPredicate if predicate doesn't change
Browse files Browse the repository at this point in the history
  • Loading branch information
sopel39 committed Feb 10, 2022
1 parent d4fc6c5 commit fed993c
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,37 @@
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.DomainTranslator.ExtractionResult;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.Expression;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.matching.Capture.newCapture;
import static io.trino.spi.predicate.TupleDomain.intersect;
import static io.trino.sql.ExpressionUtils.combineConjuncts;
import static io.trino.sql.ExpressionUtils.extractConjuncts;
import static io.trino.sql.ExpressionUtils.filterDeterministicConjuncts;
import static io.trino.sql.ExpressionUtils.filterNonDeterministicConjuncts;
import static io.trino.sql.planner.iterative.rule.PushPredicateIntoTableScan.createResultingPredicate;
import static io.trino.sql.planner.plan.Patterns.filter;
import static io.trino.sql.planner.plan.Patterns.source;
import static io.trino.sql.planner.plan.Patterns.tableScan;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static java.lang.Boolean.FALSE;
import static java.lang.Boolean.TRUE;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toList;

public class RemoveRedundantTableScanPredicate
implements Rule<FilterNode>
Expand All @@ -57,7 +64,9 @@ public class RemoveRedundantTableScanPredicate

private static final Pattern<FilterNode> PATTERN =
filter().with(source().matching(
tableScan().capturedAs(TABLE_SCAN)));
tableScan().capturedAs(TABLE_SCAN)
// avoid extra computations if table scan doesn't have any enforced predicate
.matching(node -> !node.getEnforcedConstraint().isAll())));

private final PlannerContext plannerContext;
private final TypeAnalyzer typeAnalyzer;
Expand All @@ -77,38 +86,22 @@ public Pattern<FilterNode> getPattern()
@Override
public Result apply(FilterNode filterNode, Captures captures, Context context)
{
TableScanNode tableScan = captures.get(TABLE_SCAN);
Session session = context.getSession();
TableScanNode node = captures.get(TABLE_SCAN);
Expression predicate = filterNode.getPredicate();

PlanNode rewritten = removeRedundantTableScanPredicate(
tableScan,
filterNode.getPredicate(),
context.getSession(),
context.getSymbolAllocator(),
context.getIdAllocator());

if (rewritten instanceof FilterNode
&& Objects.equals(((FilterNode) rewritten).getPredicate(), filterNode.getPredicate())) {
return Result.empty();
}

return Result.ofPlanNode(rewritten);
}

private PlanNode removeRedundantTableScanPredicate(
TableScanNode node,
Expression predicate,
Session session,
SymbolAllocator symbolAllocator,
PlanNodeIdAllocator idAllocator)
{
Expression deterministicPredicate = filterDeterministicConjuncts(plannerContext.getMetadata(), predicate);
Expression nonDeterministicPredicate = filterNonDeterministicConjuncts(plannerContext.getMetadata(), predicate);

DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.getExtractionResult(
plannerContext,
ExtractionResult decomposedPredicate = getFullyExtractedPredicates(
session,
deterministicPredicate,
symbolAllocator.getTypes());
context.getSymbolAllocator().getTypes());

if (decomposedPredicate.getTupleDomain().isAll()) {
// no conjunct could be fully converted to tuple domain
return Result.empty();
}

TupleDomain<ColumnHandle> predicateDomain = decomposedPredicate.getTupleDomain()
.transformKeys(node.getAssignments()::get);
Expand All @@ -117,12 +110,12 @@ private PlanNode removeRedundantTableScanPredicate(
// TODO: DomainTranslator.fromPredicate can infer that the expression is "false" in some cases (TupleDomain.none()).
// This should move to another rule that simplifies the filter using that logic and then rely on RemoveTrivialFilters
// to turn the subtree into a Values node
return new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of());
return Result.ofPlanNode(new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of()));
}

if (node.getEnforcedConstraint().isNone()) {
// table scans with none domain should be converted to ValuesNode
return new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of());
return Result.ofPlanNode(new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of()));
}

Map<ColumnHandle, Domain> enforcedColumnDomains = node.getEnforcedConstraint().getDomains().orElseThrow(); // is not NONE
Expand All @@ -137,20 +130,41 @@ private PlanNode removeRedundantTableScanPredicate(
return predicateColumnDomain.intersect(enforcedColumnDomain);
});

if (unenforcedDomain.equals(predicateDomain)) {
// no change in filter predicate
return Result.empty();
}

Map<ColumnHandle, Symbol> assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
Expression resultingPredicate = createResultingPredicate(
plannerContext,
session,
symbolAllocator,
context.getSymbolAllocator(),
typeAnalyzer,
new DomainTranslator(plannerContext).toPredicate(session, unenforcedDomain.transformKeys(assignments::get)),
nonDeterministicPredicate,
decomposedPredicate.getRemainingExpression());

if (!TRUE_LITERAL.equals(resultingPredicate)) {
return new FilterNode(idAllocator.getNextId(), node, resultingPredicate);
return Result.ofPlanNode(new FilterNode(context.getIdAllocator().getNextId(), node, resultingPredicate));
}

return node;
return Result.ofPlanNode(node);
}

private ExtractionResult getFullyExtractedPredicates(Session session, Expression predicate, TypeProvider types)
{
Map<Boolean, List<ExtractionResult>> extractedPredicates = extractConjuncts(predicate).stream()
.map(conjunct -> DomainTranslator.getExtractionResult(plannerContext, session, conjunct, types))
.collect(groupingBy(result -> result.getRemainingExpression().equals(TRUE_LITERAL), toList()));
return new ExtractionResult(
intersect(extractedPredicates.getOrDefault(TRUE, ImmutableList.of()).stream()
.map(ExtractionResult::getTupleDomain)
.collect(toImmutableList())),
combineConjuncts(
plannerContext.getMetadata(),
extractedPredicates.getOrDefault(FALSE, ImmutableList.of()).stream()
.map(ExtractionResult::getRemainingExpression)
.collect(toImmutableList())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.trino.connector.CatalogName;
import io.trino.metadata.TableHandle;
import io.trino.plugin.tpch.TpchColumnHandle;
Expand All @@ -37,6 +38,7 @@

import static io.trino.spi.predicate.Domain.singleValue;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.planner.assertions.PlanMatchPattern.constrainedTableScanWithTableLayout;
import static io.trino.sql.planner.assertions.PlanMatchPattern.filter;
Expand Down Expand Up @@ -237,4 +239,64 @@ public void doesNotAddTableLayoutToFilterTableScan()
ImmutableMap.of(p.symbol("orderstatus", createVarcharType(1)), new TpchColumnHandle("orderstatus", createVarcharType(1))))))
.doesNotFire();
}

@Test
public void doesNotFireOnNoTableScanPredicate()
{
ColumnHandle columnHandle = new TpchColumnHandle("nationkey", BIGINT);
tester().assertThat(removeRedundantTableScanPredicate)
.on(p -> p.filter(expression("(nationkey > 3 OR nationkey > 0) AND (nationkey > 3 OR nationkey < 1)"),
p.tableScan(
nationTableHandle,
ImmutableList.of(p.symbol("nationkey", BIGINT)),
ImmutableMap.of(p.symbol("nationkey", BIGINT), columnHandle),
TupleDomain.all())))
.doesNotFire();
}

@Test
public void doesNotFireOnNotFullyExtractedConjunct()
{
ColumnHandle columnHandle = new TpchColumnHandle("name", VARCHAR);
tester().assertThat(removeRedundantTableScanPredicate)
.on(p -> p.filter(expression("name LIKE 'LARGE PLATED %'"),
p.tableScan(
nationTableHandle,
ImmutableList.of(p.symbol("name", VARCHAR)),
ImmutableMap.of(p.symbol("name", VARCHAR), columnHandle),
TupleDomain.fromFixedValues(ImmutableMap.of(
columnHandle, NullableValue.of(VARCHAR, Slices.utf8Slice("value")))))))
.doesNotFire();
}

@Test
public void skipNotFullyExtractedConjunct()
{
ColumnHandle textColumnHandle = new TpchColumnHandle("name", VARCHAR);
ColumnHandle nationKeyColumnHandle = new TpchColumnHandle("nationkey", BIGINT);
tester().assertThat(removeRedundantTableScanPredicate)
.on(p -> p.filter(expression("name LIKE 'LARGE PLATED %' AND nationkey = BIGINT '44'"),
p.tableScan(
nationTableHandle,
ImmutableList.of(
p.symbol("name", VARCHAR),
p.symbol("nationkey", BIGINT)),
ImmutableMap.of(
p.symbol("name", VARCHAR), textColumnHandle,
p.symbol("nationkey", BIGINT), nationKeyColumnHandle),
TupleDomain.fromFixedValues(ImmutableMap.of(
textColumnHandle, NullableValue.of(VARCHAR, Slices.utf8Slice("value")),
nationKeyColumnHandle, NullableValue.of(BIGINT, (long) 44))))))
.matches(
filter(
expression("name LIKE 'LARGE PLATED %'"),
constrainedTableScanWithTableLayout(
"nation",
ImmutableMap.of(
"nationkey", Domain.singleValue(BIGINT, 44L),
"name", Domain.singleValue(VARCHAR, Slices.utf8Slice("value"))),
ImmutableMap.of(
"nationkey", "nationkey",
"name", "name"))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public void testSubsumePartitionPartOfAFilter()
exchange(LOCAL,
exchange(REMOTE, REPARTITION,
project(
filter("R_INT_COL IN (2, 3, 4)",
filter("R_INT_COL IN (2, 3, 4) AND R_INT_COL BETWEEN 2 AND 4", // TODO: R_INT_COL BETWEEN 2 AND 4 is redundant
tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))));
}

Expand Down

0 comments on commit fed993c

Please sign in to comment.