diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index e3900cbf95ab..3f9bc254c7e4 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -178,6 +178,7 @@ public final class SystemSessionProperties public static final String MAX_UNACKNOWLEDGED_SPLITS_PER_TASK = "max_unacknowledged_splits_per_task"; public static final String OPTIMIZE_JOINS_WITH_EMPTY_SOURCES = "optimize_joins_with_empty_sources"; public static final String SPOOLING_OUTPUT_BUFFER_ENABLED = "spooling_output_buffer_enabled"; + public static final String SPARK_ASSIGN_BUCKET_TO_PARTITION_FOR_PARTITIONED_TABLE_WRITE_ENABLED = "spark_assign_bucket_to_partition_for_partitioned_table_write_enabled"; private final List> sessionProperties; @@ -932,7 +933,12 @@ public SystemSessionProperties( SPOOLING_OUTPUT_BUFFER_ENABLED, "Enable spooling output buffer for terminal task", featuresConfig.isSpoolingOutputBufferEnabled(), - false)); + false), + booleanProperty( + SPARK_ASSIGN_BUCKET_TO_PARTITION_FOR_PARTITIONED_TABLE_WRITE_ENABLED, + "Assign bucket to partition map for partitioned table write when adding an exchange", + featuresConfig.isPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(), + true)); } public static boolean isEmptyJoinOptimization(Session session) @@ -1578,4 +1584,9 @@ public static int getMaxUnacknowledgedSplitsPerTask(Session session) { return session.getSystemProperty(MAX_UNACKNOWLEDGED_SPLITS_PER_TASK, Integer.class); } + + public static boolean isPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(Session session) + { + return session.getSystemProperty(SPARK_ASSIGN_BUCKET_TO_PARTITION_FOR_PARTITIONED_TABLE_WRITE_ENABLED, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index f05824892dc7..81c85bb8932e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -186,6 +186,7 @@ public class FeaturesConfig private PartitioningPrecisionStrategy partitioningPrecisionStrategy = PartitioningPrecisionStrategy.AUTOMATIC; private boolean enforceFixedDistributionForOutputOperator; + private boolean prestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled; public enum PartitioningPrecisionStrategy { @@ -1583,4 +1584,16 @@ public FeaturesConfig setSpoolingOutputBufferTempStorage(String spoolingOutputBu this.spoolingOutputBufferTempStorage = spoolingOutputBufferTempStorage; return this; } + + public boolean isPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled() + { + return prestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled; + } + + @Config("spark.assign-bucket-to-partition-for-partitioned-table-write-enabled") + public FeaturesConfig setPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(boolean prestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled) + { + this.prestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled = prestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 5e1d5449ccb2..0fd7ac814718 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -165,7 +165,8 @@ public PlanOptimizers( CostCalculator costCalculator, @EstimatedExchanges CostCalculator estimatedExchangesCostCalculator, CostComparator costComparator, - TaskCountEstimator taskCountEstimator) + TaskCountEstimator taskCountEstimator, + PartitioningProviderManager partitioningProviderManager) { this(metadata, sqlParser, @@ -178,7 +179,8 @@ public PlanOptimizers( costCalculator, estimatedExchangesCostCalculator, costComparator, - taskCountEstimator); + taskCountEstimator, + partitioningProviderManager); } @PostConstruct @@ -207,7 +209,8 @@ public PlanOptimizers( CostCalculator costCalculator, CostCalculator estimatedExchangesCostCalculator, CostComparator costComparator, - TaskCountEstimator taskCountEstimator) + TaskCountEstimator taskCountEstimator, + PartitioningProviderManager partitioningProviderManager) { this.exporter = exporter; ImmutableList.Builder builder = ImmutableList.builder(); @@ -448,7 +451,7 @@ public PlanOptimizers( ImmutableSet.>builder() .addAll(new PushDownDereferences(metadata).rules()) .build()), - new PruneUnreferencedOutputs()); + new PruneUnreferencedOutputs()); builder.add(new IterativeOptimizer( ruleStats, @@ -540,7 +543,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PushTableWriteThroughUnion()))); // Must run before AddExchanges - builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, sqlParser))); + builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, sqlParser, partitioningProviderManager))); } //noinspection UnusedAssignment diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index a82cdbd77bbe..a602169c45fe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -18,11 +18,13 @@ import com.facebook.presto.connector.system.GlobalSystemConnector; import com.facebook.presto.execution.QueryManagerConfig.ExchangeMaterializationStrategy; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.GroupingProperty; import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.DistinctLimitNode; @@ -43,6 +45,7 @@ import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningHandle; +import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanVariableAllocator; import com.facebook.presto.sql.planner.TypeProvider; @@ -98,12 +101,14 @@ import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount; import static com.facebook.presto.SystemSessionProperties.getPartialMergePushdownStrategy; import static com.facebook.presto.SystemSessionProperties.getPartitioningProviderCatalog; +import static com.facebook.presto.SystemSessionProperties.getTaskPartitionedWriterCount; import static com.facebook.presto.SystemSessionProperties.isColocatedJoinEnabled; import static com.facebook.presto.SystemSessionProperties.isDistributedIndexJoinEnabled; import static com.facebook.presto.SystemSessionProperties.isDistributedSortEnabled; import static com.facebook.presto.SystemSessionProperties.isExactPartitioningPreferred; import static com.facebook.presto.SystemSessionProperties.isForceSingleNodeOutput; import static com.facebook.presto.SystemSessionProperties.isPreferDistributedUnion; +import static com.facebook.presto.SystemSessionProperties.isPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled; import static com.facebook.presto.SystemSessionProperties.isRedistributeWrites; import static com.facebook.presto.SystemSessionProperties.isScaleWriters; import static com.facebook.presto.SystemSessionProperties.isUseStreamingExchangeForMarkDistinctEnabled; @@ -140,6 +145,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; public class AddExchanges @@ -147,17 +153,19 @@ public class AddExchanges { private final SqlParser parser; private final Metadata metadata; + private final PartitioningProviderManager partitioningProviderManager; - public AddExchanges(Metadata metadata, SqlParser parser) + public AddExchanges(Metadata metadata, SqlParser parser, PartitioningProviderManager partitioningProviderManager) { - this.metadata = metadata; - this.parser = parser; + this.metadata = requireNonNull(metadata, "metadata is null"); + this.parser = requireNonNull(parser, "parser is null"); + this.partitioningProviderManager = requireNonNull(partitioningProviderManager, "partitioningProviderManager is null"); } @Override public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) { - PlanWithProperties result = plan.accept(new Rewriter(idAllocator, variableAllocator, session), PreferredProperties.any()); + PlanWithProperties result = plan.accept(new Rewriter(idAllocator, variableAllocator, session, partitioningProviderManager), PreferredProperties.any()); return result.getNode(); } @@ -177,8 +185,13 @@ private class Rewriter private final String partitioningProviderCatalog; private final int hashPartitionCount; private final ExchangeMaterializationStrategy exchangeMaterializationStrategy; + private final PartitioningProviderManager partitioningProviderManager; - public Rewriter(PlanNodeIdAllocator idAllocator, PlanVariableAllocator variableAllocator, Session session) + public Rewriter( + PlanNodeIdAllocator idAllocator, + PlanVariableAllocator variableAllocator, + Session session, + PartitioningProviderManager partitioningProviderManager) { this.idAllocator = idAllocator; this.variableAllocator = variableAllocator; @@ -193,6 +206,7 @@ public Rewriter(PlanNodeIdAllocator idAllocator, PlanVariableAllocator variableA this.partitioningProviderCatalog = getPartitioningProviderCatalog(session); this.hashPartitionCount = getHashPartitionCount(session); this.exchangeMaterializationStrategy = getExchangeMaterializationStrategy(session); + this.partitioningProviderManager = requireNonNull(partitioningProviderManager, "partitioningProviderManager is null"); } @Override @@ -598,17 +612,44 @@ else if (redistributeWrites) { !source.getProperties().isCompatibleTablePartitioningWith(shufflePartitioningScheme.get().getPartitioning(), false, metadata, session) && !(source.getProperties().isRefinedPartitioningOver(shufflePartitioningScheme.get().getPartitioning(), false, metadata, session) && canPushdownPartialMerge(source.getNode(), partialMergePushdownStrategy))) { + PartitioningScheme exchangePartitioningScheme = shufflePartitioningScheme.get(); + if (node.getTablePartitioningScheme().isPresent() && isPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(session)) { + int writerThreadsPerNode = getTaskPartitionedWriterCount(session); + int bucketCount = getBucketCount(node.getTablePartitioningScheme().get().getPartitioning().getHandle()); + int[] bucketToPartition = new int[bucketCount]; + for (int i = 0; i < bucketCount; i++) { + bucketToPartition[i] = i / writerThreadsPerNode; + } + exchangePartitioningScheme = exchangePartitioningScheme.withBucketToPartition(Optional.of(bucketToPartition)); + } + source = withDerivedProperties( partitionedExchange( idAllocator.getNextId(), REMOTE_STREAMING, source.getNode(), - shufflePartitioningScheme.get()), + exchangePartitioningScheme), source.getProperties()); } return rebaseAndDeriveProperties(node, source); } + private int getBucketCount(PartitioningHandle partitioning) + { + ConnectorNodePartitioningProvider partitioningProvider = getPartitioningProvider(partitioning); + return partitioningProvider.getBucketCount( + partitioning.getTransactionHandle().orElse(null), + session.toConnectorSession(), + partitioning.getConnectorHandle()); + } + + private ConnectorNodePartitioningProvider getPartitioningProvider(PartitioningHandle partitioning) + { + ConnectorId connectorId = partitioning.getConnectorId() + .orElseThrow(() -> new IllegalArgumentException("Unexpected partitioning: " + partitioning)); + return partitioningProviderManager.getPartitioningProvider(connectorId); + } + private PlanWithProperties planTableScan(TableScanNode node, RowExpression predicate) { PlanNode plan = pushPredicateIntoTableScan(node, predicate, true, session, idAllocator, metadata); diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index 8bfa5118fd8d..35ad456bd45c 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -946,7 +946,8 @@ public List getPlanOptimizers(boolean forceSingleNode) costCalculator, estimatedExchangesCostCalculator, new CostComparator(featuresConfig), - taskCountEstimator).getPlanningTimeOptimizers(); + taskCountEstimator, + partitioningProviderManager).getPlanningTimeOptimizers(); } public Plan createPlan(Session session, @Language("SQL") String sql, List optimizers, WarningCollector warningCollector) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 0fb9038e0fa4..1d06b48efbc6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -159,7 +159,8 @@ public void testDefaults() .setEmptyJoinOptimization(false) .setSpoolingOutputBufferEnabled(false) .setSpoolingOutputBufferThreshold(new DataSize(8, MEGABYTE)) - .setSpoolingOutputBufferTempStorage("local")); + .setSpoolingOutputBufferTempStorage("local") + .setPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(false)); } @Test @@ -272,6 +273,7 @@ public void testExplicitPropertyMappings() .put("spooling-output-buffer-enabled", "true") .put("spooling-output-buffer-threshold", "16MB") .put("spooling-output-buffer-temp-storage", "tempfs") + .put("spark.assign-bucket-to-partition-for-partitioned-table-write-enabled", "true") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -381,7 +383,8 @@ public void testExplicitPropertyMappings() .setEmptyJoinOptimization(true) .setSpoolingOutputBufferEnabled(true) .setSpoolingOutputBufferThreshold(new DataSize(16, MEGABYTE)) - .setSpoolingOutputBufferTempStorage("tempfs"); + .setSpoolingOutputBufferTempStorage("tempfs") + .setPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(true); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java index 2ef65f2ae80b..b006b3400fbb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; @@ -94,7 +95,7 @@ public void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern pattern getQueryRunner().getStatsCalculator(), getQueryRunner().getCostCalculator(), new TranslateExpressions(getMetadata(), new SqlParser()).rules()), - new AddExchanges(getQueryRunner().getMetadata(), new SqlParser()), + new AddExchanges(getQueryRunner().getMetadata(), new SqlParser(), new PartitioningProviderManager()), new UnaliasSymbolReferences(getMetadata().getFunctionAndTypeManager()), new PruneUnreferencedOutputs(), new IterativeOptimizer( diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkSettingsRequirements.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkSettingsRequirements.java index fad7e7d74588..f5035e2df652 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkSettingsRequirements.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkSettingsRequirements.java @@ -90,6 +90,7 @@ public static void setDefaults(FeaturesConfig config) config.setForceSingleNodeOutput(false); config.setInlineSqlFunctions(true); config.setEnforceFixedDistributionForOutputOperator(true); + config.setPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(true); } public static void setDefaults(QueryManagerConfig config) diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java index 9f3412fea3fc..979a804af9c5 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java @@ -202,6 +202,8 @@ public PrestoSparkQueryRunner(String defaultCatalog, Map additio ImmutableMap.Builder configProperties = ImmutableMap.builder(); configProperties.put("presto.version", "testversion"); configProperties.put("query.hash-partition-count", Integer.toString(NODE_COUNT * 2)); + configProperties.put("task.writer-count", Integer.toString(2)); + configProperties.put("task.partitioned-writer-count", Integer.toString(4)); configProperties.putAll(additionalConfigProperties); PrestoSparkInjectorFactory injectorFactory = new PrestoSparkInjectorFactory( diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkQueryRunner.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkQueryRunner.java index abab75ac7c0b..651a2000a083 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkQueryRunner.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/TestPrestoSparkQueryRunner.java @@ -21,12 +21,17 @@ import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; +import java.util.List; + import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; +import static com.facebook.presto.SystemSessionProperties.PARTIAL_MERGE_PUSHDOWN_STRATEGY; import static com.facebook.presto.spark.PrestoSparkQueryRunner.createHivePrestoSparkQueryRunner; import static com.facebook.presto.spark.PrestoSparkSessionProperties.STORAGE_BASED_BROADCAST_JOIN_ENABLED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialMergePushdownStrategy.PUSH_THROUGH_LOW_MEMORY_OPERATORS; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static com.facebook.presto.tests.QueryAssertions.assertEqualsIgnoreOrder; import static io.airlift.tpch.TpchTable.NATION; +import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; public class TestPrestoSparkQueryRunner @@ -91,31 +96,157 @@ public void testTableWrite() } @Test - public void testBucketedTableWrite() + public void testBucketedTableWriteSimple() + { + // simple write from a bucketed table to a bucketed table + // same bucket count + testBucketedTableWriteSimple(getSession(), 8, 8); + for (Session testSession : getTestCompatibleBucketCountSessions()) { + // incompatible bucket count + testBucketedTableWriteSimple(testSession, 3, 13); + testBucketedTableWriteSimple(testSession, 13, 7); + // compatible bucket count + testBucketedTableWriteSimple(testSession, 4, 8); + testBucketedTableWriteSimple(testSession, 8, 4); + } + } + + private void testBucketedTableWriteSimple(Session session, int inputBucketCount, int outputBucketCount) { - // create from bucketed table assertUpdate( - "CREATE TABLE hive.hive_test.hive_orders_bucketed_1 WITH (bucketed_by=array['orderkey'], bucket_count=11) AS " + + session, + format("CREATE TABLE hive.hive_test.test_hive_orders_bucketed_simple_input WITH (bucketed_by=array['orderkey'], bucket_count=%s) AS " + "SELECT orderkey, custkey, orderstatus, totalprice, orderdate, orderpriority, clerk, shippriority, comment " + - "FROM orders_bucketed", + "FROM orders_bucketed", inputBucketCount), 15000); assertQuery( + session, "SELECT count(*) " + - "FROM hive.hive_test.hive_orders_bucketed_1 " + - "WHERE \"$bucket\" = 1", - "SELECT 1365"); + "FROM hive.hive_test.test_hive_orders_bucketed_simple_input " + + "WHERE \"$bucket\" = 0", + format("SELECT count(*) FROM orders WHERE orderkey %% %s = 0", inputBucketCount)); - // create from non bucketed table assertUpdate( - "CREATE TABLE hive.hive_test.hive_orders_bucketed_2 WITH (bucketed_by=array['orderkey'], bucket_count=11) AS " + + session, + format("CREATE TABLE hive.hive_test.test_hive_orders_bucketed_simple_output WITH (bucketed_by=array['orderkey'], bucket_count=%s) AS " + "SELECT orderkey, custkey, orderstatus, totalprice, orderdate, orderpriority, clerk, shippriority, comment " + - "FROM orders", + "FROM hive.hive_test.test_hive_orders_bucketed_simple_input", outputBucketCount), + 15000); + assertQuery( + session, + "SELECT count(*) " + + "FROM hive.hive_test.test_hive_orders_bucketed_simple_output " + + "WHERE \"$bucket\" = 0", + format("SELECT count(*) FROM orders WHERE orderkey %% %s = 0", outputBucketCount)); + + dropTable("hive_test", "test_hive_orders_bucketed_simple_input"); + dropTable("hive_test", "test_hive_orders_bucketed_simple_output"); + } + + @Test + public void testBucketedTableWriteAggregation() + { + // aggregate on a bucket key and write to a bucketed table + // same bucket count + testBucketedTableWriteAggregation(getSession(), 8, 8); + for (Session testSession : getTestCompatibleBucketCountSessions()) { + // incompatible bucket count + testBucketedTableWriteAggregation(testSession, 7, 13); + testBucketedTableWriteAggregation(testSession, 13, 7); + // compatible bucket count + testBucketedTableWriteAggregation(testSession, 4, 8); + testBucketedTableWriteAggregation(testSession, 8, 4); + } + } + + private void testBucketedTableWriteAggregation(Session session, int inputBucketCount, int outputBucketCount) + { + assertUpdate( + session, + format("CREATE TABLE hive.hive_test.test_hive_orders_bucketed_aggregation_input WITH (bucketed_by=array['orderkey'], bucket_count=%s) AS " + + "SELECT orderkey, custkey, orderstatus, totalprice, orderdate, orderpriority, clerk, shippriority, comment " + + "FROM orders_bucketed", inputBucketCount), + 15000); + + assertUpdate( + session, + format("CREATE TABLE hive.hive_test.test_hive_orders_bucketed_aggregation_output WITH (bucketed_by=array['orderkey'], bucket_count=%s) AS " + + "SELECT orderkey, sum(totalprice) totalprice " + + "FROM hive.hive_test.test_hive_orders_bucketed_aggregation_input " + + "GROUP BY orderkey", outputBucketCount), + 15000); + assertQuery( + session, + "SELECT count(*) " + + "FROM hive.hive_test.test_hive_orders_bucketed_aggregation_output " + + "WHERE \"$bucket\" = 0", + format("SELECT count(*) FROM orders WHERE orderkey %% %s = 0", outputBucketCount)); + + dropTable("hive_test", "test_hive_orders_bucketed_aggregation_input"); + dropTable("hive_test", "test_hive_orders_bucketed_aggregation_output"); + } + + @Test + public void testBucketedTableWriteJoin() + { + // join on a bucket key and write to a bucketed table + // same bucket count + testBucketedTableWriteJoin(getSession(), 8, 8, 8); + for (Session testSession : getTestCompatibleBucketCountSessions()) { + // incompatible bucket count + testBucketedTableWriteJoin(testSession, 7, 13, 17); + testBucketedTableWriteJoin(testSession, 13, 7, 17); + testBucketedTableWriteJoin(testSession, 7, 7, 17); + // compatible bucket count + testBucketedTableWriteJoin(testSession, 4, 4, 8); + testBucketedTableWriteJoin(testSession, 8, 8, 4); + testBucketedTableWriteJoin(testSession, 4, 8, 8); + testBucketedTableWriteJoin(testSession, 8, 4, 8); + testBucketedTableWriteJoin(testSession, 4, 8, 4); + testBucketedTableWriteJoin(testSession, 8, 4, 4); + } + } + + private void testBucketedTableWriteJoin(Session session, int firstInputBucketCount, int secondInputBucketCount, int outputBucketCount) + { + assertUpdate( + session, + format("CREATE TABLE hive.hive_test.test_hive_orders_bucketed_join_input_1 WITH (bucketed_by=array['orderkey'], bucket_count=%s) AS " + + "SELECT orderkey, custkey, orderstatus, totalprice, orderdate, orderpriority, clerk, shippriority, comment " + + "FROM orders_bucketed", firstInputBucketCount), + 15000); + + assertUpdate( + session, + format("CREATE TABLE hive.hive_test.test_hive_orders_bucketed_join_input_2 WITH (bucketed_by=array['orderkey'], bucket_count=%s) AS " + + "SELECT orderkey, custkey, orderstatus, totalprice, orderdate, orderpriority, clerk, shippriority, comment " + + "FROM orders_bucketed", secondInputBucketCount), + 15000); + + assertUpdate( + session, + format("CREATE TABLE hive.hive_test.test_hive_orders_bucketed_join_output WITH (bucketed_by=array['orderkey'], bucket_count=%s) AS " + + "SELECT first.orderkey, second.totalprice " + + "FROM hive.hive_test.test_hive_orders_bucketed_join_input_1 first " + + "INNER JOIN hive.hive_test.test_hive_orders_bucketed_join_input_2 second " + + "ON first.orderkey = second.orderkey ", + outputBucketCount), 15000); assertQuery( + session, "SELECT count(*) " + - "FROM hive.hive_test.hive_orders_bucketed_2 " + - "WHERE \"$bucket\" = 1", - "SELECT 1365"); + "FROM hive.hive_test.test_hive_orders_bucketed_join_output " + + "WHERE \"$bucket\" = 0", + format("SELECT count(*) FROM orders WHERE orderkey %% %s = 0", outputBucketCount)); + + dropTable("hive_test", "test_hive_orders_bucketed_join_input_1"); + dropTable("hive_test", "test_hive_orders_bucketed_join_input_2"); + dropTable("hive_test", "test_hive_orders_bucketed_join_output"); + } + + private void dropTable(String schema, String table) + { + ((PrestoSparkQueryRunner) getQueryRunner()).getMetastore().dropTable(schema, table, true); } @Test @@ -161,6 +292,66 @@ public void testBucketedJoin() "JOIN orders_bucketed o " + "ON l.orderkey = o.orderkey " + "WHERE l.orderkey % 223 = 42 AND l.linenumber = 4 and o.orderstatus = 'O'"); + + // different number of buckets + assertUpdate("create table if not exists hive.hive_test.bucketed_nation_for_join_4 " + + "WITH (bucket_count = 4, bucketed_by = ARRAY['nationkey']) as select * from nation", + 25); + assertUpdate("create table if not exists hive.hive_test.bucketed_nation_for_join_8 " + + "WITH (bucket_count = 8, bucketed_by = ARRAY['nationkey']) as select * from nation", + 25); + + for (Session session : getTestCompatibleBucketCountSessions()) { + String expected = "SELECT * FROM nation first " + + "INNER JOIN nation second " + + "ON first.nationkey = second.nationkey"; + assertQuery( + session, + "SELECT * FROM hive.hive_test.bucketed_nation_for_join_4 first " + + "INNER JOIN hive.hive_test.bucketed_nation_for_join_8 second " + + "ON first.nationkey = second.nationkey", + expected); + assertQuery( + session, + "SELECT * FROM hive.hive_test.bucketed_nation_for_join_8 first " + + "INNER JOIN hive.hive_test.bucketed_nation_for_join_4 second " + + "ON first.nationkey = second.nationkey", + expected); + + expected = "SELECT * FROM nation first " + + "INNER JOIN nation second " + + "ON first.nationkey = second.nationkey " + + "INNER JOIN nation third " + + "ON second.nationkey = third.nationkey"; + + assertQuery( + session, + "SELECT * FROM hive.hive_test.bucketed_nation_for_join_4 first " + + "INNER JOIN hive.hive_test.bucketed_nation_for_join_8 second " + + "ON first.nationkey = second.nationkey " + + "INNER JOIN nation third " + + "ON second.nationkey = third.nationkey", + expected); + assertQuery( + session, + "SELECT * FROM hive.hive_test.bucketed_nation_for_join_8 first " + + "INNER JOIN hive.hive_test.bucketed_nation_for_join_4 second " + + "ON first.nationkey = second.nationkey " + + "INNER JOIN nation third " + + "ON second.nationkey = third.nationkey", + expected); + } + } + + private List getTestCompatibleBucketCountSessions() + { + return ImmutableList.of( + Session.builder(getSession()) + .setSystemProperty(PARTIAL_MERGE_PUSHDOWN_STRATEGY, PUSH_THROUGH_LOW_MEMORY_OPERATORS.name()) + .build(), + Session.builder(getSession()) + .setCatalogSessionProperty("hive", "optimize_mismatched_bucket_count", "true") + .build()); } @Test diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 2e2972f88b50..fc99c4bb3faa 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -27,6 +27,7 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.PlanFragmenter; import com.facebook.presto.sql.planner.PlanOptimizers; @@ -426,7 +427,8 @@ private QueryExplainer getQueryExplainer() costCalculator, new CostCalculatorWithEstimatedExchanges(costCalculator, taskCountEstimator), new CostComparator(featuresConfig), - taskCountEstimator).getPlanningTimeOptimizers(); + taskCountEstimator, + new PartitioningProviderManager()).getPlanningTimeOptimizers(); return new QueryExplainer( optimizers, new PlanFragmenter(metadata, queryRunner.getNodePartitioningManager(), new QueryManagerConfig(), sqlParser, new FeaturesConfig()),