Skip to content

Commit

Permalink
Skip hbo stats recording for nodes with dynamic filter
Browse files Browse the repository at this point in the history
  • Loading branch information
feilong-liu authored and xiaoxmeng committed Jun 5, 2024
1 parent 09685c8 commit 32c9f93
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.graph.GraphBuilder;
import com.google.common.graph.MutableGraph;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -66,6 +69,7 @@
import static com.facebook.presto.cost.HistoryBasedPlanStatisticsManager.historyBasedPlanCanonicalizationStrategyList;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
import static com.facebook.presto.sql.planner.planPrinter.PlanNodeStatsSummarizer.aggregateStageStats;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.graph.Traverser.forTree;
Expand Down Expand Up @@ -164,6 +168,7 @@ else if (trackStatsForFailedQueries) {
Map<PlanNodeWithHash, PlanStatisticsWithSourceInfo> planStatisticsMap = new HashMap<>();
Map<CanonicalPlan, PlanNodeCanonicalInfo> canonicalInfoMap = new HashMap<>();
Map<Integer, FinalAggregationStatsInfo> aggregationNodeMap = new HashMap<>();
Set<PlanNodeId> planNodeIdsDynamicFilter = getPlanNodeAppliedDynamicFilter(planNodeStatsMap, allStages);

queryInfo.getPlanCanonicalInfo().forEach(canonicalPlanWithInfo -> {
// We can have duplicate stats equivalent plan nodes. It's ok to use any stats in this case
Expand All @@ -177,7 +182,7 @@ else if (trackStatsForFailedQueries) {
boolean isScaledWriterStage = stageInfo.getPlan().isPresent() && stageInfo.getPlan().get().getPartitioning().equals(SCALED_WRITER_DISTRIBUTION);
PlanNode root = stageInfo.getPlan().get().getRoot();
for (PlanNode planNode : forTree(PlanNode::getSources).depthFirstPreOrder(root)) {
if (!planNode.getStatsEquivalentPlanNode().isPresent() && !isAggregation(planNode, AggregationNode.Step.PARTIAL)) {
if ((!planNode.getStatsEquivalentPlanNode().isPresent() && !isAggregation(planNode, AggregationNode.Step.PARTIAL)) || planNodeIdsDynamicFilter.contains(planNode.getId())) {
continue;
}
PlanNodeStats planNodeStats = planNodeStatsMap.get(planNode.getId());
Expand Down Expand Up @@ -252,6 +257,54 @@ else if (trackStatsForFailedQueries) {
return ImmutableMap.copyOf(planStatisticsMap);
}

private static Set<PlanNodeId> getPlanNodeAppliedDynamicFilter(Map<PlanNodeId, PlanNodeStats> planNodeStatsMap, List<StageInfo> allStages)
{
Map<PlanNodeId, Set<PlanNodeId>> dynamicFilterNodeMap = new HashMap<>();
planNodeStatsMap.forEach((planNodeId, planNodeStats) -> {
if (planNodeStats.getDynamicFilterStats().isPresent()) {
if (!dynamicFilterNodeMap.containsKey(planNodeId)) {
dynamicFilterNodeMap.put(planNodeId, new HashSet<>());
}
dynamicFilterNodeMap.get(planNodeId).addAll(planNodeStats.getDynamicFilterStats().get().getProducerNodeIds());
}
});
if (dynamicFilterNodeMap.isEmpty()) {
return ImmutableSet.of();
}
// Now find the path between producer and child node having dynamic filter applied. Reverse the tree so that all nodes' out degree is at most 1 and easier to find path between nodes
MutableGraph<PlanNodeId> reversePlanTree = GraphBuilder.directed().allowsSelfLoops(false).build();
for (StageInfo stageInfo : allStages) {
if (!stageInfo.getPlan().isPresent()) {
continue;
}
PlanNode root = stageInfo.getPlan().get().getRoot();
for (PlanNode planNode : forTree(PlanNode::getSources).depthFirstPreOrder(root)) {
for (PlanNode child : planNode.getSources()) {
reversePlanTree.putEdge(child.getId(), planNode.getId());
}
}
}
Set<PlanNodeId> planNodeIdsDynamicFilter = new HashSet<>();
dynamicFilterNodeMap.forEach((destNode, producerNodes) -> {
for (PlanNodeId producerNode : producerNodes) {
PlanNodeId rootNode = destNode;
Set<PlanNodeId> visitedNodes = new HashSet<>();
while (!rootNode.equals(producerNode)) {
visitedNodes.add(rootNode);
if (reversePlanTree.successors(rootNode).isEmpty()) {
break;
}
checkState(reversePlanTree.successors(rootNode).size() == 1);
rootNode = reversePlanTree.successors(rootNode).stream().findFirst().orElse(null);
}
if (rootNode.equals(producerNode)) {
planNodeIdsDynamicFilter.addAll(visitedNodes);
}
}
});
return planNodeIdsDynamicFilter;
}

private static void updatePartialAggregationStatistics(
AggregationNode partialAggregationNode,
Map<Integer, FinalAggregationStatsInfo> aggregationNodeStats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
*/
package com.facebook.presto.sql.planner.planPrinter;

import com.facebook.presto.operator.DynamicFilterStats;
import com.facebook.presto.spi.plan.PlanNodeId;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;

import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Math.max;
Expand Down Expand Up @@ -45,10 +47,12 @@ public HashCollisionPlanNodeStats(
long planNodeJoinBuildKeyCount,
long planNodeNullJoinProbeKeyCount,
long planNodeJoinProbeKeyCount,
Optional<DynamicFilterStats> dynamicFilterStats,
Map<String, OperatorHashCollisionsStats> operatorHashCollisionsStats)
{
super(planNodeId, planNodeScheduledTime, planNodeCpuTime, planNodeInputPositions, planNodeInputDataSize, planNodeRawInputPositions, planNodeRawInputDataSize,
planNodeOutputPositions, planNodeOutputDataSize, operatorInputStats, planNodeNullJoinBuildKeyCount, planNodeJoinBuildKeyCount, planNodeNullJoinProbeKeyCount, planNodeJoinProbeKeyCount);
planNodeOutputPositions, planNodeOutputDataSize, operatorInputStats, planNodeNullJoinBuildKeyCount, planNodeJoinBuildKeyCount, planNodeNullJoinProbeKeyCount,
planNodeJoinProbeKeyCount, dynamicFilterStats);
this.operatorHashCollisionsStats = requireNonNull(operatorHashCollisionsStats, "operatorHashCollisionsStats is null");
}

Expand Down Expand Up @@ -108,6 +112,7 @@ public PlanNodeStats mergeWith(PlanNodeStats other)
merged.getPlanNodeJoinBuildKeyCount(),
merged.getPlanNodeNullJoinProbeKeyCount(),
merged.getPlanNodeJoinProbeKeyCount(),
merged.getDynamicFilterStats(),
operatorHashCollisionsStats);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
*/
package com.facebook.presto.sql.planner.planPrinter;

import com.facebook.presto.operator.DynamicFilterStats;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.util.Mergeable;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;

import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.util.MoreMaps.mergeMaps;
Expand Down Expand Up @@ -49,6 +51,7 @@ public class PlanNodeStats
private final long planNodeJoinBuildKeyCount;
private final long planNodeNullJoinProbeKeyCount;
private final long planNodeJoinProbeKeyCount;
private final Optional<DynamicFilterStats> dynamicFilterStats;

PlanNodeStats(
PlanNodeId planNodeId,
Expand All @@ -64,7 +67,8 @@ public class PlanNodeStats
long planNodeNullJoinBuildKeyCount,
long planNodeJoinBuildKeyCount,
long planNodeNullJoinProbeKeyCount,
long planNodeJoinProbeKeyCount)
long planNodeJoinProbeKeyCount,
Optional<DynamicFilterStats> dynamicFilterStats)
{
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");

Expand All @@ -82,6 +86,7 @@ public class PlanNodeStats
this.planNodeJoinBuildKeyCount = planNodeJoinBuildKeyCount;
this.planNodeNullJoinProbeKeyCount = planNodeNullJoinProbeKeyCount;
this.planNodeJoinProbeKeyCount = planNodeJoinProbeKeyCount;
this.dynamicFilterStats = dynamicFilterStats;
}

private static double computedStdDev(double sumSquared, double sum, long n)
Expand Down Expand Up @@ -181,6 +186,25 @@ public long getPlanNodeJoinProbeKeyCount()
return planNodeJoinProbeKeyCount;
}

public Optional<DynamicFilterStats> getDynamicFilterStats()
{
return dynamicFilterStats;
}

public static Optional<DynamicFilterStats> mergeDynamicFilterStats(Optional<DynamicFilterStats> stats1, Optional<DynamicFilterStats> stats2)
{
Optional<DynamicFilterStats> optionalDynamicFilterStats = Optional.empty();
if (stats1.isPresent()) {
DynamicFilterStats dynamicFilterStats = stats1.get();
stats2.ifPresent(dynamicFilterStats::mergeWith);
optionalDynamicFilterStats = Optional.of(dynamicFilterStats);
}
else if (stats2.isPresent()) {
optionalDynamicFilterStats = Optional.of(stats2.get());
}
return optionalDynamicFilterStats;
}

@Override
public PlanNodeStats mergeWith(PlanNodeStats other)
{
Expand All @@ -198,6 +222,7 @@ public PlanNodeStats mergeWith(PlanNodeStats other)
long planNodeJoinBuildKeyCount = this.planNodeJoinBuildKeyCount + other.planNodeJoinBuildKeyCount;
long planNodeNullJoinProbeKeyCount = this.planNodeNullJoinProbeKeyCount + other.planNodeNullJoinProbeKeyCount;
long planNodeJoinProbeKeyCount = this.planNodeJoinProbeKeyCount + other.planNodeJoinProbeKeyCount;
Optional<DynamicFilterStats> optionalDynamicFilterStats = mergeDynamicFilterStats(this.dynamicFilterStats, other.dynamicFilterStats);

return new PlanNodeStats(
planNodeId,
Expand All @@ -210,6 +235,7 @@ public PlanNodeStats mergeWith(PlanNodeStats other)
planNodeNullJoinBuildKeyCount,
planNodeJoinBuildKeyCount,
planNodeNullJoinProbeKeyCount,
planNodeJoinProbeKeyCount);
planNodeJoinProbeKeyCount,
optionalDynamicFilterStats);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.execution.StageInfo;
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.operator.DynamicFilterStats;
import com.facebook.presto.operator.HashCollisionsInfo;
import com.facebook.presto.operator.OperatorStats;
import com.facebook.presto.operator.PipelineStats;
Expand All @@ -29,6 +30,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.util.MoreMaps.mergeMaps;
Expand Down Expand Up @@ -84,6 +86,7 @@ private static List<PlanNodeStats> getPlanNodeStats(TaskStats taskStats)
Map<PlanNodeId, Long> planNodeJoinBuildKeyCount = new HashMap<>();
Map<PlanNodeId, Long> planNodeNullJoinProbeKeyCount = new HashMap<>();
Map<PlanNodeId, Long> planNodeJoinProbeKeyCount = new HashMap<>();
Map<PlanNodeId, Optional<DynamicFilterStats>> planNodeIdDynamicFilterStatsMap = new HashMap<>();

Map<PlanNodeId, Map<String, OperatorInputStats>> operatorInputStats = new HashMap<>();
Map<PlanNodeId, Map<String, OperatorHashCollisionsStats>> operatorHashCollisionsStats = new HashMap<>();
Expand Down Expand Up @@ -158,6 +161,8 @@ private static List<PlanNodeStats> getPlanNodeStats(TaskStats taskStats)
planNodeNullJoinProbeKeyCount.merge(planNodeId, operatorStats.getNullJoinProbeKeyCount(), Long::sum);
planNodeJoinProbeKeyCount.merge(planNodeId, operatorStats.getJoinProbeKeyCount(), Long::sum);

planNodeIdDynamicFilterStatsMap.merge(planNodeId, Optional.of(operatorStats.getDynamicFilterStats()), PlanNodeStats::mergeDynamicFilterStats);

processedNodes.add(planNodeId);
}

Expand Down Expand Up @@ -218,6 +223,7 @@ private static List<PlanNodeStats> getPlanNodeStats(TaskStats taskStats)
planNodeJoinBuildKeyCount.get(planNodeId),
planNodeNullJoinProbeKeyCount.get(planNodeId),
planNodeJoinProbeKeyCount.get(planNodeId),
planNodeIdDynamicFilterStatsMap.get(planNodeId),
operatorHashCollisionsStats.get(planNodeId));
}
else if (windowNodeStats.containsKey(planNodeId)) {
Expand All @@ -236,6 +242,7 @@ else if (windowNodeStats.containsKey(planNodeId)) {
planNodeJoinBuildKeyCount.get(planNodeId),
planNodeNullJoinProbeKeyCount.get(planNodeId),
planNodeJoinProbeKeyCount.get(planNodeId),
planNodeIdDynamicFilterStatsMap.get(planNodeId),
windowNodeStats.get(planNodeId));
}
else {
Expand All @@ -253,7 +260,8 @@ else if (windowNodeStats.containsKey(planNodeId)) {
planNodeNullJoinBuildKeyCount.get(planNodeId),
planNodeJoinBuildKeyCount.get(planNodeId),
planNodeNullJoinProbeKeyCount.get(planNodeId),
planNodeJoinProbeKeyCount.get(planNodeId));
planNodeJoinProbeKeyCount.get(planNodeId),
planNodeIdDynamicFilterStatsMap.get(planNodeId));
}

