Skip to content

Commit

Permalink
Refactor PrestoSparkRddFactory
Browse files Browse the repository at this point in the history
Move partitioning assignment to PrestoSparkQueryExecutionFactory

This will allow to simply follow the number of partitions set in the
bucketToPartition when creating a spark partitioner instead of running
the logic of assigning numbers of partitions twice
  • Loading branch information
arhimondr committed Apr 15, 2021
1 parent 75ec89c commit d35dbfe
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@
import com.facebook.presto.spark.classloader_interface.PrestoSparkConfInitializer;
import com.facebook.presto.spark.classloader_interface.PrestoSparkExecutionException;
import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow;
import com.facebook.presto.spark.classloader_interface.PrestoSparkPartitioner;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSession;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleSerializer;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats.Operation;
import com.facebook.presto.spark.classloader_interface.PrestoSparkStorageHandle;
Expand All @@ -81,13 +83,17 @@
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.connector.ConnectorCapabilities;
import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider;
import com.facebook.presto.spi.memory.MemoryPoolId;
import com.facebook.presto.spi.page.PagesSerde;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.resourceGroups.ResourceGroupId;
import com.facebook.presto.spi.storage.StorageCapabilities;
import com.facebook.presto.spi.storage.TempDataOperationContext;
import com.facebook.presto.spi.storage.TempStorage;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.SubPlan;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
Expand All @@ -104,12 +110,14 @@
import com.google.common.io.BaseEncoding;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import org.apache.spark.Partitioner;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkException;
import org.apache.spark.api.java.JavaFutureAction;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.ShuffledRDD;
import org.apache.spark.util.CollectionAccumulator;
import org.joda.time.DateTime;
import scala.Option;
Expand All @@ -132,8 +140,10 @@
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.IntStream;

import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue;
import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxBroadcastMemory;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxExecutionTime;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxRunTime;
Expand All @@ -156,6 +166,7 @@
import static com.facebook.presto.spark.SparkErrorCode.UNSUPPORTED_STORAGE_TYPE;
import static com.facebook.presto.spark.classloader_interface.ScalaUtils.collectScalaIterator;
import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
import static com.facebook.presto.spark.util.PrestoSparkUtils.classTag;
import static com.facebook.presto.spark.util.PrestoSparkUtils.computeNextTimeout;
import static com.facebook.presto.spark.util.PrestoSparkUtils.createPagesSerde;
import static com.facebook.presto.spark.util.PrestoSparkUtils.deserializeZstdCompressed;
Expand All @@ -167,7 +178,10 @@
import static com.facebook.presto.spi.connector.ConnectorCapabilities.SUPPORTS_PAGE_SINK_COMMIT;
import static com.facebook.presto.spi.storage.StorageCapabilities.REMOTELY_ACCESSIBLE;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.textDistributedPlan;
import static com.facebook.presto.util.Failures.toFailure;
import static com.google.common.base.MoreObjects.firstNonNull;
Expand Down Expand Up @@ -215,6 +229,7 @@ public class PrestoSparkQueryExecutionFactory
private final PrestoSparkTaskExecutorFactory prestoSparkTaskExecutorFactory;
private final SessionPropertyDefaults sessionPropertyDefaults;
private final WarningCollectorFactory warningCollectorFactory;
private final PartitioningProviderManager partitioningProviderManager;

