From 9fe77b6372b8a1ae47d75bda3bc8ab5fc0ec39b2 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Sat, 16 May 2020 00:23:28 +0530 Subject: [PATCH 1/3] Remove dynamic filters with expression which are not a SymbolReference --- .../iterative/rule/RemoveUnsupportedDynamicFilters.java | 4 +++- .../sql/planner/sanity/DynamicFiltersChecker.java | 8 +++++--- .../java/io/prestosql/sql/planner/TestDynamicFilter.java | 6 ++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java index c08c88d0437df..2843e1c7c8fea 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java @@ -34,6 +34,7 @@ import io.prestosql.sql.tree.ExpressionRewriter; import io.prestosql.sql.tree.ExpressionTreeRewriter; import io.prestosql.sql.tree.LogicalBinaryExpression; +import io.prestosql.sql.tree.SymbolReference; import java.util.HashSet; import java.util.List; @@ -222,7 +223,8 @@ private Expression removeDynamicFilters(Expression expression, Set allow .filter(conjunct -> getDescriptor(conjunct) .map(descriptor -> { - if (allowedDynamicFilterIds.contains(descriptor.getId())) { + if (descriptor.getInput() instanceof SymbolReference && + allowedDynamicFilterIds.contains(descriptor.getId())) { consumedDynamicFilterIds.add(descriptor.getId()); return true; } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/DynamicFiltersChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/DynamicFiltersChecker.java index 466558bd41bf7..b403222b109ef 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/DynamicFiltersChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/DynamicFiltersChecker.java @@ -29,6 +29,7 @@ import io.prestosql.sql.planner.plan.PlanVisitor; import io.prestosql.sql.planner.plan.TableScanNode; import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; import java.util.HashSet; import java.util.List; @@ -104,9 +105,10 @@ public Set visitFilter(FilterNode node, Void context) verify(node.getSource() instanceof TableScanNode, "Dynamic filters %s present in filter predicate whose source is not a table scan.", dynamicFilters); } ImmutableSet.Builder consumed = ImmutableSet.builder(); - dynamicFilters.stream() - .map(DynamicFilters.Descriptor::getId) - .forEach(consumed::add); + dynamicFilters.forEach(descriptor -> { + verify(descriptor.getInput() instanceof SymbolReference, "Dynamic filter expression must be a SymbolReference"); + consumed.add(descriptor.getId()); + }); consumed.addAll(node.getSource().accept(this, context)); return consumed.build(); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestDynamicFilter.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestDynamicFilter.java index 4c1bf01b24966..330297f4a7e20 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestDynamicFilter.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestDynamicFilter.java @@ -116,10 +116,8 @@ public void testJoinOnCast() node( JoinNode.class, anyTree( - node( - FilterNode.class, - tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))) - .with(numberOfDynamicFilters(1))), + project( + tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))), anyTree( tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))))); } From ac45ad2fefb409e01165452b857b1969557ab1aa Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Tue, 2 Jul 2019 19:25:29 +0530 Subject: [PATCH 2/3] Implement runtime partition pruning --- .../benchmark/AbstractOperatorBenchmark.java | 2 +- .../hive/BackgroundHiveSplitLoader.java | 23 +- .../plugin/hive/HivePartitionManager.java | 11 +- .../plugin/hive/HiveSplitManager.java | 28 +- .../hive/util/InternalHiveSplitFactory.java | 10 + .../plugin/hive/AbstractTestHive.java | 3 +- .../hive/AbstractTestHiveFileSystem.java | 3 +- .../hive/TestBackgroundHiveSplitLoader.java | 8 + .../execution/SqlQueryExecution.java | 32 +- .../java/io/prestosql/execution/SqlTask.java | 16 +- .../io/prestosql/execution/StageState.java | 24 ++ .../io/prestosql/execution/TaskStatus.java | 21 +- .../scheduler/SqlQueryScheduler.java | 22 +- .../io/prestosql/operator/TaskContext.java | 19 + .../prestosql/server/CoordinatorModule.java | 3 + .../server/DynamicFilterService.java | 225 ++++++++++ .../java/io/prestosql/split/SplitManager.java | 12 +- .../planner/DistributedExecutionPlanner.java | 14 +- .../sql/planner/LocalDynamicFilter.java | 74 +++- .../sql/planner/LocalExecutionPlanner.java | 20 +- .../iterative/rule/ExtractSpatialJoins.java | 2 +- .../prestosql/testing/LocalQueryRunner.java | 4 +- .../execution/MockRemoteTaskFactory.java | 6 +- .../server/TestDynamicFilterService.java | 392 ++++++++++++++++++ .../server/remotetask/TestHttpRemoteTask.java | 3 +- .../sql/planner/TestLocalDynamicFilter.java | 63 +-- .../ClassLoaderSafeConnectorSplitManager.java | 12 + .../spi/connector/ConnectorSplitManager.java | 14 + 28 files changed, 960 insertions(+), 106 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/server/DynamicFilterService.java create mode 100644 presto-main/src/test/java/io/prestosql/server/TestDynamicFilterService.java diff --git a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java index a82dc69b97627..f64541c3d3bbd 100644 --- a/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/io/prestosql/benchmark/AbstractOperatorBenchmark.java @@ -199,7 +199,7 @@ public OperatorFactory duplicate() private Split getLocalQuerySplit(Session session, TableHandle handle) { - SplitSource splitSource = localQueryRunner.getSplitManager().getSplits(session, handle, UNGROUPED_SCHEDULING); + SplitSource splitSource = localQueryRunner.getSplitManager().getSplits(session, handle, UNGROUPED_SCHEDULING, TupleDomain::all); List splits = new ArrayList<>(); while (!splitSource.isFinished()) { splits.addAll(getNextBatch(splitSource)); diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/BackgroundHiveSplitLoader.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/BackgroundHiveSplitLoader.java index 621fb6cc42aa2..40f6eb71c4850 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/BackgroundHiveSplitLoader.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/BackgroundHiveSplitLoader.java @@ -37,6 +37,7 @@ import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.spi.type.TypeManager; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; @@ -71,7 +72,9 @@ import java.util.concurrent.Executor; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.BooleanSupplier; import java.util.function.IntPredicate; +import java.util.function.Supplier; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -85,6 +88,7 @@ import static io.prestosql.plugin.hive.HiveErrorCode.HIVE_INVALID_METADATA; import static io.prestosql.plugin.hive.HiveErrorCode.HIVE_INVALID_PARTITION_VALUE; import static io.prestosql.plugin.hive.HiveErrorCode.HIVE_UNKNOWN_ERROR; +import static io.prestosql.plugin.hive.HivePartitionManager.partitionMatches; import static io.prestosql.plugin.hive.HiveSessionProperties.isForceLocalScheduling; import static io.prestosql.plugin.hive.metastore.MetastoreUtil.getHiveSchema; import static io.prestosql.plugin.hive.metastore.MetastoreUtil.getPartitionLocation; @@ -97,6 +101,7 @@ import static io.prestosql.plugin.hive.util.HiveUtil.getFooterCount; import static io.prestosql.plugin.hive.util.HiveUtil.getHeaderCount; import static io.prestosql.plugin.hive.util.HiveUtil.getInputFormat; +import static io.prestosql.plugin.hive.util.HiveUtil.getPartitionKeyColumnHandles; import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED; import static java.lang.Integer.parseInt; import static java.lang.Math.max; @@ -119,6 +124,8 @@ public class BackgroundHiveSplitLoader private final Table table; private final TupleDomain compactEffectivePredicate; + private final Supplier> dynamicFilterSupplier; + private final TypeManager typeManager; private final Optional tableBucketInfo; private final HdfsEnvironment hdfsEnvironment; private final HdfsContext hdfsContext; @@ -157,6 +164,8 @@ public BackgroundHiveSplitLoader( Table table, Iterable partitions, TupleDomain compactEffectivePredicate, + Supplier> dynamicFilterSupplier, + TypeManager typeManager, Optional tableBucketInfo, ConnectorSession session, HdfsEnvironment hdfsEnvironment, @@ -170,6 +179,8 @@ public BackgroundHiveSplitLoader( { this.table = table; this.compactEffectivePredicate = compactEffectivePredicate; + this.dynamicFilterSupplier = dynamicFilterSupplier; + this.typeManager = typeManager; this.tableBucketInfo = tableBucketInfo; this.loaderConcurrency = loaderConcurrency; this.session = session; @@ -302,11 +313,19 @@ private ListenableFuture loadSplits() private ListenableFuture loadPartition(HivePartitionMetadata partition) throws IOException { - String partitionName = partition.getHivePartition().getPartitionId(); + HivePartition hivePartition = partition.getHivePartition(); + String partitionName = hivePartition.getPartitionId(); Properties schema = getPartitionSchema(table, partition.getPartition()); List partitionKeys = getPartitionKeys(table, partition.getPartition()); TupleDomain effectivePredicate = compactEffectivePredicate.transform(HiveColumnHandle.class::cast); + List partitionColumns = getPartitionKeyColumnHandles(table, typeManager); + BooleanSupplier partitionMatchSupplier = () -> partitionMatches(partitionColumns, dynamicFilterSupplier.get(), hivePartition); + if (!partitionMatchSupplier.getAsBoolean()) { + // Avoid listing files and creating splits from a partition if it has been pruned due to dynamic filters + return COMPLETED_FUTURE; + } + Path path = new Path(getPartitionLocation(table, partition.getPartition())); Configuration configuration = hdfsEnvironment.getConfiguration(hdfsContext, path); InputFormat inputFormat = getInputFormat(configuration, schema, false); @@ -349,6 +368,7 @@ private ListenableFuture loadPartition(HivePartitionMetadata partition) schema, partitionKeys, effectivePredicate, + partitionMatchSupplier, partition.getTableToPartitionMapping(), Optional.empty(), isForceLocalScheduling(session), @@ -386,6 +406,7 @@ private ListenableFuture loadPartition(HivePartitionMetadata partition) schema, partitionKeys, effectivePredicate, + partitionMatchSupplier, partition.getTableToPartitionMapping(), bucketConversionRequiresWorkerParticipation ? bucketConversion : Optional.empty(), isForceLocalScheduling(session), diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePartitionManager.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePartitionManager.java index 755b0d62ce5d1..344d1d211d740 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePartitionManager.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePartitionManager.java @@ -253,6 +253,14 @@ private Optional parseValuesAndFilterPartition( private boolean partitionMatches(List partitionColumns, TupleDomain constraintSummary, Predicate> constraint, HivePartition partition) { + return partitionMatches(partitionColumns, constraintSummary, partition) && constraint.test(partition.getKeys()); + } + + public static boolean partitionMatches(List partitionColumns, TupleDomain constraintSummary, HivePartition partition) + { + if (constraintSummary.isNone()) { + return false; + } Map domains = constraintSummary.getDomains().get(); for (HiveColumnHandle column : partitionColumns) { NullableValue value = partition.getKeys().get(column); @@ -261,8 +269,7 @@ private boolean partitionMatches(List partitionColumns, TupleD return false; } } - - return constraint.test(partition.getKeys()); + return true; } private List getFilteredPartitionNames(SemiTransactionalHiveMetastore metastore, HiveIdentity identity, SchemaTableName tableName, List partitionKeys, TupleDomain effectivePredicate) diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveSplitManager.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveSplitManager.java index 149436d5e7c7d..330b8f427e0df 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveSplitManager.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveSplitManager.java @@ -29,6 +29,7 @@ import io.prestosql.plugin.hive.util.HiveBucketing.HiveBucketFilter; import io.prestosql.spi.PrestoException; import io.prestosql.spi.VersionEmbedder; +import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.ConnectorSplitManager; import io.prestosql.spi.connector.ConnectorSplitSource; @@ -37,6 +38,8 @@ import io.prestosql.spi.connector.FixedSplitSource; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.TableNotFoundException; +import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.spi.type.TypeManager; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; @@ -51,6 +54,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; +import java.util.function.Supplier; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; @@ -99,6 +103,7 @@ public class HiveSplitManager private final int maxSplitsPerSecond; private final boolean recursiveDfsWalkerEnabled; private final CounterStat highMemorySplitSourceCounter; + private final TypeManager typeManager; @Inject public HiveSplitManager( @@ -110,7 +115,8 @@ public HiveSplitManager( DirectoryLister directoryLister, ExecutorService executorService, VersionEmbedder versionEmbedder, - CoercionPolicy coercionPolicy) + CoercionPolicy coercionPolicy, + TypeManager typeManager) { this( metastoreProvider, @@ -128,7 +134,8 @@ public HiveSplitManager( hiveConfig.getMaxInitialSplits(), hiveConfig.getSplitLoaderConcurrency(), hiveConfig.getMaxSplitsPerSecond(), - hiveConfig.getRecursiveDirWalkerEnabled()); + hiveConfig.getRecursiveDirWalkerEnabled(), + typeManager); } public HiveSplitManager( @@ -147,7 +154,8 @@ public HiveSplitManager( int maxInitialSplits, int splitLoaderConcurrency, @Nullable Integer maxSplitsPerSecond, - boolean recursiveDfsWalkerEnabled) + boolean recursiveDfsWalkerEnabled, + TypeManager typeManager) { this.metastoreProvider = requireNonNull(metastoreProvider, "metastore is null"); this.partitionManager = requireNonNull(partitionManager, "partitionManager is null"); @@ -166,6 +174,7 @@ public HiveSplitManager( this.splitLoaderConcurrency = splitLoaderConcurrency; this.maxSplitsPerSecond = firstNonNull(maxSplitsPerSecond, Integer.MAX_VALUE); this.recursiveDfsWalkerEnabled = recursiveDfsWalkerEnabled; + this.typeManager = requireNonNull(typeManager, "typeManager is null"); } @Override @@ -174,6 +183,17 @@ public ConnectorSplitSource getSplits( ConnectorSession session, ConnectorTableHandle tableHandle, SplitSchedulingStrategy splitSchedulingStrategy) + { + return getSplits(transaction, session, tableHandle, splitSchedulingStrategy, TupleDomain::all); + } + + @Override + public ConnectorSplitSource getSplits( + ConnectorTransactionHandle transaction, + ConnectorSession session, + ConnectorTableHandle tableHandle, + SplitSchedulingStrategy splitSchedulingStrategy, + Supplier> dynamicFilter) { HiveTableHandle hiveTable = (HiveTableHandle) tableHandle; SchemaTableName tableName = hiveTable.getSchemaTableName(); @@ -215,6 +235,8 @@ public ConnectorSplitSource getSplits( table, hivePartitions, hiveTable.getCompactEffectivePredicate(), + dynamicFilter, + typeManager, createBucketSplitInfo(bucketHandle, bucketFilter), session, hdfsEnvironment, diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/util/InternalHiveSplitFactory.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/util/InternalHiveSplitFactory.java index 39c9f9fcd420d..babcf658731fa 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/util/InternalHiveSplitFactory.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/util/InternalHiveSplitFactory.java @@ -42,6 +42,7 @@ import java.util.Optional; import java.util.OptionalInt; import java.util.Properties; +import java.util.function.BooleanSupplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -59,6 +60,7 @@ public class InternalHiveSplitFactory private final List partitionKeys; private final Optional pathDomain; private final TableToPartitionMapping tableToPartitionMapping; + private final BooleanSupplier partitionMatchSupplier; private final Optional bucketConversion; private final boolean forceLocalScheduling; private final boolean s3SelectPushdownEnabled; @@ -70,6 +72,7 @@ public InternalHiveSplitFactory( Properties schema, List partitionKeys, TupleDomain effectivePredicate, + BooleanSupplier partitionMatchSupplier, TableToPartitionMapping tableToPartitionMapping, Optional bucketConversion, boolean forceLocalScheduling, @@ -81,6 +84,7 @@ public InternalHiveSplitFactory( this.schema = requireNonNull(schema, "schema is null"); this.partitionKeys = requireNonNull(partitionKeys, "partitionKeys is null"); pathDomain = getPathDomain(requireNonNull(effectivePredicate, "effectivePredicate is null")); + this.partitionMatchSupplier = requireNonNull(partitionMatchSupplier, "partitionMatchSupplier is null"); this.tableToPartitionMapping = requireNonNull(tableToPartitionMapping, "tableToPartitionMapping is null"); this.bucketConversion = requireNonNull(bucketConversion, "bucketConversion is null"); this.forceLocalScheduling = forceLocalScheduling; @@ -139,6 +143,12 @@ private Optional createInternalHiveSplit( return Optional.empty(); } + // Dynamic filter may not have been ready when partition was loaded in BackgroundHiveSplitLoader, + // but it might be ready when splits are enumerated lazily. + if (!partitionMatchSupplier.getAsBoolean()) { + return Optional.empty(); + } + boolean forceLocalScheduling = this.forceLocalScheduling; // For empty files, some filesystem (e.g. LocalFileSystem) produce one empty block diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java index 69e39b2a3e5a6..5ddcb3b771f36 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java @@ -786,7 +786,8 @@ protected final void setup(String databaseName, HiveConfig hiveConfig, HiveMetas hiveConfig.getMaxInitialSplits(), hiveConfig.getSplitLoaderConcurrency(), hiveConfig.getMaxSplitsPerSecond(), - false); + false, + TYPE_MANAGER); pageSinkProvider = new HivePageSinkProvider( getDefaultHiveFileWriterFactories(hiveConfig, hdfsEnvironment), hdfsEnvironment, diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileSystem.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileSystem.java index dee77e4b43c74..2ca4a88f899b9 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileSystem.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileSystem.java @@ -216,7 +216,8 @@ protected void setup(String host, int port, String databaseName, boolean s3Selec config.getMaxInitialSplits(), config.getSplitLoaderConcurrency(), config.getMaxSplitsPerSecond(), - config.getRecursiveDirWalkerEnabled()); + config.getRecursiveDirWalkerEnabled(), + TYPE_MANAGER); pageSinkProvider = new HivePageSinkProvider( getDefaultHiveFileWriterFactories(config, hdfsEnvironment), hdfsEnvironment, diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java index 56cf814cd2843..8090e72008beb 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java @@ -361,6 +361,8 @@ public HivePartitionMetadata next() } }, TupleDomain.all(), + TupleDomain::all, + TYPE_MANAGER, createBucketSplitInfo(Optional.empty(), Optional.empty()), SESSION, new TestingHdfsEnvironment(TEST_FILES), @@ -645,6 +647,8 @@ private static BackgroundHiveSplitLoader backgroundHiveSplitLoader( table, hivePartitionMetadatas, compactEffectivePredicate, + TupleDomain::all, + TYPE_MANAGER, createBucketSplitInfo(bucketHandle, hiveBucketFilter), SESSION, hdfsEnvironment, @@ -672,6 +676,8 @@ private static BackgroundHiveSplitLoader backgroundHiveSplitLoader(List 0, "scheduleSplitBatchSize must be greater than 0"); this.scheduleSplitBatchSize = scheduleSplitBatchSize; @@ -175,6 +181,15 @@ private SqlQueryExecution( // analyze query this.analysis = analyze(preparedQuery, stateMachine, metadata, accessControl, sqlParser, queryExplainer, warningCollector); + stateMachine.addStateChangeListener(state -> { + if (state == STARTING) { + dynamicFilterService.registerQuery(this); + } + else if (state.isDone()) { + dynamicFilterService.removeQuery(stateMachine.getQueryId()); + } + }); + // when the query finishes cache the final query info, and clear the reference to the output stage AtomicReference queryScheduler = this.queryScheduler; stateMachine.addStateChangeListener(state -> { @@ -409,7 +424,7 @@ private PlanRoot doPlanQuery() private void planDistribution(PlanRoot plan) { // plan the execution on the active nodes - DistributedExecutionPlanner distributedPlanner = new DistributedExecutionPlanner(splitManager, metadata); + DistributedExecutionPlanner distributedPlanner = new DistributedExecutionPlanner(splitManager, metadata, dynamicFilterService); StageExecutionPlan outputStageExecutionPlan = distributedPlanner.plan(plan.getRoot(), stateMachine.getSession()); // ensure split sources are closed @@ -553,6 +568,15 @@ public QueryInfo getQueryInfo() } } + public List getAllStages() + { + SqlQueryScheduler scheduler = queryScheduler.get(); + if (scheduler != null) { + return StageInfo.getAllStages(Optional.of(scheduler.getStageInfo())); + } + return ImmutableList.of(); + } + @Override public QueryState getState() { @@ -644,6 +668,7 @@ public static class SqlQueryExecutionFactory private final Map executionPolicies; private final StatsCalculator statsCalculator; private final CostCalculator costCalculator; + private final DynamicFilterService dynamicFilterService; @Inject SqlQueryExecutionFactory( @@ -666,7 +691,8 @@ public static class SqlQueryExecutionFactory Map executionPolicies, SplitSchedulerStats schedulerStats, StatsCalculator statsCalculator, - CostCalculator costCalculator) + CostCalculator costCalculator, + DynamicFilterService dynamicFilterService) { requireNonNull(config, "config is null"); this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); @@ -689,6 +715,7 @@ public static class SqlQueryExecutionFactory this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null").get(); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); } @Override @@ -726,6 +753,7 @@ public QueryExecution createQueryExecution( schedulerStats, statsCalculator, costCalculator, + dynamicFilterService, warningCollector); } } diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlTask.java b/presto-main/src/main/java/io/prestosql/execution/SqlTask.java index de9d70faeba4b..90026a62daa65 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlTask.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlTask.java @@ -14,6 +14,7 @@ package io.prestosql.execution; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -34,6 +35,7 @@ import io.prestosql.operator.PipelineStatus; import io.prestosql.operator.TaskContext; import io.prestosql.operator.TaskStats; +import io.prestosql.spi.predicate.Domain; import io.prestosql.sql.planner.PlanFragment; import io.prestosql.sql.planner.plan.PlanNodeId; import org.joda.time.DateTime; @@ -42,6 +44,7 @@ import java.net.URI; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -54,6 +57,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.units.DataSize.succinctBytes; @@ -241,8 +245,10 @@ private TaskStatus createTaskStatus(TaskHolder taskHolder) Set completedDriverGroups = ImmutableSet.of(); long fullGcCount = 0; Duration fullGcTime = new Duration(0, MILLISECONDS); + Map dynamicTupleDomains = ImmutableMap.of(); if (taskHolder.getFinalTaskInfo() != null) { - TaskStats taskStats = taskHolder.getFinalTaskInfo().getStats(); + TaskInfo taskInfo = taskHolder.getFinalTaskInfo(); + TaskStats taskStats = taskInfo.getStats(); queuedPartitionedDrivers = taskStats.getQueuedPartitionedDrivers(); runningPartitionedDrivers = taskStats.getRunningPartitionedDrivers(); physicalWrittenDataSize = taskStats.getPhysicalWrittenDataSize(); @@ -251,6 +257,7 @@ private TaskStatus createTaskStatus(TaskHolder taskHolder) revocableMemoryReservation = taskStats.getRevocableMemoryReservation(); fullGcCount = taskStats.getFullGcCount(); fullGcTime = taskStats.getFullGcTime(); + dynamicTupleDomains = taskInfo.getTaskStatus().getDynamicFilterDomains(); } else if (taskHolder.getTaskExecution() != null) { long physicalWrittenBytes = 0; @@ -268,7 +275,11 @@ else if (taskHolder.getTaskExecution() != null) { completedDriverGroups = taskContext.getCompletedDriverGroups(); fullGcCount = taskContext.getFullGcCount(); fullGcTime = taskContext.getFullGcTime(); + dynamicTupleDomains = taskContext.getDynamicTupleDomains(); } + // Compact TupleDomain before reporting dynamic filters to coordinator to avoid bloating QueryInfo + Map compactDynamicTupleDomains = dynamicTupleDomains.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().simplify())); return new TaskStatus(taskStateMachine.getTaskId(), taskInstanceId, @@ -286,7 +297,8 @@ else if (taskHolder.getTaskExecution() != null) { systemMemoryReservation, revocableMemoryReservation, fullGcCount, - fullGcTime); + fullGcTime, + compactDynamicTupleDomains); } private TaskStats getTaskStats(TaskHolder taskHolder) diff --git a/presto-main/src/main/java/io/prestosql/execution/StageState.java b/presto-main/src/main/java/io/prestosql/execution/StageState.java index 511b75ebd0e92..aed57b6cef5d0 100644 --- a/presto-main/src/main/java/io/prestosql/execution/StageState.java +++ b/presto-main/src/main/java/io/prestosql/execution/StageState.java @@ -88,4 +88,28 @@ public boolean isFailure() { return failureState; } + + public boolean canScheduleMoreTasks() + { + switch (this) { + case PLANNED: + case SCHEDULING: + // workers are still being added to the query + return true; + case SCHEDULING_SPLITS: + case SCHEDULED: + case RUNNING: + case FINISHED: + case CANCELED: + // no more workers will be added to the query + return false; + case ABORTED: + case FAILED: + // DO NOT complete a FAILED or ABORTED stage. This will cause the + // stage above to finish normally, which will result in a query + // completing successfully when it should fail.. + return true; + } + return true; + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/TaskStatus.java b/presto-main/src/main/java/io/prestosql/execution/TaskStatus.java index 84ff4499b49ab..a5baa766dc649 100644 --- a/presto-main/src/main/java/io/prestosql/execution/TaskStatus.java +++ b/presto-main/src/main/java/io/prestosql/execution/TaskStatus.java @@ -16,12 +16,15 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.units.DataSize; import io.airlift.units.Duration; +import io.prestosql.spi.predicate.Domain; import java.net.URI; import java.util.List; +import java.util.Map; import java.util.Set; import static com.google.common.base.MoreObjects.toStringHelper; @@ -71,6 +74,8 @@ public class TaskStatus private final List failures; + private final Map dynamicFilterDomains; + @JsonCreator public TaskStatus( @JsonProperty("taskId") TaskId taskId, @@ -89,7 +94,8 @@ public TaskStatus( @JsonProperty("systemMemoryReservation") DataSize systemMemoryReservation, @JsonProperty("revocableMemoryReservation") DataSize revocableMemoryReservation, @JsonProperty("fullGcCount") long fullGcCount, - @JsonProperty("fullGcTime") Duration fullGcTime) + @JsonProperty("fullGcTime") Duration fullGcTime, + @JsonProperty("dynamicFilterDomains") Map dynamicFilterDomains) { this.taskId = requireNonNull(taskId, "taskId is null"); this.taskInstanceId = requireNonNull(taskInstanceId, "taskInstanceId is null"); @@ -119,6 +125,7 @@ public TaskStatus( checkArgument(fullGcCount >= 0, "fullGcCount is negative"); this.fullGcCount = fullGcCount; this.fullGcTime = requireNonNull(fullGcTime, "fullGcTime is null"); + this.dynamicFilterDomains = requireNonNull(dynamicFilterDomains, "dynamicFilterDomains is null"); } @JsonProperty @@ -223,6 +230,12 @@ public Duration getFullGcTime() return fullGcTime; } + @JsonProperty + public Map getDynamicFilterDomains() + { + return dynamicFilterDomains; + } + @Override public String toString() { @@ -251,7 +264,8 @@ public static TaskStatus initialTaskStatus(TaskId taskId, URI location, String n DataSize.ofBytes(0), DataSize.ofBytes(0), 0, - new Duration(0, MILLISECONDS)); + new Duration(0, MILLISECONDS), + ImmutableMap.of()); } public static TaskStatus failWith(TaskStatus taskStatus, TaskState state, List exceptions) @@ -273,6 +287,7 @@ public static TaskStatus failWith(TaskStatus taskStatus, TaskState state, List getChildStageIds() public void processScheduleResults(StageState newState, Set newTasks) { - boolean noMoreTasks = false; - switch (newState) { - case PLANNED: - case SCHEDULING: - // workers are still being added to the query - break; - case SCHEDULING_SPLITS: - case SCHEDULED: - case RUNNING: - case FINISHED: - case CANCELED: - // no more workers will be added to the query - noMoreTasks = true; - case ABORTED: - case FAILED: - // DO NOT complete a FAILED or ABORTED stage. This will cause the - // stage above to finish normally, which will result in a query - // completing successfully when it should fail.. - break; - } - + boolean noMoreTasks = !newState.canScheduleMoreTasks(); // Add an exchange location to the parent stage for each new task parent.addExchangeLocations(currentStageFragmentId, newTasks, noMoreTasks); diff --git a/presto-main/src/main/java/io/prestosql/operator/TaskContext.java b/presto-main/src/main/java/io/prestosql/operator/TaskContext.java index 48865b0c5a739..a72d3a1926acc 100644 --- a/presto-main/src/main/java/io/prestosql/operator/TaskContext.java +++ b/presto-main/src/main/java/io/prestosql/operator/TaskContext.java @@ -15,6 +15,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.AtomicDouble; import com.google.common.util.concurrent.ListenableFuture; @@ -32,12 +33,15 @@ import io.prestosql.memory.QueryContextVisitor; import io.prestosql.memory.context.LocalMemoryContext; import io.prestosql.memory.context.MemoryTrackingContext; +import io.prestosql.spi.predicate.Domain; import org.joda.time.DateTime; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; @@ -101,6 +105,9 @@ public class TaskContext private final MemoryTrackingContext taskMemoryContext; + @GuardedBy("this") + private Map dynamicTupleDomains = new HashMap<>(); + public static TaskContext createTaskContext( QueryContext queryContext, TaskStateMachine taskStateMachine, @@ -374,6 +381,18 @@ public int getFullGcCount() return toIntExact(max(0, endFullGcCount - startFullGcCount)); } + public synchronized void collectDynamicTupleDomain(Map dynamicFilterDomains) + { + for (Map.Entry entry : dynamicFilterDomains.entrySet()) { + dynamicTupleDomains.merge(entry.getKey(), entry.getValue(), Domain::intersect); + } + } + + public synchronized Map getDynamicTupleDomains() + { + return ImmutableMap.copyOf(dynamicTupleDomains); + } + public TaskStats getTaskStats() { // check for end state to avoid callback ordering problems diff --git a/presto-main/src/main/java/io/prestosql/server/CoordinatorModule.java b/presto-main/src/main/java/io/prestosql/server/CoordinatorModule.java index 8064d13384ff9..2e253e42af2e2 100644 --- a/presto-main/src/main/java/io/prestosql/server/CoordinatorModule.java +++ b/presto-main/src/main/java/io/prestosql/server/CoordinatorModule.java @@ -273,6 +273,9 @@ protected void setup(Binder binder) binder.bind(CostCalculator.class).annotatedWith(EstimatedExchanges.class).to(CostCalculatorWithEstimatedExchanges.class).in(Scopes.SINGLETON); binder.bind(CostComparator.class).in(Scopes.SINGLETON); + // dynamic filtering service + binder.bind(DynamicFilterService.class).in(Scopes.SINGLETON); + // planner binder.bind(PlanFragmenter.class).in(Scopes.SINGLETON); binder.bind(PlanOptimizers.class).in(Scopes.SINGLETON); diff --git a/presto-main/src/main/java/io/prestosql/server/DynamicFilterService.java b/presto-main/src/main/java/io/prestosql/server/DynamicFilterService.java new file mode 100644 index 0000000000000..9e83969d915d3 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/server/DynamicFilterService.java @@ -0,0 +1,225 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.server; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import io.prestosql.execution.SqlQueryExecution; +import io.prestosql.execution.StageInfo; +import io.prestosql.execution.StageState; +import io.prestosql.execution.TaskInfo; +import io.prestosql.execution.TaskManagerConfig; +import io.prestosql.spi.QueryId; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.predicate.Domain; +import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.sql.DynamicFilters; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.optimizations.PlanNodeSearcher; +import io.prestosql.sql.planner.plan.JoinNode; + +import javax.annotation.PreDestroy; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.Immutable; +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Supplier; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.mapping; +import static java.util.stream.Collectors.toList; + +@ThreadSafe +public class DynamicFilterService +{ + private final Map dynamicFilterSummaries = new ConcurrentHashMap<>(); + + @GuardedBy("this") + private final Map>> queries = new HashMap<>(); + + private final ScheduledExecutorService collectDynamicFiltersExecutor = newSingleThreadScheduledExecutor(daemonThreadsNamed("DynamicFilterService")); + + @Inject + public DynamicFilterService(TaskManagerConfig taskConfig) + { + collectDynamicFiltersExecutor.scheduleWithFixedDelay(this::collectDynamicFilters, 0, taskConfig.getStatusRefreshMaxWait().toMillis(), MILLISECONDS); + } + + @PreDestroy + public void stop() + { + collectDynamicFiltersExecutor.shutdownNow(); + } + + public void registerQuery(SqlQueryExecution sqlQueryExecution) + { + // register query only if it contains dynamic filters + boolean hasDynamicFilters = PlanNodeSearcher.searchFrom(sqlQueryExecution.getQueryPlan().getRoot()) + .where(node -> node instanceof JoinNode && !((JoinNode) node).getDynamicFilters().isEmpty()) + .matches(); + if (hasDynamicFilters) { + registerQuery(sqlQueryExecution.getQueryId(), sqlQueryExecution::getAllStages); + } + } + + @VisibleForTesting + public synchronized void registerQuery(QueryId queryId, Supplier> stageInfoSupplier) + { + queries.putIfAbsent(queryId, stageInfoSupplier); + } + + public synchronized void removeQuery(QueryId queryId) + { + dynamicFilterSummaries.keySet().removeIf(sourceDescriptor -> sourceDescriptor.getQueryId().equals(queryId)); + queries.remove(queryId); + } + + @VisibleForTesting + public synchronized void collectDynamicFilters() + { + for (Map.Entry>> entry : queries.entrySet()) { + QueryId queryId = entry.getKey(); + for (StageInfo stageInfo : entry.getValue().get()) { + StageState stageState = stageInfo.getState(); + // wait until stage has finished scheduling tasks + if (stageState.canScheduleMoreTasks()) { + continue; + } + List tasks = stageInfo.getTasks(); + Map> stageDynamicFilterDomains = tasks.stream() + .map(taskInfo -> taskInfo.getTaskStatus().getDynamicFilterDomains()) + .flatMap(taskDomains -> taskDomains.entrySet().stream()) + .collect(groupingBy(Map.Entry::getKey, mapping(Map.Entry::getValue, toList()))); + + stageDynamicFilterDomains.entrySet().stream() + // check if all tasks of a dynamic filter source have reported dynamic filter summary + .filter(stageDomains -> stageDomains.getValue().size() == tasks.size()) + .forEach(stageDomains -> dynamicFilterSummaries.put( + SourceDescriptor.of(queryId, stageDomains.getKey()), + Domain.union(stageDomains.getValue()))); + } + } + } + + public Supplier> createDynamicFilterSupplier(QueryId queryId, List dynamicFilters, Map columnHandles) + { + Map sourceColumnHandles = extractSourceColumnHandles(dynamicFilters, columnHandles); + + return () -> dynamicFilters.stream() + .map(filter -> getSummary(queryId, filter.getId()) + .map(summary -> translateSummaryToTupleDomain(filter.getId(), summary, sourceColumnHandles))) + .filter(Optional::isPresent) + .map(Optional::get) + .reduce(TupleDomain.all(), TupleDomain::intersect); + } + + @VisibleForTesting + Optional getSummary(QueryId queryId, String filterId) + { + return Optional.ofNullable(dynamicFilterSummaries.get(SourceDescriptor.of(queryId, filterId))); + } + + @Immutable + private static class SourceDescriptor + { + private final QueryId queryId; + private final String filterId; + + public static SourceDescriptor of(QueryId queryId, String filterId) + { + return new SourceDescriptor(queryId, filterId); + } + + private SourceDescriptor(QueryId queryId, String filterId) + { + this.queryId = requireNonNull(queryId, "queryId is null"); + this.filterId = requireNonNull(filterId, "filterId is null"); + } + + public QueryId getQueryId() + { + return queryId; + } + + public String getFilterId() + { + return filterId; + } + + @Override + public boolean equals(Object other) + { + if (other == this) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + + SourceDescriptor sourceDescriptor = (SourceDescriptor) other; + + return Objects.equals(queryId, sourceDescriptor.queryId) && + Objects.equals(filterId, sourceDescriptor.filterId); + } + + @Override + public int hashCode() + { + return Objects.hash(queryId, filterId); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("queryId", queryId) + .add("filterId", filterId) + .toString(); + } + } + + private static TupleDomain translateSummaryToTupleDomain(String filterId, Domain summary, Map sourceColumnHandles) + { + if (summary.isNone()) { + return TupleDomain.none(); + } + ColumnHandle sourceColumnHandle = requireNonNull(sourceColumnHandles.get(filterId), () -> format("Source column handle for dynamic filter %s is null", filterId)); + return TupleDomain.withColumnDomains(ImmutableMap.builder() + .put(sourceColumnHandle, summary) + .build()); + } + + private static Map extractSourceColumnHandles(List dynamicFilters, Map columnHandles) + { + return dynamicFilters.stream() + .collect(toImmutableMap( + DynamicFilters.Descriptor::getId, + descriptor -> columnHandles.get(Symbol.from(descriptor.getInput())))); + } +} diff --git a/presto-main/src/main/java/io/prestosql/split/SplitManager.java b/presto-main/src/main/java/io/prestosql/split/SplitManager.java index cc2e198b09d1a..c8361a113f60c 100644 --- a/presto-main/src/main/java/io/prestosql/split/SplitManager.java +++ b/presto-main/src/main/java/io/prestosql/split/SplitManager.java @@ -18,18 +18,21 @@ import io.prestosql.execution.QueryManagerConfig; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; +import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.ConnectorSplitManager; import io.prestosql.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; import io.prestosql.spi.connector.ConnectorSplitSource; import io.prestosql.spi.connector.ConnectorTableLayoutHandle; import io.prestosql.spi.connector.Constraint; +import io.prestosql.spi.predicate.TupleDomain; import javax.inject.Inject; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -65,7 +68,7 @@ public void removeConnectorSplitManager(CatalogName catalogName) splitManagers.remove(catalogName); } - public SplitSource getSplits(Session session, TableHandle table, SplitSchedulingStrategy splitSchedulingStrategy) + public SplitSource getSplits(Session session, TableHandle table, SplitSchedulingStrategy splitSchedulingStrategy, Supplier> dynamicFilter) { CatalogName catalogName = table.getCatalogName(); ConnectorSplitManager splitManager = getConnectorSplitManager(catalogName); @@ -83,7 +86,12 @@ public SplitSource getSplits(Session session, TableHandle table, SplitScheduling source = splitManager.getSplits(table.getTransaction(), connectorSession, layout, splitSchedulingStrategy); } else { - source = splitManager.getSplits(table.getTransaction(), connectorSession, table.getConnectorHandle(), splitSchedulingStrategy); + source = splitManager.getSplits( + table.getTransaction(), + connectorSession, + table.getConnectorHandle(), + splitSchedulingStrategy, + dynamicFilter); } SplitSource splitSource = new ConnectorAwareSplitSource(catalogName, source); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java index 3c71ffb578841..26e42e767772f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DistributedExecutionPlanner.java @@ -22,6 +22,9 @@ import io.prestosql.metadata.TableMetadata; import io.prestosql.metadata.TableProperties; import io.prestosql.operator.StageExecutionDescriptor; +import io.prestosql.server.DynamicFilterService; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.split.SampledSplitSource; import io.prestosql.split.SplitManager; import io.prestosql.split.SplitSource; @@ -67,6 +70,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Supplier; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.getOnlyElement; @@ -81,12 +85,14 @@ public class DistributedExecutionPlanner private final SplitManager splitManager; private final Metadata metadata; + private final DynamicFilterService dynamicFilterService; @Inject - public DistributedExecutionPlanner(SplitManager splitManager, Metadata metadata) + public DistributedExecutionPlanner(SplitManager splitManager, Metadata metadata, DynamicFilterService dynamicFilterService) { this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.metadata = requireNonNull(metadata, "metadata is null"); + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); } public StageExecutionPlan plan(SubPlan root, Session session) @@ -180,16 +186,18 @@ private Map visitScanAndFilter(TableScanNode node, Opti .map(DynamicFilters.ExtractResult::getDynamicConjuncts) .orElse(ImmutableList.of()); - // TODO: Execution must be plugged in here + Supplier> dynamicFilterSupplier = TupleDomain::all; if (!dynamicFilters.isEmpty()) { log.debug("Dynamic filters: %s", dynamicFilters); + dynamicFilterSupplier = dynamicFilterService.createDynamicFilterSupplier(session.getQueryId(), dynamicFilters, node.getAssignments()); } // get dataSource for table SplitSource splitSource = splitManager.getSplits( session, node.getTable(), - stageExecutionDescriptor.isScanGroupedExecution(node.getId()) ? GROUPED_SCHEDULING : UNGROUPED_SCHEDULING); + stageExecutionDescriptor.isScanGroupedExecution(node.getId()) ? GROUPED_SCHEDULING : UNGROUPED_SCHEDULING, + dynamicFilterSupplier); splitSources.add(splitSource); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilter.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilter.java index c8515c63e845a..06ecb294ceefa 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilter.java @@ -16,11 +16,13 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.airlift.log.Logger; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.spi.type.Type; import io.prestosql.sql.DynamicFilters; import io.prestosql.sql.planner.optimizations.PlanNodeSearcher; import io.prestosql.sql.planner.plan.FilterNode; @@ -32,13 +34,14 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.prestosql.sql.DynamicFilters.Descriptor; import static io.prestosql.sql.DynamicFilters.extractDynamicFilters; import static java.util.Objects.requireNonNull; @@ -54,9 +57,10 @@ public class LocalDynamicFilter // Mapping from dynamic filter ID to its build channel indices. private final Map buildChannels; - private final TypeProvider types; + // Mapping from dynamic filter ID to its build channel type. + private final Map filterBuildTypes; - private final SettableFuture> resultFuture; + private final SettableFuture> resultFuture; // Number of build-side partitions to be collected. private final int partitionCount; @@ -64,12 +68,14 @@ public class LocalDynamicFilter // The resulting predicates from each build-side partition. private final List> partitions; - public LocalDynamicFilter(Multimap probeSymbols, Map buildChannels, TypeProvider types, int partitionCount) + public LocalDynamicFilter(Multimap probeSymbols, Map buildChannels, Map filterBuildTypes, int partitionCount) { this.probeSymbols = requireNonNull(probeSymbols, "probeSymbols is null"); this.buildChannels = requireNonNull(buildChannels, "buildChannels is null"); - verify(probeSymbols.keySet().equals(buildChannels.keySet()), "probeSymbols and buildChannels must have same keys"); - this.types = requireNonNull(types, "types is null"); + verify(buildChannels.keySet().containsAll(probeSymbols.keySet()), "probeSymbols should be subset of buildChannels"); + + this.filterBuildTypes = requireNonNull(filterBuildTypes, "filterBuildTypes is null"); + verify(buildChannels.keySet().equals(filterBuildTypes.keySet()), "filterBuildTypes and buildChannels must have same keys"); this.resultFuture = SettableFuture.create(); @@ -77,6 +83,16 @@ public LocalDynamicFilter(Multimap probeSymbols, Map(partitionCount); } + public ListenableFuture> getDynamicFilterDomains() + { + return Futures.transform(resultFuture, this::convertTupleDomain, directExecutor()); + } + + public ListenableFuture> getNodeLocalDynamicFilterForSymbols() + { + return Futures.transform(resultFuture, this::convertTupleDomainForLocalFilters, directExecutor()); + } + private synchronized void addPartition(TupleDomain tupleDomain) { // Called concurrently by each DynamicFilterSourceOperator instance (when collection is over). @@ -85,19 +101,24 @@ private synchronized void addPartition(TupleDomain tupleDomain) // See the comment at TupleDomain::columnWiseUnion() for more details. partitions.add(tupleDomain); if (partitions.size() == partitionCount) { - Map result = convertTupleDomain(TupleDomain.columnWiseUnion(partitions)); + TupleDomain result = TupleDomain.columnWiseUnion(partitions); // No more partitions are left to be processed. resultFuture.set(result); } } - private Map convertTupleDomain(TupleDomain result) + private Map convertTupleDomainForLocalFilters(TupleDomain result) { if (result.isNone()) { // One of the join build symbols has no non-null values, therefore no symbols can match predicate - return buildChannels.keySet().stream() - .flatMap(filterId -> probeSymbols.get(filterId).stream()) - .collect(toImmutableMap(identity(), probeSymbol -> Domain.none(types.get(probeSymbol)))); + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (Map.Entry entry : filterBuildTypes.entrySet()) { + // Store `none` domain explicitly for each probe symbol + for (Symbol probeSymbol : probeSymbols.get(entry.getKey())) { + builder.put(probeSymbol, Domain.none(entry.getValue())); + } + } + return builder.build(); } // Convert the predicate to use probe symbols (instead dynamic filter IDs). // Note that in case of a probe-side union, a single dynamic filter may match multiple probe symbols. @@ -112,8 +133,20 @@ private Map convertTupleDomain(TupleDomain result) return builder.build(); } - public static Optional create(JoinNode planNode, TypeProvider types, int partitionCount) + private Map convertTupleDomain(TupleDomain result) { + if (result.isNone()) { + // One of the join build symbols has no non-null values, therefore no filters can match predicate + return buildChannels.keySet().stream() + .collect(toImmutableMap(identity(), filterId -> Domain.none(filterBuildTypes.get(filterId)))); + } + return result.getDomains().get(); + } + + public static LocalDynamicFilter create(JoinNode planNode, List buildSourceTypes, int partitionCount) + { + checkArgument(!planNode.getDynamicFilters().isEmpty(), "Join node dynamicFilters is empty."); + Set joinDynamicFilters = planNode.getDynamicFilters().keySet(); List filterNodes = PlanNodeSearcher .searchFrom(planNode.getLeft()) @@ -138,9 +171,8 @@ public static Optional create(JoinNode planNode, TypeProvide Multimap probeSymbols = probeSymbolsBuilder.build(); PlanNode buildNode = planNode.getRight(); + // Collect dynamic filters for all dynamic filters produced by join Map buildChannels = planNode.getDynamicFilters().entrySet().stream() - // Skip build channels that don't match local probe dynamic filters. - .filter(entry -> probeSymbols.containsKey(entry.getKey())) .collect(toImmutableMap( // Dynamic filter ID Map.Entry::getKey, @@ -152,10 +184,11 @@ public static Optional create(JoinNode planNode, TypeProvide return buildChannelIndex; })); - if (buildChannels.isEmpty()) { - return Optional.empty(); - } - return Optional.of(new LocalDynamicFilter(probeSymbols, buildChannels, types, partitionCount)); + Map filterBuildTypes = buildChannels.entrySet().stream() + .collect(toImmutableMap( + Map.Entry::getKey, + entry -> buildSourceTypes.get(entry.getValue()))); + return new LocalDynamicFilter(probeSymbols, buildChannels, filterBuildTypes, partitionCount); } private static boolean isFilterAboveTableScan(PlanNode node) @@ -173,11 +206,6 @@ public Map getBuildChannels() return buildChannels; } - public ListenableFuture> getResultFuture() - { - return resultFuture; - } - public Consumer> getTupleDomainConsumer() { return this::addPartition; diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 45e15e5a1b9ed..23a621309981b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -124,6 +124,7 @@ import io.prestosql.spi.connector.ConnectorIndex; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.RecordSet; +import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.type.Type; @@ -607,6 +608,11 @@ public LocalDynamicFiltersCollector getDynamicFiltersCollector() return dynamicFiltersCollector; } + private void addDynamicFilter(Map dynamicTupleDomain) + { + taskContext.collectDynamicTupleDomain(dynamicTupleDomain); + } + public Optional getIndexSourceContext() { return indexSourceContext; @@ -2151,14 +2157,12 @@ private Optional createDynamicFilter(PhysicalOperation build "Dynamic filtering cannot be used with grouped execution"); log.debug("[Join] Dynamic filters: %s", node.getDynamicFilters()); LocalDynamicFiltersCollector collector = context.getDynamicFiltersCollector(); - return LocalDynamicFilter - .create(node, context.getTypes(), partitionCount) - .map(filter -> { - // Intersect dynamic filters' predicates when they become ready, - // in order to support multiple join nodes in the same plan fragment. - addSuccessCallback(filter.getResultFuture(), collector::addDynamicFilter); - return filter; - }); + LocalDynamicFilter filter = LocalDynamicFilter.create(node, buildSource.getTypes(), partitionCount); + // Intersect dynamic filters' predicates when they become ready, + // in order to support multiple join nodes in the same plan fragment. + addSuccessCallback(filter.getDynamicFilterDomains(), context::addDynamicFilter); + addSuccessCallback(filter.getNodeLocalDynamicFilterForSymbols(), collector::addDynamicFilter); + return Optional.of(filter); } private JoinFilterFunctionFactory compileJoinFilterFunction( diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java index 8b2b9792c1f46..4c633330d6a79 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -465,7 +465,7 @@ private static KdbTree loadKdbTree(String tableName, Session session, Metadata m ColumnHandle kdbTreeColumn = Iterables.getOnlyElement(visibleColumnHandles); Optional kdbTree = Optional.empty(); - try (SplitSource splitSource = splitManager.getSplits(session, tableHandle, UNGROUPED_SCHEDULING)) { + try (SplitSource splitSource = splitManager.getSplits(session, tableHandle, UNGROUPED_SCHEDULING, TupleDomain::all)) { while (!Thread.currentThread().isInterrupted()) { SplitBatch splitBatch = getFutureValue(splitSource.getNextBatch(NOT_PARTITIONED, Lifespan.taskWide(), 1000)); List splits = splitBatch.getSplits(); diff --git a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java index 3ce5bb4806176..143e5e7ea7b2a 100644 --- a/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/io/prestosql/testing/LocalQueryRunner.java @@ -113,6 +113,7 @@ import io.prestosql.spi.PageSorter; import io.prestosql.spi.Plugin; import io.prestosql.spi.connector.ConnectorFactory; +import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.session.PropertyMetadata; import io.prestosql.spiller.FileSingleStreamSpillerFactory; import io.prestosql.spiller.GenericPartitioningSpillerFactory; @@ -748,7 +749,8 @@ private List createDrivers(Session session, Plan plan, OutputFactory out SplitSource splitSource = splitManager.getSplits( session, table, - stageExecutionDescriptor.isScanGroupedExecution(tableScan.getId()) ? GROUPED_SCHEDULING : UNGROUPED_SCHEDULING); + stageExecutionDescriptor.isScanGroupedExecution(tableScan.getId()) ? GROUPED_SCHEDULING : UNGROUPED_SCHEDULING, + TupleDomain::all); ImmutableSet.Builder scheduledSplits = ImmutableSet.builder(); while (!splitSource.isFinished()) { diff --git a/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java b/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java index 043b72d340bd8..cba3ec0aa47e0 100644 --- a/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java +++ b/presto-main/src/test/java/io/prestosql/execution/MockRemoteTaskFactory.java @@ -248,7 +248,8 @@ public TaskInfo getTaskInfo() DataSize.ofBytes(0), DataSize.ofBytes(0), 0, - new Duration(0, MILLISECONDS)), + new Duration(0, MILLISECONDS), + ImmutableMap.of()), DateTime.now(), outputBuffer.getInfo(), ImmutableSet.of(), @@ -276,7 +277,8 @@ public TaskStatus getTaskStatus() stats.getSystemMemoryReservation(), stats.getRevocableMemoryReservation(), 0, - new Duration(0, MILLISECONDS)); + new Duration(0, MILLISECONDS), + ImmutableMap.of()); } private synchronized void updateSplitQueueSpace() diff --git a/presto-main/src/test/java/io/prestosql/server/TestDynamicFilterService.java b/presto-main/src/test/java/io/prestosql/server/TestDynamicFilterService.java new file mode 100644 index 0000000000000..208f7c0dd4054 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/server/TestDynamicFilterService.java @@ -0,0 +1,392 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.server; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.stats.Distribution; +import io.airlift.units.DataSize; +import io.airlift.units.Duration; +import io.prestosql.cost.StatsAndCosts; +import io.prestosql.execution.StageId; +import io.prestosql.execution.StageInfo; +import io.prestosql.execution.StageState; +import io.prestosql.execution.StageStats; +import io.prestosql.execution.TaskId; +import io.prestosql.execution.TaskInfo; +import io.prestosql.execution.TaskManagerConfig; +import io.prestosql.execution.TaskStatus; +import io.prestosql.operator.TaskStats; +import io.prestosql.spi.QueryId; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.connector.TestingColumnHandle; +import io.prestosql.spi.eventlistener.StageGcStatistics; +import io.prestosql.spi.predicate.Domain; +import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.sql.DynamicFilters; +import io.prestosql.sql.planner.Partitioning; +import io.prestosql.sql.planner.PartitioningScheme; +import io.prestosql.sql.planner.PlanFragment; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.PlanFragmentId; +import io.prestosql.sql.planner.plan.PlanNodeId; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.tree.Expression; +import org.joda.time.DateTime; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.units.DataSize.Unit.BYTE; +import static io.prestosql.execution.TaskInfo.createInitialTask; +import static io.prestosql.operator.StageExecutionDescriptor.ungroupedExecution; +import static io.prestosql.spi.predicate.Domain.multipleValues; +import static io.prestosql.spi.predicate.Domain.none; +import static io.prestosql.spi.predicate.Domain.singleValue; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static io.prestosql.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static io.prestosql.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; +import static io.prestosql.testing.TestingHandles.TEST_TABLE_HANDLE; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestDynamicFilterService +{ + @Test + public void testDynamicFilterSummaryCompletion() + { + DynamicFilterService dynamicFilterService = new DynamicFilterService(new TaskManagerConfig()); + String filterId = "df"; + QueryId queryId = new QueryId("query"); + StageId stageId = new StageId(queryId, 0); + List taskIds = ImmutableList.of(new TaskId(stageId, 0), new TaskId(stageId, 1), new TaskId(stageId, 2)); + + assertFalse(dynamicFilterService.getSummary(queryId, filterId).isPresent()); + TestDynamicFiltersStageSupplier dynamicFiltersStageSupplier = new TestDynamicFiltersStageSupplier(); + dynamicFiltersStageSupplier.addDynamicFilter(filterId, taskIds, "probeColumn"); + dynamicFilterService.registerQuery(queryId, dynamicFiltersStageSupplier); + assertFalse(dynamicFilterService.getSummary(queryId, filterId).isPresent()); + + dynamicFiltersStageSupplier.storeSummary( + filterId, + new TaskId(stageId, 0), + singleValue(INTEGER, 1L)); + dynamicFilterService.collectDynamicFilters(); + assertFalse(dynamicFilterService.getSummary(queryId, filterId).isPresent()); + + dynamicFiltersStageSupplier.storeSummary( + filterId, + new TaskId(stageId, 1), + singleValue(INTEGER, 2L)); + dynamicFilterService.collectDynamicFilters(); + assertFalse(dynamicFilterService.getSummary(queryId, filterId).isPresent()); + + dynamicFiltersStageSupplier.storeSummary( + filterId, + new TaskId(stageId, 2), + singleValue(INTEGER, 3L)); + dynamicFilterService.collectDynamicFilters(); + Optional summary = dynamicFilterService.getSummary(queryId, filterId); + assertTrue(summary.isPresent()); + assertEquals(summary.get(), multipleValues(INTEGER, ImmutableList.of(1L, 2L, 3L))); + } + + @Test + public void testDynamicFilterSupplier() + { + DynamicFilterService dynamicFilterService = new DynamicFilterService(new TaskManagerConfig()); + String filterId1 = "df1"; + String filterId2 = "df2"; + String filterId3 = "df3"; + Expression df1 = expression("DF_SYMBOL1"); + Expression df2 = expression("DF_SYMBOL2"); + Expression df3 = expression("DF_SYMBOL3"); + QueryId queryId = new QueryId("query"); + StageId stageId1 = new StageId(queryId, 1); + StageId stageId2 = new StageId(queryId, 2); + StageId stageId3 = new StageId(queryId, 3); + + Supplier> dynamicFilterSupplier = dynamicFilterService.createDynamicFilterSupplier( + queryId, + ImmutableList.of( + new DynamicFilters.Descriptor(filterId1, df1), + new DynamicFilters.Descriptor(filterId2, df2), + new DynamicFilters.Descriptor(filterId3, df3)), + ImmutableMap.of( + Symbol.from(df1), new TestingColumnHandle("probeColumnA"), + Symbol.from(df2), new TestingColumnHandle("probeColumnA"), + Symbol.from(df3), new TestingColumnHandle("probeColumnB"))); + + assertTrue(dynamicFilterSupplier.get().isAll()); + TestDynamicFiltersStageSupplier dynamicFiltersStageSupplier = new TestDynamicFiltersStageSupplier(); + + List taskIds = ImmutableList.of(new TaskId(stageId1, 0), new TaskId(stageId1, 1)); + dynamicFiltersStageSupplier.addDynamicFilter(filterId1, taskIds, "probeColumnA"); + + taskIds = ImmutableList.of(new TaskId(stageId2, 0), new TaskId(stageId2, 1)); + dynamicFiltersStageSupplier.addDynamicFilter(filterId2, taskIds, "probeColumnA"); + + taskIds = ImmutableList.of(new TaskId(stageId3, 0), new TaskId(stageId3, 1)); + dynamicFiltersStageSupplier.addDynamicFilter(filterId3, taskIds, "probeColumnB"); + + dynamicFilterService.registerQuery(queryId, dynamicFiltersStageSupplier); + assertTrue(dynamicFilterSupplier.get().isAll()); + + dynamicFiltersStageSupplier.storeSummary( + filterId1, + new TaskId(stageId1, 0), + singleValue(INTEGER, 1L)); + dynamicFilterService.collectDynamicFilters(); + assertTrue(dynamicFilterSupplier.get().isAll()); + + dynamicFiltersStageSupplier.storeSummary( + filterId1, + new TaskId(stageId1, 1), + singleValue(INTEGER, 2L)); + dynamicFilterService.collectDynamicFilters(); + assertEquals(dynamicFilterSupplier.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new TestingColumnHandle("probeColumnA"), + multipleValues(INTEGER, ImmutableList.of(1L, 2L))))); + + dynamicFiltersStageSupplier.storeSummary( + filterId2, + new TaskId(stageId2, 0), + singleValue(INTEGER, 2L)); + dynamicFilterService.collectDynamicFilters(); + assertEquals(dynamicFilterSupplier.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new TestingColumnHandle("probeColumnA"), + multipleValues(INTEGER, ImmutableList.of(1L, 2L))))); + + dynamicFiltersStageSupplier.storeSummary( + filterId2, + new TaskId(stageId2, 1), + singleValue(INTEGER, 3L)); + dynamicFilterService.collectDynamicFilters(); + assertEquals(dynamicFilterSupplier.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new TestingColumnHandle("probeColumnA"), + singleValue(INTEGER, 2L)))); + + dynamicFiltersStageSupplier.storeSummary( + filterId3, + new TaskId(stageId3, 0), + none(INTEGER)); + dynamicFilterService.collectDynamicFilters(); + assertEquals(dynamicFilterSupplier.get(), TupleDomain.withColumnDomains(ImmutableMap.of( + new TestingColumnHandle("probeColumnA"), + singleValue(INTEGER, 2L)))); + + dynamicFiltersStageSupplier.storeSummary( + filterId3, + new TaskId(stageId3, 1), + none(INTEGER)); + dynamicFilterService.collectDynamicFilters(); + assertEquals(dynamicFilterSupplier.get(), TupleDomain.none()); + } + + private static class TestDynamicFiltersStageSupplier + implements Supplier> + { + private static final StageStats TEST_STAGE_STATS = new StageStats( + new DateTime(0), + + new Distribution(0).snapshot(), + + 4, + 5, + 6, + + 7, + 8, + 10, + 26, + 11, + + 12.0, + DataSize.of(13, BYTE), + DataSize.of(14, BYTE), + DataSize.of(15, BYTE), + DataSize.of(16, BYTE), + DataSize.of(17, BYTE), + + new Duration(15, NANOSECONDS), + new Duration(16, NANOSECONDS), + new Duration(18, NANOSECONDS), + false, + ImmutableSet.of(), + + DataSize.of(191, BYTE), + 201, + new Duration(19, NANOSECONDS), + + DataSize.of(192, BYTE), + 202, + + DataSize.of(19, BYTE), + 20, + + DataSize.of(21, BYTE), + 22, + + DataSize.of(23, BYTE), + DataSize.of(24, BYTE), + 25, + + DataSize.of(26, BYTE), + + new StageGcStatistics( + 101, + 102, + 103, + 104, + 105, + 106, + 107), + + ImmutableList.of()); + + private final Map probes = new HashMap<>(); + private final Map stagesInfo = new HashMap<>(); + + void addDynamicFilter(String filterId, List taskIds, String probeColumnName) + { + String colName = "column" + filterId; + Symbol buildSymbol = new Symbol(colName); + TableScanNode build = TableScanNode.newInstance( + new PlanNodeId("build" + filterId), + TEST_TABLE_HANDLE, + ImmutableList.of(buildSymbol), + ImmutableMap.of(buildSymbol, new TestingColumnHandle(colName))); + + Symbol probeSymbol = new Symbol(probeColumnName); + TableScanNode probe = probes.computeIfAbsent( + probeSymbol, + symbol -> TableScanNode.newInstance( + new PlanNodeId("probe" + filterId), + TEST_TABLE_HANDLE, + ImmutableList.of(symbol), + ImmutableMap.of(symbol, new TestingColumnHandle(symbol.getName())))); + + PlanFragment testFragment = new PlanFragment( + new PlanFragmentId("plan_id" + filterId), + new JoinNode( + new PlanNodeId("join_id" + filterId), + INNER, + build, + probe, + ImmutableList.of(), + build.getOutputSymbols(), + probe.getOutputSymbols(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(filterId, probeSymbol), + Optional.empty()), + ImmutableMap.of(probeSymbol, VARCHAR), + SOURCE_DISTRIBUTION, + ImmutableList.of(), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(probeSymbol)), + ungroupedExecution(), + StatsAndCosts.empty(), + Optional.empty()); + + StageId stageId = taskIds.stream().findFirst().get().getStageId(); + List tasks = taskIds.stream() + .map(taskId -> createInitialTask( + taskId, URI.create(""), "", ImmutableList.of(), new TaskStats(DateTime.now(), DateTime.now()))) + .collect(toImmutableList()); + + stagesInfo.put(stageId, new StageInfo( + stageId, + StageState.RUNNING, + testFragment, + ImmutableList.of(), + TEST_STAGE_STATS, + tasks, + ImmutableList.of(), + ImmutableMap.of(), + null)); + } + + void storeSummary(String filterId, TaskId taskId, Domain domain) + { + StageId stageId = taskId.getStageId(); + ImmutableList.Builder updatedTasks = ImmutableList.builder(); + StageInfo stageInfo = stagesInfo.get(stageId); + for (TaskInfo task : stageInfo.getTasks()) { + if (task.getTaskStatus().getTaskId().equals(taskId)) { + TaskStatus taskStatus = task.getTaskStatus(); + updatedTasks.add(new TaskInfo( + new TaskStatus( + taskStatus.getTaskId(), + taskStatus.getTaskInstanceId(), + taskStatus.getVersion(), + taskStatus.getState(), + taskStatus.getSelf(), + taskStatus.getNodeId(), + taskStatus.getCompletedDriverGroups(), + taskStatus.getFailures(), + taskStatus.getQueuedPartitionedDrivers(), + taskStatus.getRunningPartitionedDrivers(), + taskStatus.isOutputBufferOverutilized(), + taskStatus.getPhysicalWrittenDataSize(), + taskStatus.getMemoryReservation(), + taskStatus.getSystemMemoryReservation(), + taskStatus.getRevocableMemoryReservation(), + taskStatus.getFullGcCount(), + taskStatus.getFullGcTime(), + ImmutableMap.of(filterId, domain)), + task.getLastHeartbeat(), + task.getOutputBuffers(), + task.getNoMoreSplits(), + task.getStats(), + task.isNeedsPlan())); + } + else { + updatedTasks.add(task); + } + } + stagesInfo.put(stageId, new StageInfo( + stageInfo.getStageId(), + stageInfo.getState(), + stageInfo.getPlan(), + stageInfo.getTypes(), + TEST_STAGE_STATS, + updatedTasks.build(), + stageInfo.getSubStages(), + stageInfo.getTables(), + null)); + } + + @Override + public List get() + { + return ImmutableList.copyOf(stagesInfo.values()); + } + } +} diff --git a/presto-main/src/test/java/io/prestosql/server/remotetask/TestHttpRemoteTask.java b/presto-main/src/test/java/io/prestosql/server/remotetask/TestHttpRemoteTask.java index 06bebf8fe83de..c6599e063a681 100644 --- a/presto-main/src/test/java/io/prestosql/server/remotetask/TestHttpRemoteTask.java +++ b/presto-main/src/test/java/io/prestosql/server/remotetask/TestHttpRemoteTask.java @@ -476,7 +476,8 @@ private TaskStatus buildTaskStatus() initialTaskStatus.getSystemMemoryReservation(), initialTaskStatus.getRevocableMemoryReservation(), initialTaskStatus.getFullGcCount(), - initialTaskStatus.getFullGcTime()); + initialTaskStatus.getFullGcTime(), + initialTaskStatus.getDynamicFilterDomains()); } } } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java index f21436fd17869..5965de4ee07fe 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListenableFuture; import io.prestosql.Session; import io.prestosql.spi.predicate.Domain; @@ -32,11 +31,11 @@ import java.util.Comparator; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.function.Consumer; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; import static io.prestosql.SystemSessionProperties.ENABLE_DYNAMIC_FILTERING; import static io.prestosql.SystemSessionProperties.FORCE_SINGLE_NODE_OUTPUT; import static io.prestosql.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; @@ -66,11 +65,11 @@ public void testSimple() LocalDynamicFilter filter = new LocalDynamicFilter( ImmutableMultimap.of("123", new Symbol("a")), ImmutableMap.of("123", 0), - TypeProvider.copyOf(ImmutableMap.of(new Symbol("a"), INTEGER)), + ImmutableMap.of("123", INTEGER), 1); assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0)); Consumer> consumer = filter.getTupleDomainConsumer(); - ListenableFuture> result = filter.getResultFuture(); + ListenableFuture> result = filter.getNodeLocalDynamicFilterForSymbols(); assertFalse(result.isDone()); consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( @@ -86,11 +85,11 @@ public void testMultipleProbeSymbols() LocalDynamicFilter filter = new LocalDynamicFilter( ImmutableMultimap.of("123", new Symbol("a1"), "123", new Symbol("a2")), ImmutableMap.of("123", 0), - TypeProvider.copyOf(ImmutableMap.of(new Symbol("a1"), INTEGER, new Symbol("a2"), INTEGER)), + ImmutableMap.of("123", INTEGER), 1); assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0)); Consumer> consumer = filter.getTupleDomainConsumer(); - ListenableFuture> result = filter.getResultFuture(); + ListenableFuture> result = filter.getNodeLocalDynamicFilterForSymbols(); assertFalse(result.isDone()); consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( @@ -107,11 +106,11 @@ public void testMultiplePartitions() LocalDynamicFilter filter = new LocalDynamicFilter( ImmutableMultimap.of("123", new Symbol("a")), ImmutableMap.of("123", 0), - TypeProvider.copyOf(ImmutableMap.of(new Symbol("a"), INTEGER)), + ImmutableMap.of("123", INTEGER), 2); assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0)); Consumer> consumer = filter.getTupleDomainConsumer(); - ListenableFuture> result = filter.getResultFuture(); + ListenableFuture> result = filter.getNodeLocalDynamicFilterForSymbols(); assertFalse(result.isDone()); consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( @@ -132,11 +131,11 @@ public void testNone() LocalDynamicFilter filter = new LocalDynamicFilter( ImmutableMultimap.of("123", new Symbol("a")), ImmutableMap.of("123", 0), - TypeProvider.copyOf(ImmutableMap.of(new Symbol("a"), INTEGER)), + ImmutableMap.of("123", INTEGER), 1); assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0)); Consumer> consumer = filter.getTupleDomainConsumer(); - ListenableFuture> result = filter.getResultFuture(); + ListenableFuture> result = filter.getNodeLocalDynamicFilterForSymbols(); assertFalse(result.isDone()); consumer.accept(TupleDomain.none()); @@ -152,11 +151,11 @@ public void testMultipleColumns() LocalDynamicFilter filter = new LocalDynamicFilter( ImmutableMultimap.of("123", new Symbol("a"), "456", new Symbol("b")), ImmutableMap.of("123", 0, "456", 1), - TypeProvider.copyOf(ImmutableMap.of(new Symbol("a"), INTEGER)), + ImmutableMap.of("123", INTEGER, "456", INTEGER), 1); assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0, "456", 1)); Consumer> consumer = filter.getTupleDomainConsumer(); - ListenableFuture> result = filter.getResultFuture(); + ListenableFuture> result = filter.getNodeLocalDynamicFilterForSymbols(); assertFalse(result.isDone()); consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( @@ -174,11 +173,11 @@ public void testMultiplePartitionsAndColumns() LocalDynamicFilter filter = new LocalDynamicFilter( ImmutableMultimap.of("123", new Symbol("a"), "456", new Symbol("b")), ImmutableMap.of("123", 0, "456", 1), - TypeProvider.copyOf(ImmutableMap.of(new Symbol("a"), INTEGER, new Symbol("b"), BIGINT)), + ImmutableMap.of("123", INTEGER, "456", BIGINT), 2); assertEquals(filter.getBuildChannels(), ImmutableMap.of("123", 0, "456", 1)); Consumer> consumer = filter.getTupleDomainConsumer(); - ListenableFuture> result = filter.getResultFuture(); + ListenableFuture> result = filter.getNodeLocalDynamicFilterForSymbols(); assertFalse(result.isDone()); consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( @@ -205,18 +204,19 @@ public void testCreateSingleColumn() OPTIMIZED_AND_VALIDATED, false); JoinNode joinNode = searchJoins(subplan.getChildren().get(0).getFragment()).findOnlyElement(); - LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, TypeProvider.copyOf(subplan.getFragment().getSymbols()), 1).get(); - String filterId = Iterables.getOnlyElement(filter.getBuildChannels().keySet()); - Symbol probeSymbol = Iterables.getOnlyElement(joinNode.getCriteria()).getLeft(); + LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); + String filterId = getOnlyElement(filter.getBuildChannels().keySet()); + Symbol probeSymbol = getOnlyElement(joinNode.getCriteria()).getLeft(); filter.getTupleDomainConsumer().accept(TupleDomain.withColumnDomains(ImmutableMap.of( filterId, Domain.singleValue(BIGINT, 3L)))); - assertEquals(filter.getResultFuture().get(), ImmutableMap.of( + assertEquals(filter.getNodeLocalDynamicFilterForSymbols().get(), ImmutableMap.of( probeSymbol, Domain.singleValue(BIGINT, 3L))); } @Test public void testCreateDistributedJoin() + throws Exception { Session session = Session.builder(getQueryRunner().getDefaultSession()) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "PARTITIONED") @@ -228,8 +228,15 @@ public void testCreateDistributedJoin() false, session); JoinNode joinNode = searchJoins(subplan.getChildren().get(0).getFragment()).findOnlyElement(); + LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); + String filterId = getOnlyElement(filter.getBuildChannels().keySet()); assertFalse(joinNode.getDynamicFilters().isEmpty()); - assertEquals(LocalDynamicFilter.create(joinNode, TypeProvider.copyOf(subplan.getFragment().getSymbols()), 1), Optional.empty()); + + filter.getTupleDomainConsumer().accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filterId, Domain.singleValue(BIGINT, 3L)))); + assertEquals(filter.getNodeLocalDynamicFilterForSymbols().get(), ImmutableMap.of()); + assertEquals(filter.getDynamicFilterDomains().get(), ImmutableMap.of( + filterId, Domain.singleValue(BIGINT, 3L))); } @Test @@ -244,7 +251,7 @@ public void testCreateMultipleCriteria() false); JoinNode joinNode = searchJoins(subplan.getChildren().get(0).getFragment()).findOnlyElement(); - LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, TypeProvider.copyOf(subplan.getFragment().getSymbols()), 1).get(); + LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); List filterIds = filter .getBuildChannels() .entrySet() @@ -256,7 +263,7 @@ public void testCreateMultipleCriteria() filterIds.get(0), Domain.singleValue(BIGINT, 4L), filterIds.get(1), Domain.singleValue(BIGINT, 5L)))); - assertEquals(filter.getResultFuture().get(), ImmutableMap.of( + assertEquals(filter.getNodeLocalDynamicFilterForSymbols().get(), ImmutableMap.of( new Symbol("partkey"), Domain.singleValue(BIGINT, 4L), new Symbol("suppkey"), Domain.singleValue(BIGINT, 5L))); } @@ -275,13 +282,13 @@ public void testCreateMultipleJoins() List joinNodes = searchJoins(subplan.getChildren().get(0).getFragment()).findAll(); assertEquals(joinNodes.size(), 2); for (JoinNode joinNode : joinNodes) { - LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, TypeProvider.copyOf(subplan.getFragment().getSymbols()), 1).get(); - String filterId = Iterables.getOnlyElement(filter.getBuildChannels().keySet()); - Symbol probeSymbol = Iterables.getOnlyElement(joinNode.getCriteria()).getLeft(); + LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); + String filterId = getOnlyElement(filter.getBuildChannels().keySet()); + Symbol probeSymbol = getOnlyElement(joinNode.getCriteria()).getLeft(); filter.getTupleDomainConsumer().accept(TupleDomain.withColumnDomains(ImmutableMap.of( filterId, Domain.singleValue(BIGINT, 6L)))); - assertEquals(filter.getResultFuture().get(), ImmutableMap.of( + assertEquals(filter.getNodeLocalDynamicFilterForSymbols().get(), ImmutableMap.of( probeSymbol, Domain.singleValue(BIGINT, 6L))); } } @@ -299,12 +306,12 @@ public void testCreateProbeSideUnion() true); JoinNode joinNode = searchJoins(subplan.getFragment()).findOnlyElement(); - LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, TypeProvider.copyOf(subplan.getFragment().getSymbols()), 1).get(); - String filterId = Iterables.getOnlyElement(filter.getBuildChannels().keySet()); + LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); + String filterId = getOnlyElement(filter.getBuildChannels().keySet()); filter.getTupleDomainConsumer().accept(TupleDomain.withColumnDomains(ImmutableMap.of( filterId, Domain.singleValue(BIGINT, 7L)))); - assertEquals(filter.getResultFuture().get(), ImmutableMap.of( + assertEquals(filter.getNodeLocalDynamicFilterForSymbols().get(), ImmutableMap.of( new Symbol("partkey"), Domain.singleValue(BIGINT, 7L), new Symbol("suppkey"), Domain.singleValue(BIGINT, 7L))); } diff --git a/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorSplitManager.java b/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorSplitManager.java index b53bcfc95c7dd..3d25ccf29e7fd 100644 --- a/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorSplitManager.java +++ b/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorSplitManager.java @@ -14,15 +14,19 @@ package io.prestosql.plugin.base.classloader; import io.prestosql.spi.classloader.ThreadContextClassLoader; +import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.ConnectorSplitManager; import io.prestosql.spi.connector.ConnectorSplitSource; import io.prestosql.spi.connector.ConnectorTableHandle; import io.prestosql.spi.connector.ConnectorTableLayoutHandle; import io.prestosql.spi.connector.ConnectorTransactionHandle; +import io.prestosql.spi.predicate.TupleDomain; import javax.inject.Inject; +import java.util.function.Supplier; + import static java.util.Objects.requireNonNull; public final class ClassLoaderSafeConnectorSplitManager @@ -53,4 +57,12 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, Co return delegate.getSplits(transaction, session, table, splitSchedulingStrategy); } } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableHandle table, SplitSchedulingStrategy splitSchedulingStrategy, Supplier> dynamicFilter) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getSplits(transaction, session, table, splitSchedulingStrategy, dynamicFilter); + } + } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorSplitManager.java b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorSplitManager.java index 8b20eb2155390..531405817da18 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorSplitManager.java +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorSplitManager.java @@ -13,6 +13,10 @@ */ package io.prestosql.spi.connector; +import io.prestosql.spi.predicate.TupleDomain; + +import java.util.function.Supplier; + public interface ConnectorSplitManager { @Deprecated @@ -34,6 +38,16 @@ default ConnectorSplitSource getSplits( throw new UnsupportedOperationException(); } + default ConnectorSplitSource getSplits( + ConnectorTransactionHandle transaction, + ConnectorSession session, + ConnectorTableHandle table, + SplitSchedulingStrategy splitSchedulingStrategy, + Supplier> dynamicFilter) + { + return getSplits(transaction, session, table, splitSchedulingStrategy); + } + enum SplitSchedulingStrategy { UNGROUPED_SCHEDULING, From dc2eedbdfb6ec7290fae39956fc6908176428e6e Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Fri, 15 May 2020 10:15:08 +0530 Subject: [PATCH 3/3] Rename LocalDynamicFilter to LocalDynamicFilterConsumer --- ...r.java => LocalDynamicFilterConsumer.java} | 12 ++++----- .../sql/planner/LocalExecutionPlanner.java | 12 ++++----- ...va => TestLocalDynamicFilterConsumer.java} | 26 +++++++++---------- 3 files changed, 25 insertions(+), 25 deletions(-) rename presto-main/src/main/java/io/prestosql/sql/planner/{LocalDynamicFilter.java => LocalDynamicFilterConsumer.java} (94%) rename presto-main/src/test/java/io/prestosql/sql/planner/{TestLocalDynamicFilter.java => TestLocalDynamicFilterConsumer.java} (91%) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilter.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilterConsumer.java similarity index 94% rename from presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilter.java rename to presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilterConsumer.java index 06ecb294ceefa..2e8b325868e64 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalDynamicFilterConsumer.java @@ -47,9 +47,9 @@ import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; -public class LocalDynamicFilter +public class LocalDynamicFilterConsumer { - private static final Logger log = Logger.get(LocalDynamicFilter.class); + private static final Logger log = Logger.get(LocalDynamicFilterConsumer.class); // Mapping from dynamic filter ID to its probe symbols. private final Multimap probeSymbols; @@ -68,7 +68,7 @@ public class LocalDynamicFilter // The resulting predicates from each build-side partition. private final List> partitions; - public LocalDynamicFilter(Multimap probeSymbols, Map buildChannels, Map filterBuildTypes, int partitionCount) + public LocalDynamicFilterConsumer(Multimap probeSymbols, Map buildChannels, Map filterBuildTypes, int partitionCount) { this.probeSymbols = requireNonNull(probeSymbols, "probeSymbols is null"); this.buildChannels = requireNonNull(buildChannels, "buildChannels is null"); @@ -143,14 +143,14 @@ private Map convertTupleDomain(TupleDomain result) return result.getDomains().get(); } - public static LocalDynamicFilter create(JoinNode planNode, List buildSourceTypes, int partitionCount) + public static LocalDynamicFilterConsumer create(JoinNode planNode, List buildSourceTypes, int partitionCount) { checkArgument(!planNode.getDynamicFilters().isEmpty(), "Join node dynamicFilters is empty."); Set joinDynamicFilters = planNode.getDynamicFilters().keySet(); List filterNodes = PlanNodeSearcher .searchFrom(planNode.getLeft()) - .where(LocalDynamicFilter::isFilterAboveTableScan) + .where(LocalDynamicFilterConsumer::isFilterAboveTableScan) .findAll(); // Mapping from probe-side dynamic filters' IDs to their matching probe symbols. @@ -188,7 +188,7 @@ public static LocalDynamicFilter create(JoinNode planNode, List buildSourc .collect(toImmutableMap( Map.Entry::getKey, entry -> buildSourceTypes.get(entry.getValue()))); - return new LocalDynamicFilter(probeSymbols, buildChannels, filterBuildTypes, partitionCount); + return new LocalDynamicFilterConsumer(probeSymbols, buildChannels, filterBuildTypes, partitionCount); } private static boolean isFilterAboveTableScan(PlanNode node) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 23a621309981b..79174eb3e6d48 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -2125,7 +2125,7 @@ private JoinBridgeManager createLookupSourceFact } private DynamicFilterSourceOperatorFactory createDynamicFilterSourceOperatorFactory( - LocalDynamicFilter dynamicFilter, + LocalDynamicFilterConsumer dynamicFilter, JoinNode node, PhysicalOperation buildSource, LocalExecutionPlanContext context) @@ -2147,7 +2147,7 @@ private DynamicFilterSourceOperatorFactory createDynamicFilterSourceOperatorFact getDynamicFilteringMaxPerDriverSize(context.getSession())); } - private Optional createDynamicFilter(PhysicalOperation buildSource, JoinNode node, LocalExecutionPlanContext context, int partitionCount) + private Optional createDynamicFilter(PhysicalOperation buildSource, JoinNode node, LocalExecutionPlanContext context, int partitionCount) { if (node.getDynamicFilters().isEmpty()) { return Optional.empty(); @@ -2157,12 +2157,12 @@ private Optional createDynamicFilter(PhysicalOperation build "Dynamic filtering cannot be used with grouped execution"); log.debug("[Join] Dynamic filters: %s", node.getDynamicFilters()); LocalDynamicFiltersCollector collector = context.getDynamicFiltersCollector(); - LocalDynamicFilter filter = LocalDynamicFilter.create(node, buildSource.getTypes(), partitionCount); + LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(node, buildSource.getTypes(), partitionCount); // Intersect dynamic filters' predicates when they become ready, // in order to support multiple join nodes in the same plan fragment. - addSuccessCallback(filter.getDynamicFilterDomains(), context::addDynamicFilter); - addSuccessCallback(filter.getNodeLocalDynamicFilterForSymbols(), collector::addDynamicFilter); - return Optional.of(filter); + addSuccessCallback(filterConsumer.getDynamicFilterDomains(), context::addDynamicFilter); + addSuccessCallback(filterConsumer.getNodeLocalDynamicFilterForSymbols(), collector::addDynamicFilter); + return Optional.of(filterConsumer); } private JoinFilterFunctionFactory compileJoinFilterFunction( diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilterConsumer.java similarity index 91% rename from presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java rename to presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilterConsumer.java index 5965de4ee07fe..e9a34acbe7cda 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilter.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestLocalDynamicFilterConsumer.java @@ -46,10 +46,10 @@ import static io.prestosql.testing.assertions.Assert.assertEquals; import static org.testng.Assert.assertFalse; -public class TestLocalDynamicFilter +public class TestLocalDynamicFilterConsumer extends BasePlanTest { - public TestLocalDynamicFilter() + public TestLocalDynamicFilterConsumer() { super(ImmutableMap.of( FORCE_SINGLE_NODE_OUTPUT, "false", @@ -62,7 +62,7 @@ public TestLocalDynamicFilter() public void testSimple() throws ExecutionException, InterruptedException { - LocalDynamicFilter filter = new LocalDynamicFilter( + LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMultimap.of("123", new Symbol("a")), ImmutableMap.of("123", 0), ImmutableMap.of("123", INTEGER), @@ -82,7 +82,7 @@ public void testSimple() public void testMultipleProbeSymbols() throws ExecutionException, InterruptedException { - LocalDynamicFilter filter = new LocalDynamicFilter( + LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMultimap.of("123", new Symbol("a1"), "123", new Symbol("a2")), ImmutableMap.of("123", 0), ImmutableMap.of("123", INTEGER), @@ -103,7 +103,7 @@ public void testMultipleProbeSymbols() public void testMultiplePartitions() throws ExecutionException, InterruptedException { - LocalDynamicFilter filter = new LocalDynamicFilter( + LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMultimap.of("123", new Symbol("a")), ImmutableMap.of("123", 0), ImmutableMap.of("123", INTEGER), @@ -128,7 +128,7 @@ public void testMultiplePartitions() public void testNone() throws ExecutionException, InterruptedException { - LocalDynamicFilter filter = new LocalDynamicFilter( + LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMultimap.of("123", new Symbol("a")), ImmutableMap.of("123", 0), ImmutableMap.of("123", INTEGER), @@ -148,7 +148,7 @@ public void testNone() public void testMultipleColumns() throws ExecutionException, InterruptedException { - LocalDynamicFilter filter = new LocalDynamicFilter( + LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMultimap.of("123", new Symbol("a"), "456", new Symbol("b")), ImmutableMap.of("123", 0, "456", 1), ImmutableMap.of("123", INTEGER, "456", INTEGER), @@ -170,7 +170,7 @@ public void testMultipleColumns() public void testMultiplePartitionsAndColumns() throws ExecutionException, InterruptedException { - LocalDynamicFilter filter = new LocalDynamicFilter( + LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMultimap.of("123", new Symbol("a"), "456", new Symbol("b")), ImmutableMap.of("123", 0, "456", 1), ImmutableMap.of("123", INTEGER, "456", BIGINT), @@ -204,7 +204,7 @@ public void testCreateSingleColumn() OPTIMIZED_AND_VALIDATED, false); JoinNode joinNode = searchJoins(subplan.getChildren().get(0).getFragment()).findOnlyElement(); - LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); + LocalDynamicFilterConsumer filter = LocalDynamicFilterConsumer.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); String filterId = getOnlyElement(filter.getBuildChannels().keySet()); Symbol probeSymbol = getOnlyElement(joinNode.getCriteria()).getLeft(); @@ -228,7 +228,7 @@ public void testCreateDistributedJoin() false, session); JoinNode joinNode = searchJoins(subplan.getChildren().get(0).getFragment()).findOnlyElement(); - LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); + LocalDynamicFilterConsumer filter = LocalDynamicFilterConsumer.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); String filterId = getOnlyElement(filter.getBuildChannels().keySet()); assertFalse(joinNode.getDynamicFilters().isEmpty()); @@ -251,7 +251,7 @@ public void testCreateMultipleCriteria() false); JoinNode joinNode = searchJoins(subplan.getChildren().get(0).getFragment()).findOnlyElement(); - LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); + LocalDynamicFilterConsumer filter = LocalDynamicFilterConsumer.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); List filterIds = filter .getBuildChannels() .entrySet() @@ -282,7 +282,7 @@ public void testCreateMultipleJoins() List joinNodes = searchJoins(subplan.getChildren().get(0).getFragment()).findAll(); assertEquals(joinNodes.size(), 2); for (JoinNode joinNode : joinNodes) { - LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); + LocalDynamicFilterConsumer filter = LocalDynamicFilterConsumer.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); String filterId = getOnlyElement(filter.getBuildChannels().keySet()); Symbol probeSymbol = getOnlyElement(joinNode.getCriteria()).getLeft(); @@ -306,7 +306,7 @@ public void testCreateProbeSideUnion() true); JoinNode joinNode = searchJoins(subplan.getFragment()).findOnlyElement(); - LocalDynamicFilter filter = LocalDynamicFilter.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); + LocalDynamicFilterConsumer filter = LocalDynamicFilterConsumer.create(joinNode, ImmutableList.copyOf(subplan.getFragment().getSymbols().values()), 1); String filterId = getOnlyElement(filter.getBuildChannels().keySet()); filter.getTupleDomainConsumer().accept(TupleDomain.withColumnDomains(ImmutableMap.of(