stats.add(nodeStats);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
*/
package com.facebook.presto.sql.planner.planPrinter;

import com.facebook.presto.operator.DynamicFilterStats;
import com.facebook.presto.spi.plan.PlanNodeId;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;

import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;

Expand All @@ -41,10 +43,12 @@ public WindowPlanNodeStats(
long planNodeJoinBuildKeyCount,
long planNodeNullJoinProbeKeyCount,
long planNodeJoinProbeKeyCount,
Optional<DynamicFilterStats> dynamicFilterStats,
WindowOperatorStats windowOperatorStats)
{
super(planNodeId, planNodeScheduledTime, planNodeCpuTime, planNodeInputPositions, planNodeInputDataSize, planNodeRawInputPositions, planNodeRawInputDataSize,
planNodeOutputPositions, planNodeOutputDataSize, operatorInputStats, planNodeNullJoinBuildKeyCount, planNodeJoinBuildKeyCount, planNodeNullJoinProbeKeyCount, planNodeJoinProbeKeyCount);
planNodeOutputPositions, planNodeOutputDataSize, operatorInputStats, planNodeNullJoinBuildKeyCount, planNodeJoinBuildKeyCount, planNodeNullJoinProbeKeyCount,
planNodeJoinProbeKeyCount, dynamicFilterStats);
this.windowOperatorStats = windowOperatorStats;
}