private final Set<PrestoSparkCredentialsProvider> credentialsProviders;
private final Set<PrestoSparkAuthenticatorProvider> authenticatorProviders;
Expand Down Expand Up @@ -245,6 +260,7 @@ public PrestoSparkQueryExecutionFactory(
PrestoSparkTaskExecutorFactory prestoSparkTaskExecutorFactory,
SessionPropertyDefaults sessionPropertyDefaults,
WarningCollectorFactory warningCollectorFactory,
PartitioningProviderManager partitioningProviderManager,
Set<PrestoSparkCredentialsProvider> credentialsProviders,
Set<PrestoSparkAuthenticatorProvider> authenticatorProviders,
TempStorageManager tempStorageManager,
Expand Down Expand Up @@ -272,6 +288,7 @@ public PrestoSparkQueryExecutionFactory(
this.prestoSparkTaskExecutorFactory = requireNonNull(prestoSparkTaskExecutorFactory, "prestoSparkTaskExecutorFactory is null");
this.sessionPropertyDefaults = requireNonNull(sessionPropertyDefaults, "sessionPropertyDefaults is null");
this.warningCollectorFactory = requireNonNull(warningCollectorFactory, "warningCollectorFactory is null");
this.partitioningProviderManager = requireNonNull(partitioningProviderManager, "partitioningProviderManager is null");
this.credentialsProviders = ImmutableSet.copyOf(requireNonNull(credentialsProviders, "credentialsProviders is null"));
this.authenticatorProviders = ImmutableSet.copyOf(requireNonNull(authenticatorProviders, "authenticatorProviders is null"));
this.tempStorageManager = requireNonNull(tempStorageManager, "tempStorageManager is null");
Expand Down Expand Up @@ -381,6 +398,7 @@ public IPrestoSparkQueryExecution create(
planAndMore = queryPlanner.createQueryPlan(session, preparedQuery, warningCollector);
SubPlan fragmentedPlan = planFragmenter.fragmentQueryPlan(session, planAndMore.getPlan(), warningCollector);
log.info(textDistributedPlan(fragmentedPlan, metadata.getFunctionAndTypeManager(), session, true));
fragmentedPlan = configureOutputPartitioning(session, fragmentedPlan);
TableWriteInfo tableWriteInfo = getTableWriteInfo(session, fragmentedPlan);

JavaSparkContext javaSparkContext = new JavaSparkContext(sparkContext);
Expand Down Expand Up @@ -472,6 +490,61 @@ public IPrestoSparkQueryExecution create(
}
}

private SubPlan configureOutputPartitioning(Session session, SubPlan subPlan)
{
PlanFragment fragment = subPlan.getFragment();
if (!fragment.getPartitioningScheme().getBucketToPartition().isPresent()) {
PartitioningHandle partitioningHandle = fragment.getPartitioningScheme().getPartitioning().getHandle();
Optional<int[]> bucketToPartition = getBucketToPartition(session, partitioningHandle);
if (bucketToPartition.isPresent()) {
fragment = fragment.withBucketToPartition(bucketToPartition);
}
}
return new SubPlan(
fragment,
subPlan.getChildren().stream()
.map(child -> configureOutputPartitioning(session, child))
.collect(toImmutableList()));
}

private Optional<int[]> getBucketToPartition(Session session, PartitioningHandle partitioningHandle)
{
if (partitioningHandle.equals(FIXED_HASH_DISTRIBUTION)) {
int hashPartitionCount = getHashPartitionCount(session);
return Optional.of(IntStream.range(0, hashPartitionCount).toArray());
}
// FIXED_ARBITRARY_DISTRIBUTION is used for UNION ALL
// UNION ALL inputs could be source inputs or shuffle inputs
if (partitioningHandle.equals(FIXED_ARBITRARY_DISTRIBUTION)) {
// given modular hash function, partition count could be arbitrary size
// simply reuse hash_partition_count for convenience
// it can also be set by a separate session property if needed
int partitionCount = getHashPartitionCount(session);
return Optional.of(IntStream.range(0, partitionCount).toArray());
}
if (partitioningHandle.getConnectorId().isPresent()) {
int connectorPartitionCount = getPartitionCount(session, partitioningHandle);
return Optional.of(IntStream.range(0, connectorPartitionCount).toArray());
}
return Optional.empty();
}

private int getPartitionCount(Session session, PartitioningHandle partitioning)
{
ConnectorNodePartitioningProvider partitioningProvider = getPartitioningProvider(partitioning);
return partitioningProvider.getBucketCount(
partitioning.getTransactionHandle().orElse(null),
session.toConnectorSession(),
partitioning.getConnectorHandle());
}

private ConnectorNodePartitioningProvider getPartitioningProvider(PartitioningHandle partitioning)
{
ConnectorId connectorId = partitioning.getConnectorId()
.orElseThrow(() -> new IllegalArgumentException("Unexpected partitioning: " + partitioning));
return partitioningProviderManager.getPartitioningProvider(connectorId);
}

private TableWriteInfo getTableWriteInfo(Session session, SubPlan plan)
{
StreamingPlanSection streamingPlanSection = extractStreamingSections(plan);
Expand Down Expand Up @@ -1041,7 +1114,7 @@ private <T extends PrestoSparkTaskOutput> RddAndMore<T> createRdd(SubPlan subPla
}
else {
RddAndMore<PrestoSparkMutableRow> childRdd = createRdd(child, PrestoSparkMutableRow.class);
rddInputs.put(childFragment.getId(), childRdd.getRdd());
rddInputs.put(childFragment.getId(), partitionBy(childRdd.getRdd(), child.getFragment().getPartitioningScheme()));
broadcastDependencies.addAll(childRdd.getBroadcastDependencies());
}
}
Expand All @@ -1059,6 +1132,40 @@ private <T extends PrestoSparkTaskOutput> RddAndMore<T> createRdd(SubPlan subPla
return new RddAndMore<>(rdd, broadcastDependencies.build());
}

private static JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow> partitionBy(
JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow> rdd,
PartitioningScheme partitioningScheme)
{
Partitioner partitioner = createPartitioner(partitioningScheme);
JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow> javaPairRdd = rdd.partitionBy(partitioner);
ShuffledRDD<MutablePartitionId, PrestoSparkMutableRow, PrestoSparkMutableRow> shuffledRdd = (ShuffledRDD<MutablePartitionId, PrestoSparkMutableRow, PrestoSparkMutableRow>) javaPairRdd.rdd();
shuffledRdd.setSerializer(new PrestoSparkShuffleSerializer());
return JavaPairRDD.fromRDD(
shuffledRdd,
classTag(MutablePartitionId.class),
classTag(PrestoSparkMutableRow.class));
}

private static Partitioner createPartitioner(PartitioningScheme partitioningScheme)
{
PartitioningHandle partitioning = partitioningScheme.getPartitioning().getHandle();
if (partitioning.equals(SINGLE_DISTRIBUTION)) {
return new PrestoSparkPartitioner(1);
}
if (partitioning.equals(FIXED_HASH_DISTRIBUTION)
|| partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION)
|| partitioning.getConnectorId().isPresent()) {
int[] bucketToPartition = partitioningScheme.getBucketToPartition().orElseThrow(
() -> new IllegalArgumentException("bucketToPartition is expected to be assigned at this point"));
checkArgument(bucketToPartition.length > 0, "bucketToPartition is expected to be non empty");
int numberOfPartitions = IntStream.of(bucketToPartition)
.max()
.getAsInt() + 1;
return new PrestoSparkPartitioner(numberOfPartitions);
}
throw new IllegalArgumentException("Unexpected partitioning: " + partitioning);
}

private void validateStorageCapabilities(TempStorage tempStorage)
{
boolean isLocalMode = isLocalMaster(sparkContext.getConf());
Expand Down
Loading

0 comments on commit d35dbfe

Please sign in to comment.