Expand Down Expand Up @@ -74,6 +78,7 @@ public PlanNodeStats mergeWith(PlanNodeStats other)
merged.getPlanNodeJoinBuildKeyCount(),
merged.getPlanNodeNullJoinProbeKeyCount(),
merged.getPlanNodeJoinProbeKeyCount(),
merged.getDynamicFilterStats(),
windowOperatorStats);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.nativeworker;

import com.facebook.presto.Session;
import com.facebook.presto.execution.SqlQueryManager;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.testing.InMemoryHistoryBasedPlanStatisticsProvider;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.facebook.presto.tests.DistributedQueryRunner;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node;
import static java.lang.Double.NaN;

@Test(singleThreaded = true)
public class TestNativeHistoryBasedStatsTracking
extends AbstractTestQueryFramework
{
@Override
protected QueryRunner createQueryRunner()
throws Exception
{
return PrestoNativeQueryRunnerUtils.createNativeQueryRunner(true);
}

@BeforeMethod(alwaysRun = true)
public void setUp()
{
getHistoryProvider().clearCache();
}

@Test
public void testDynamicFilterEnabled()
{
Session broadcastSession = Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty(JOIN_DISTRIBUTION_TYPE, "BROADCAST").build();
String sql = "select s.name, s.acctbal, sum(l.quantity) from lineitem l join supplier s on l.suppkey=s.suppkey where acctbal < 0 and length(l.comment) > 2 group by 1, 2";
// CBO Statistics
assertPlan(
broadcastSession,
sql,
anyTree(
node(ProjectNode.class, anyTree(any())).withOutputRowCount(NaN),
anyTree(any())));

// HBO Statistics
executeAndTrackHistory(sql, broadcastSession);
assertPlan(
broadcastSession,
sql,
anyTree(
node(ProjectNode.class, anyTree(any())).withOutputRowCount(NaN),
anyTree(any())));
}

private void executeAndTrackHistory(String sql, Session session)
{
getQueryRunner().execute(session, sql);
getHistoryProvider().waitProcessQueryEvents();
}

private InMemoryHistoryBasedPlanStatisticsProvider getHistoryProvider()
{
DistributedQueryRunner queryRunner = (DistributedQueryRunner) getQueryRunner();
SqlQueryManager sqlQueryManager = (SqlQueryManager) queryRunner.getCoordinator().getQueryManager();
return (InMemoryHistoryBasedPlanStatisticsProvider) sqlQueryManager.getHistoryBasedPlanStatisticsTracker().getHistoryBasedPlanStatisticsProvider();
}
}
Loading

0 comments on commit 32c9f93

Please sign in to comment.