diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 890b615b18c5f..08c61381c5780 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -59,7 +59,7 @@ class AvroRowReaderSuite val df = spark.read.format("avro").load(dir.getCanonicalPath) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _) => f + case BatchScanExec(_, f: AvroScan, _, _) => f } val filePath = fileScan.get.fileIndex.inputFiles(0) val fileSize = new File(new URI(filePath)).length diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index a70fbc0d833e8..e93c1c09c9fc2 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -2335,7 +2335,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { }) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _) => f + case BatchScanExec(_, f: AvroScan, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) @@ -2368,7 +2368,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { assert(filterCondition.isDefined) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _) => f + case BatchScanExec(_, f: AvroScan, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) @@ -2449,7 +2449,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { .where("value = 'a'") val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _) => f + case BatchScanExec(_, f: AvroScan, _, _) => f } assert(fileScan.nonEmpty) if (filtersPushdown) { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index ebd5b844cbc9b..9f93fbf96d2bd 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -372,7 +372,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu private def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, _) => + case DataSourceV2ScanRelation(_, scan, _, _) => assert(scan.isInstanceOf[V1ScanWrapper]) val wrapper = scan.asInstanceOf[V1ScanWrapper] assert(wrapper.pushedDownOperators.aggregation.isDefined) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index 4939b600dbfbd..8b543f1642a05 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -36,4 +36,13 @@ private[spark] object Utils { } ordering.leastOf(input.asJava, num).iterator.asScala } + + /** + * Only returns `Some` iff ALL elements in `input` are defined. In this case, it is + * equivalent to `Some(input.flatten)`. + * + * Otherwise, returns `None`. + */ + def sequenceToOption[T](input: Seq[Option[T]]): Option[Seq[T]] = + if (input.forall(_.isDefined)) Some(input.flatten) else None } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 49969cbe3ade6..b7376286826c2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -64,7 +64,13 @@ object MimaExcludes { // [SPARK-37600][BUILD] Upgrade to Hadoop 3.3.2 ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4Compressor"), ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4Factory"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4SafeDecompressor") + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4SafeDecompressor"), + + // [SPARK-37377][SQL] Initial implementation of Storage-Partitioned Join + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.connector.read.partitioning.ClusteredDistribution"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.connector.read.partitioning.Distribution"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.connector.read.partitioning.Partitioning.*"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.connector.read.partitioning.Partitioning.*") ) // Exclude rules for 3.2.x from 3.1.1 diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/Distribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/Distribution.java deleted file mode 100644 index a5911a820ac10..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/Distribution.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.read.partitioning; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.read.PartitionReader; - -/** - * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions (one {@link PartitionReader} outputs data for one - * partition). - *

- * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link PartitionReader}). - *

- * The instance of this interface is created and provided by Spark, then consumed by - * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to - * implement this interface, but need to catch as more concrete implementations of this interface - * as possible in {@link Partitioning#satisfy(Distribution)}. - *

- * Concrete implementations until now: - *

- * - * @since 3.0.0 - */ -@Evolving -public interface Distribution {} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/KeyGroupedPartitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/KeyGroupedPartitioning.java new file mode 100644 index 0000000000000..552d92ad0e8b8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/KeyGroupedPartitioning.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read.partitioning; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; + +/** + * Represents a partitioning where rows are split across partitions based on the + * partition transform expressions returned by {@link KeyGroupedPartitioning#keys}. + *

+ * Note: Data source implementations should make sure for a single partition, all of its rows + * must be evaluated to the same partition value after being applied by + * {@link KeyGroupedPartitioning#keys} expressions. Different partitions can share the same + * partition value: Spark will group these into a single logical partition during planning phase. + * + * @since 3.3.0 + */ +@Evolving +public class KeyGroupedPartitioning implements Partitioning { + private final Expression[] keys; + private final int numPartitions; + + public KeyGroupedPartitioning(Expression[] keys, int numPartitions) { + this.keys = keys; + this.numPartitions = numPartitions; + } + + /** + * Returns the partition transform expressions for this partitioning. + */ + public Expression[] keys() { + return keys; + } + + @Override + public int numPartitions() { + return numPartitions; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/Partitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/Partitioning.java index 7befab4ec5365..09f05d84e7ffb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/Partitioning.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/Partitioning.java @@ -18,33 +18,25 @@ package org.apache.spark.sql.connector.read.partitioning; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by - * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work - * like a snapshot. Once created, it should be deterministic and always report the same number of - * partitions and the same "satisfy" result for a certain distribution. + * {@link SupportsReportPartitioning#outputPartitioning()}. + *

+ * Note: implementors should NOT directly implement this interface. Instead, they should + * use one of the following subclasses: + *

* * @since 3.0.0 */ @Evolving public interface Partitioning { - /** - * Returns the number of partitions(i.e., {@link InputPartition}s) the data source outputs. + * Returns the number of partitions that the data is split across. */ int numPartitions(); - - /** - * Returns true if this partitioning can satisfy the given distribution, which means Spark does - * not need to shuffle the output data of this data source for some certain operations. - *

- * Note that, Spark may add new concrete implementations of {@link Distribution} in new releases. - * This method should be aware of it and always return false for unrecognized distributions. It's - * recommended to check every Spark new release and support new distributions if possible, to - * avoid shuffle at Spark side for more cases. - */ - boolean satisfy(Distribution distribution); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/ClusteredDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/UnknownPartitioning.java similarity index 61% rename from sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/ClusteredDistribution.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/UnknownPartitioning.java index ed0354484d7be..a2ae360d9a51f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/ClusteredDistribution.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/partitioning/UnknownPartitioning.java @@ -18,24 +18,22 @@ package org.apache.spark.sql.connector.read.partitioning; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.read.PartitionReader; /** - * A concrete implementation of {@link Distribution}. Represents a distribution where records that - * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link PartitionReader}. + * Represents a partitioning where rows are split across partitions in an unknown pattern. * - * @since 3.0.0 + * @since 3.3.0 */ @Evolving -public class ClusteredDistribution implements Distribution { +public class UnknownPartitioning implements Partitioning { + private final int numPartitions; - /** - * The names of the clustered columns. Note that they are order insensitive. - */ - public final String[] clusteredColumns; + public UnknownPartitioning(int numPartitions) { + this.numPartitions = numPartitions; + } - public ClusteredDistribution(String[] clusteredColumns) { - this.clusteredColumns = clusteredColumns; + @Override + public int numPartitions() { + return numPartitions; } } diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/InternalRowSet.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/InternalRowSet.scala new file mode 100644 index 0000000000000..9e8ec042694d0 --- /dev/null +++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/InternalRowSet.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Murmur3HashFunction, RowOrdering} +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * A mutable Set with [[InternalRow]] as its element type. It uses Spark's internal murmur hash to + * compute hash code from an row, and uses [[RowOrdering]] to perform equality checks. + * + * @param dataTypes the data types for the row keys this set holds + */ +class InternalRowSet(val dataTypes: Seq[DataType]) extends mutable.Set[InternalRow] { + private val baseSet = new mutable.HashSet[InternalRowContainer] + + private val structType = StructType(dataTypes.map(t => StructField("f", t))) + private val ordering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + + override def contains(row: InternalRow): Boolean = + baseSet.contains(new InternalRowContainer(row)) + + private class InternalRowContainer(val row: InternalRow) { + override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 42L).toInt + + override def equals(other: Any): Boolean = other match { + case r: InternalRowContainer => ordering.compare(row, r.row) == 0 + case r => this == r + } + } + + override def +=(row: InternalRow): InternalRowSet.this.type = { + val rowKey = new InternalRowContainer(row) + baseSet += rowKey + this + } + + override def -=(row: InternalRow): InternalRowSet.this.type = { + val rowKey = new InternalRowContainer(row) + baseSet -= rowKey + this + } + + override def iterator: Iterator[InternalRow] = { + baseSet.iterator.map(_.row) + } +} diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/InternalRowSet.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/InternalRowSet.scala new file mode 100644 index 0000000000000..66090fdf1872f --- /dev/null +++ b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/InternalRowSet.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Murmur3HashFunction, RowOrdering} +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * A mutable Set with [[InternalRow]] as its element type. It uses Spark's internal murmur hash to + * compute hash code from an row, and uses [[RowOrdering]] to perform equality checks. + * + * @param dataTypes the data types for the row keys this set holds + */ +class InternalRowSet(val dataTypes: Seq[DataType]) extends mutable.Set[InternalRow] { + private val baseSet = new mutable.HashSet[InternalRowContainer] + + private val structType = StructType(dataTypes.map(t => StructField("f", t))) + private val ordering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + + override def contains(row: InternalRow): Boolean = + baseSet.contains(new InternalRowContainer(row)) + + private class InternalRowContainer(val row: InternalRow) { + override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 42L).toInt + + override def equals(other: Any): Boolean = other match { + case r: InternalRowContainer => ordering.compare(row, r.row) == 0 + case r => this == r + } + } + + override def addOne(row: InternalRow): InternalRowSet.this.type = { + val rowKey = new InternalRowContainer(row) + baseSet += rowKey + this + } + + override def subtractOne(row: InternalRow): InternalRowSet.this.type = { + val rowKey = new InternalRowContainer(row) + baseSet -= rowKey + this + } + + override def clear(): Unit = { + baseSet.clear() + } + + override def iterator: Iterator[InternalRow] = { + baseSet.iterator.map(_.row) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala new file mode 100644 index 0000000000000..8412de554b711 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.connector.catalog.functions.BoundFunction +import org.apache.spark.sql.types.DataType + +/** + * Represents a partition transform expression, for instance, `bucket`, `days`, `years`, etc. + * + * @param function the transform function itself. Spark will use it to decide whether two + * partition transform expressions are compatible. + * @param numBucketsOpt the number of buckets if the transform is `bucket`. Unset otherwise. + */ +case class TransformExpression( + function: BoundFunction, + children: Seq[Expression], + numBucketsOpt: Option[Int] = None) extends Expression with Unevaluable { + + override def nullable: Boolean = true + + /** + * Whether this [[TransformExpression]] has the same semantics as `other`. + * For instance, `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or + * `year(c)`. + * + * This will be used, for instance, by Spark to determine whether storage-partitioned join can + * be triggered, by comparing partition transforms from both sides of the join and checking + * whether they are compatible. + * + * @param other the transform expression to compare to + * @return true if this and `other` has the same semantics w.r.t to transform, false otherwise. + */ + def isSameFunction(other: TransformExpression): Boolean = other match { + case TransformExpression(otherFunction, _, otherNumBucketsOpt) => + function.canonicalName() == otherFunction.canonicalName() && + numBucketsOpt == otherNumBucketsOpt + case _ => + false + } + + override def dataType: DataType = function.resultType() + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 72d072ff1a4a4..596d5d8b565df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -17,16 +17,22 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, IdentityTransform, NamedReference, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortValue} +import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} +import org.apache.spark.sql.connector.catalog.functions._ +import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.Utils.sequenceToOption /** * A utility class that converts public connector expressions into Catalyst expressions. */ -object V2ExpressionUtils extends SQLConfHelper { +object V2ExpressionUtils extends SQLConfHelper with Logging { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper def resolveRef[T <: NamedExpression](ref: NamedReference, plan: LogicalPlan): T = { @@ -44,20 +50,85 @@ object V2ExpressionUtils extends SQLConfHelper { refs.map(ref => resolveRef[T](ref, plan)) } - def toCatalyst(expr: V2Expression, query: LogicalPlan): Expression = { + /** + * Converts the array of input V2 [[V2SortOrder]] into their counterparts in catalyst. + */ + def toCatalystOrdering(ordering: Array[V2SortOrder], query: LogicalPlan): Seq[SortOrder] = { + sequenceToOption(ordering.map(toCatalyst(_, query))).asInstanceOf[Option[Seq[SortOrder]]] + .getOrElse(Seq.empty) + } + + def toCatalyst( + expr: V2Expression, + query: LogicalPlan, + funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = { expr match { + case t: Transform => + toCatalystTransform(t, query, funCatalogOpt) case SortValue(child, direction, nullOrdering) => - val catalystChild = toCatalyst(child, query) - SortOrder(catalystChild, toCatalyst(direction), toCatalyst(nullOrdering), Seq.empty) - case IdentityTransform(ref) => - resolveRef[NamedExpression](ref, query) + toCatalyst(child, query, funCatalogOpt).map { catalystChild => + SortOrder(catalystChild, toCatalyst(direction), toCatalyst(nullOrdering), Seq.empty) + } case ref: FieldReference => - resolveRef[NamedExpression](ref, query) + Some(resolveRef[NamedExpression](ref, query)) case _ => throw new AnalysisException(s"$expr is not currently supported") } } + def toCatalystTransform( + trans: Transform, + query: LogicalPlan, + funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = trans match { + case IdentityTransform(ref) => + Some(resolveRef[NamedExpression](ref, query)) + case BucketTransform(numBuckets, refs, sorted) + if sorted.isEmpty && refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) => + val resolvedRefs = refs.map(r => resolveRef[NamedExpression](r, query)) + // Create a dummy reference for `numBuckets` here and use that, together with `refs`, to + // look up the V2 function. + val numBucketsRef = AttributeReference("numBuckets", IntegerType, nullable = false)() + funCatalogOpt.flatMap { catalog => + loadV2Function(catalog, "bucket", Seq(numBucketsRef) ++ resolvedRefs).map { bound => + TransformExpression(bound, resolvedRefs, Some(numBuckets)) + } + } + case NamedTransform(name, refs) + if refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) => + val resolvedRefs = refs.map(_.asInstanceOf[NamedReference]).map { r => + resolveRef[NamedExpression](r, query) + } + funCatalogOpt.flatMap { catalog => + loadV2Function(catalog, name, resolvedRefs).map { bound => + TransformExpression(bound, resolvedRefs) + } + } + case _ => + throw new AnalysisException(s"Transform $trans is not currently supported") + } + + private def loadV2Function( + catalog: FunctionCatalog, + name: String, + args: Seq[Expression]): Option[BoundFunction] = { + val inputType = StructType(args.zipWithIndex.map { + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) + }) + try { + val unbound = catalog.loadFunction(Identifier.of(Array.empty, name)) + Some(unbound.bind(inputType)) + } catch { + case _: NoSuchFunctionException => + val parameterString = args.map(_.dataType.typeName).mkString("(", ", ", ")") + logWarning(s"V2 function $name with parameter types $parameterString is used in " + + "partition transforms, but its definition couldn't be found in the function catalog " + + "provided") + None + case _: UnsupportedOperationException => + None + } + } + private def toCatalyst(direction: V2SortDirection): SortDirection = direction match { case V2SortDirection.ASCENDING => Ascending case V2SortDirection.DESCENDING => Descending diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e4ff14b17a20c..69eeab426ed01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.physical import scala.collection.mutable +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType} @@ -305,6 +306,63 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) } +/** + * Represents a partitioning where rows are split across partitions based on transforms defined + * by `expressions`. `partitionValuesOpt`, if defined, should contain value of partition key(s) in + * ascending order, after evaluated by the transforms in `expressions`, for each input partition. + * In addition, its length must be the same as the number of input partitions (and thus is a 1-1 + * mapping), and each row in `partitionValuesOpt` must be unique. + * + * For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValuesOpt` is + * `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows + * in each partition have the same value for column `ts_col` (which is of timestamp type), after + * being applied by the `years` transform. + * + * On the other hand, `[0, 0, 1]` is not a valid value for `partitionValuesOpt` since `0` is + * duplicated twice. + * + * @param expressions partition expressions for the partitioning. + * @param numPartitions the number of partitions + * @param partitionValuesOpt if set, the values for the cluster keys of the distribution, must be + * in ascending order. + */ +case class KeyGroupedPartitioning( + expressions: Seq[Expression], + numPartitions: Int, + partitionValuesOpt: Option[Seq[InternalRow]] = None) extends Partitioning { + + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (requireAllClusterKeys) { + // Checks whether this partitioning is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + // We'll need to find leaf attributes from the partition expressions first. + val attributes = expressions.flatMap(_.collectLeaves()) + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } + + case _ => + false + } + } + } + + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = + KeyGroupedShuffleSpec(this, distribution) +} + +object KeyGroupedPartitioning { + def apply( + expressions: Seq[Expression], + partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + KeyGroupedPartitioning(expressions, partitionValues.size, Some(partitionValues)) + } +} + /** * Represents a partitioning where rows are split across partitions based on some total ordering of * the expressions specified in `ordering`. When data is partitioned in this manner, it guarantees: @@ -456,6 +514,8 @@ trait ShuffleSpec { * A true return value means that the data partitioning from this spec can be seen as * co-partitioned with the `other`, and therefore no shuffle is required when joining the two * sides. + * + * Note that Spark assumes this to be reflexive, symmetric and transitive. */ def isCompatibleWith(other: ShuffleSpec): Boolean @@ -574,6 +634,80 @@ case class HashShuffleSpec( override def numPartitions: Int = partitioning.numPartitions } +case class KeyGroupedShuffleSpec( + partitioning: KeyGroupedPartitioning, + distribution: ClusteredDistribution) extends ShuffleSpec { + + /** + * A sequence where each element is a set of positions of the partition expression to the cluster + * keys. For instance, if cluster keys are [a, b, b] and partition expressions are + * [bucket(4, a), years(b)], the result will be [(0), (1, 2)]. + * + * Note that we only allow each partition expression to contain a single partition key. + * Therefore the mapping here is very similar to that from `HashShuffleSpec`. + */ + lazy val keyPositions: Seq[mutable.BitSet] = { + val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet] + distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) => + distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos) + } + partitioning.expressions.map { e => + val leaves = e.collectLeaves() + assert(leaves.size == 1, s"Expected exactly one child from $e, but found ${leaves.size}") + distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty) + } + } + + private lazy val ordering: Ordering[InternalRow] = + RowOrdering.createNaturalAscendingOrdering(partitioning.expressions.map(_.dataType)) + + override def numPartitions: Int = partitioning.numPartitions + + override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { + // Here we check: + // 1. both distributions have the same number of clustering keys + // 2. both partitioning have the same number of partitions + // 3. partition expressions from both sides are compatible, which means: + // 3.1 both sides have the same number of partition expressions + // 3.2 for each pair of partition expressions at the same index, the corresponding + // partition keys must share overlapping positions in their respective clustering keys. + // 3.3 each pair of partition expressions at the same index must share compatible + // transform functions. + // 4. the partition values, if present on both sides, are following the same order. + case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) => + val expressions = partitioning.expressions + val otherExpressions = otherPartitioning.expressions + + distribution.clustering.length == otherDistribution.clustering.length && + numPartitions == other.numPartitions && + expressions.length == otherExpressions.length && { + val otherKeyPositions = otherSpec.keyPositions + keyPositions.zip(otherKeyPositions).forall { case (left, right) => + left.intersect(right).nonEmpty + } + } && expressions.zip(otherExpressions).forall { + case (l, r) => isExpressionCompatible(l, r) + } && partitioning.partitionValuesOpt.zip(otherPartitioning.partitionValuesOpt).forall { + case (left, right) => left.zip(right).forall { case (l, r) => + ordering.compare(l, r) == 0 + } + } + case ShuffleSpecCollection(specs) => + specs.exists(isCompatibleWith) + case _ => false + } + + private def isExpressionCompatible(left: Expression, right: Expression): Boolean = + (left, right) match { + case (_: LeafExpression, _: LeafExpression) => true + case (left: TransformExpression, right: TransformExpression) => + left.isSameFunction(right) + case _ => false + } + + override def canCreatePartitioning: Boolean = false +} + case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec { override def isCompatibleWith(other: ShuffleSpec): Boolean = { specs.exists(_.isCompatibleWith(other)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 04af7eda6aaa9..91809b6176c8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -188,4 +188,8 @@ private[sql] object CatalogV2Implicits { def parseColumnPath(name: String): Seq[String] = { CatalystSqlParser.parseMultipartIdentifier(name) } + + def parseFunctionName(name: String): Seq[String] = { + CatalystSqlParser.parseMultipartIdentifier(name) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index d93864991fc3d..6b0760ca1637b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability} @@ -113,11 +113,14 @@ case class DataSourceV2Relation( * @param relation a [[DataSourceV2Relation]] * @param scan a DSv2 [[Scan]] * @param output the output attributes of this relation + * @param keyGroupedPartitioning if set, the partitioning expressions that are used to split the + * rows in the scan across different partitions */ case class DataSourceV2ScanRelation( relation: DataSourceV2Relation, scan: Scan, - output: Seq[AttributeReference]) extends LeafNode with NamedRelation { + output: Seq[AttributeReference], + keyGroupedPartitioning: Option[Seq[Expression]] = None) extends LeafNode with NamedRelation { override def name: String = relation.table.name() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f394a6d4fe98..910f2db5e3479 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1334,6 +1334,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val V2_BUCKETING_ENABLED = buildConf("spark.sql.sources.v2.bucketing.enabled") + .doc(s"Similar to ${BUCKETING_ENABLED.key}, this config is used to enable bucketing for V2 " + + "data sources. When turned on, Spark will recognize the specific distribution " + + "reported by a V2 data source through SupportsReportPartitioning, and will try to " + + "avoid shuffle if necessary.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -4169,6 +4178,8 @@ class SQLConf extends Serializable with Logging { def autoBucketedScanEnabled: Boolean = getConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED) + def v2BucketingEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_ENABLED) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala index 202b03f28f082..be3baf9252006 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala @@ -55,4 +55,8 @@ class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog { def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction = { functions.put(ident, fn) } + + def clearFunctions(): Unit = { + functions.clear() + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 5d72b2060bfd8..a762b0f87839f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -29,10 +29,11 @@ import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} -import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.distributions.{ClusteredDistribution, Distribution, Distributions} import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend @@ -260,7 +261,8 @@ class InMemoryTable( var data: Seq[InputPartition], readSchema: StructType, tableSchema: StructType) - extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics { + extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics + with SupportsReportPartitioning { override def toBatch: Batch = this @@ -278,6 +280,13 @@ class InMemoryTable( InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows)) } + override def outputPartitioning(): Partitioning = { + InMemoryTable.this.distribution match { + case cd: ClusteredDistribution => new KeyGroupedPartitioning(cd.clustering(), data.size) + case _ => new UnknownPartitioning(data.size) + } + } + override def planInputPartitions(): Array[InputPartition] = data.toArray override def createReaderFactory(): PartitionReaderFactory = { @@ -293,9 +302,10 @@ class InMemoryTable( } override def filter(filters: Array[Filter]): Unit = { - if (partitioning.length == 1) { + if (partitioning.length == 1 && partitioning.head.references().length == 1) { + val ref = partitioning.head.references().head filters.foreach { - case In(attrName, values) if attrName == partitioning.head.name => + case In(attrName, values) if attrName == ref.toString => val matchingKeys = values.map(_.toString).toSet data = data.filter(partition => { val key = partition.asInstanceOf[BufferedRows].keyString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5e5de1fc0dc84..5c8e884c5a155 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3619,7 +3619,8 @@ class Dataset[T] private[sql]( fr.inputFiles case r: HiveTableRelation => r.tableMeta.storage.locationUri.map(_.toString).toArray - case DataSourceV2ScanRelation(DataSourceV2Relation(table: FileTable, _, _, _, _), _, _) => + case DataSourceV2ScanRelation(DataSourceV2Relation(table: FileTable, _, _, _, _), + _, _, _) => table.fileIndex.inputFiles }.flatten files.toSet.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index dcb02ab8556ec..bfe4bd2924118 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.SchemaPruning -import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes} +import org.apache.spark.sql.execution.datasources.v2.{V2ScanPartitioning, V2ScanRelationPushDown, V2Writes} import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} @@ -37,7 +37,11 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - SchemaPruning :: V2ScanRelationPushDown :: V2Writes :: PruneFileSourcePartitions :: Nil + Seq(SchemaPruning) :+ + V2ScanRelationPushDown :+ + V2ScanPartitioning :+ + V2Writes :+ + PruneFileSourcePartitions override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ @@ -75,6 +79,7 @@ class SparkOptimizer( ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+ ExtractPythonUDFs.ruleName :+ V2ScanRelationPushDown.ruleName :+ + V2ScanPartitioning.ruleName :+ V2Writes.ruleName /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 937d18d9eb76f..0b813d52ceed1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition} +import org.apache.spark.sql.catalyst.util.InternalRowSet import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering} +import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering} import org.apache.spark.sql.execution.datasources.DataSourceStrategy /** @@ -35,7 +36,8 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy case class BatchScanExec( output: Seq[AttributeReference], @transient scan: Scan, - runtimeFilters: Seq[Expression]) extends DataSourceV2ScanExecBase { + runtimeFilters: Seq[Expression], + keyGroupedPartitioning: Option[Seq[Expression]] = None) extends DataSourceV2ScanExecBase { @transient lazy val batch = scan.toBatch @@ -49,9 +51,9 @@ case class BatchScanExec( override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters) - @transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions() + @transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions() - @transient private lazy val filteredPartitions: Seq[InputPartition] = { + @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { val dataSourceFilters = runtimeFilters.flatMap { case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e) case _ => None @@ -68,16 +70,36 @@ case class BatchScanExec( val newPartitions = scan.toBatch.planInputPartitions() originalPartitioning match { - case p: DataSourcePartitioning if p.numPartitions != newPartitions.size => - throw new SparkException( - "Data source must have preserved the original partitioning during runtime filtering; " + - s"reported num partitions: ${p.numPartitions}, " + - s"num partitions after runtime filtering: ${newPartitions.size}") + case p: KeyGroupedPartitioning => + if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { + throw new SparkException("Data source must have preserved the original partitioning " + + "during runtime filtering: not all partitions implement HasPartitionKey after " + + "filtering") + } + + val newRows = new InternalRowSet(p.expressions.map(_.dataType)) + newRows ++= newPartitions.map(_.asInstanceOf[HasPartitionKey].partitionKey()) + val oldRows = p.partitionValuesOpt.get + + if (oldRows.size != newRows.size) { + throw new SparkException("Data source must have preserved the original partitioning " + + "during runtime filtering: the number of unique partition values obtained " + + s"through HasPartitionKey changed: before ${oldRows.size}, after ${newRows.size}") + } + + if (!oldRows.forall(newRows.contains)) { + throw new SparkException("Data source must have preserved the original partitioning " + + "during runtime filtering: the number of unique partition values obtained " + + s"through HasPartitionKey remain the same but do not exactly match") + } + + groupPartitions(newPartitions).get.map(_._2) + case _ => // no validation is needed as the data source did not report any specific partitioning + newPartitions.map(Seq(_)) } - newPartitions } else { partitions } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala index b19db8b0e5110..5f973e10b80f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala @@ -31,7 +31,8 @@ case class ContinuousScanExec( output: Seq[Attribute], @transient scan: Scan, @transient stream: ContinuousStream, - @transient start: Offset) extends DataSourceV2ScanExecBase { + @transient start: Offset, + keyGroupedPartitioning: Option[Seq[Expression]] = None) extends DataSourceV2ScanExecBase { // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { @@ -41,13 +42,14 @@ case class ContinuousScanExec( override def hashCode(): Int = stream.hashCode() - override lazy val partitions: Seq[InputPartition] = stream.planInputPartitions(start) + override lazy val inputPartitions: Seq[InputPartition] = stream.planInputPartitions(start) override lazy val readerFactory: ContinuousPartitionReaderFactory = { stream.createContinuousReaderFactory() } override lazy val inputRDD: RDD[InternalRow] = { + assert(partitions.forall(_.length == 1), "should only contain a single partition") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -56,7 +58,7 @@ case class ContinuousScanExec( sparkContext, conf.continuousStreamingExecutorQueueSize, conf.continuousStreamingExecutorPollIntervalMs, - partitions, + partitions.map(_.head), schema, readerFactory, customMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala deleted file mode 100644 index 9211ec25525fa..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} -import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Partitioning} - -/** - * An adapter from public data source partitioning to catalyst internal `Partitioning`. - */ -class DataSourcePartitioning( - partitioning: Partitioning, - colNames: AttributeMap[String]) extends physical.Partitioning { - - override val numPartitions: Int = partitioning.numPartitions() - - override def satisfies0(required: physical.Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case d: physical.ClusteredDistribution if isCandidate(d.clustering) => - val attrs = d.clustering.map(_.asInstanceOf[Attribute]) - partitioning.satisfy( - new ClusteredDistribution(attrs.map { a => - val name = colNames.get(a) - assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") - name.get - }.toArray)) - - case _ => false - } - } - } - - private def isCandidate(clustering: Seq[Expression]): Boolean = { - clustering.forall { - case a: Attribute => colNames.contains(a) - case _ => false - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index a1eb857c4ed41..09c8756ca0189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.vectorized.ColumnarBatch -class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) +class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition]) extends Partition with Serializable // TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for // columnar scan. class DataSourceRDD( sc: SparkContext, - @transient private val inputPartitions: Seq[InputPartition], + @transient private val inputPartitions: Seq[Seq[InputPartition]], partitionReaderFactory: PartitionReaderFactory, columnarReads: Boolean, customMetrics: Map[String, SQLMetric]) @@ -44,7 +44,7 @@ class DataSourceRDD( override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { - case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition) + case (inputPartitions, index) => new DataSourceRDDPartition(index, inputPartitions) }.toArray } @@ -54,31 +54,56 @@ class DataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { - val inputPartition = castPartition(split).inputPartition - val (iter, reader) = if (columnarReads) { - val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) - val iter = new MetricsBatchIterator( - new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) - (iter, batchReader) - } else { - val rowReader = partitionReaderFactory.createReader(inputPartition) - val iter = new MetricsRowIterator( - new PartitionIterator[InternalRow](rowReader, customMetrics)) - (iter, rowReader) - } - context.addTaskCompletionListener[Unit] { _ => - // In case of early stopping before consuming the entire iterator, - // we need to do one more metric update at the end of the task. - CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) - iter.forceUpdateMetrics() - reader.close() + + val iterator = new Iterator[Object] { + private val inputPartitions = castPartition(split).inputPartitions + private var currentIter: Option[Iterator[Object]] = None + private var currentIndex: Int = 0 + + override def hasNext: Boolean = currentIter.exists(_.hasNext) || advanceToNextIter() + + override def next(): Object = { + if (!hasNext) throw new NoSuchElementException("No more elements") + currentIter.get.next() + } + + private def advanceToNextIter(): Boolean = { + if (currentIndex >= inputPartitions.length) { + false + } else { + val inputPartition = inputPartitions(currentIndex) + currentIndex += 1 + + // TODO: SPARK-25083 remove the type erasure hack in data source scan + val (iter, reader) = if (columnarReads) { + val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) + val iter = new MetricsBatchIterator( + new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) + (iter, batchReader) + } else { + val rowReader = partitionReaderFactory.createReader(inputPartition) + val iter = new MetricsRowIterator( + new PartitionIterator[InternalRow](rowReader, customMetrics)) + (iter, rowReader) + } + context.addTaskCompletionListener[Unit] { _ => + // In case of early stopping before consuming the entire iterator, + // we need to do one more metric update at the end of the task. + CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) + iter.forceUpdateMetrics() + reader.close() + } + currentIter = Some(iter) + hasNext + } + } } - // TODO: SPARK-25083 remove the type erasure hack in data source scan - new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) + + new InterruptibleIterator(context, iterator).asInstanceOf[Iterator[InternalRow]] } override def getPreferredLocations(split: Partition): Seq[String] = { - castPartition(split).inputPartition.preferredLocations() + castPartition(split).inputPartitions.flatMap(_.preferredLocations()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index 92f454c1bcd1e..42909986fce05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan, SupportsReportPartitioning} +import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils @@ -43,16 +44,23 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { def scan: Scan - def partitions: Seq[InputPartition] - def readerFactory: PartitionReaderFactory + /** Optional partitioning expressions provided by the V2 data sources, through + * `SupportsReportPartitioning` */ + def keyGroupedPartitioning: Option[Seq[Expression]] + + protected def inputPartitions: Seq[InputPartition] + override def simpleString(maxFields: Int): String = { val result = s"$nodeName${truncatedString(output, "[", ", ", "]", maxFields)} ${scan.description()}" redact(result) } + def partitions: Seq[Seq[InputPartition]] = + groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_))) + /** * Shorthand for calling redact() without specifying redacting rules */ @@ -78,23 +86,64 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { |""".stripMargin } - override def outputPartitioning: physical.Partitioning = scan match { - case _ if partitions.length == 1 => - SinglePartition + override def outputPartitioning: physical.Partitioning = { + if (partitions.length == 1) SinglePartition + else groupedPartitions.map { partitionValues => + KeyGroupedPartitioning(keyGroupedPartitioning.get, + partitionValues.size, Some(partitionValues.map(_._1))) + }.getOrElse(super.outputPartitioning) + } - case s: SupportsReportPartitioning => - new DataSourcePartitioning( - s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + @transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = + groupPartitions(inputPartitions) - case _ => super.outputPartitioning + /** + * Group partition values for all the input partitions. This returns `Some` iff: + * - [[SQLConf.V2_BUCKETING_ENABLED]] is turned on + * - all input partitions implement [[HasPartitionKey]] + * - `keyGroupedPartitioning` is set + * + * The result, if defined, is a list of tuples where the first element is a partition value, + * and the second element is a list of input partitions that share the same partition value. + * + * A non-empty result means each partition is clustered on a single key and therefore eligible + * for further optimizations to eliminate shuffling in some operations such as join and aggregate. + */ + def groupPartitions( + inputPartitions: Seq[InputPartition]): Option[Seq[(InternalRow, Seq[InputPartition])]] = { + if (!SQLConf.get.v2BucketingEnabled) return None + keyGroupedPartitioning.flatMap { expressions => + val results = inputPartitions.takeWhile { + case _: HasPartitionKey => true + case _ => false + }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p)) + + if (results.length != inputPartitions.length || inputPartitions.isEmpty) { + // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here. + None + } else { + val partKeyType = expressions.map(_.dataType) + + val groupedPartitions = results.groupBy(_._1).toSeq.map { case (key, s) => + (key, s.map(_._2)) + } + + // also sort the input partitions according to their partition key order. This ensures + // a canonical order from both sides of a bucketed join, for example. + val keyOrdering: Ordering[(InternalRow, Seq[InputPartition])] = { + RowOrdering.createNaturalAscendingOrdering(partKeyType).on(_._1) + } + Some(groupedPartitions.sorted(keyOrdering)) + } + } } override def supportsColumnar: Boolean = { - require(partitions.forall(readerFactory.supportColumnarReads) || - !partitions.exists(readerFactory.supportColumnarReads), + require(inputPartitions.forall(readerFactory.supportColumnarReads) || + !inputPartitions.exists(readerFactory.supportColumnarReads), "Cannot mix row-based and columnar input partitions.") - partitions.exists(readerFactory.supportColumnarReads) + inputPartitions.exists(readerFactory.supportColumnarReads) } def inputRDD: RDD[InternalRow] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d1790796175a1..65f7bbba4e611 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -105,7 +105,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, DataSourceV2ScanRelation( - _, V1ScanWrapper(scan, pushed, pushedDownOperators), output)) => + _, V1ScanWrapper(scan, pushed, pushedDownOperators), output, _)) => val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != scan.readSchema()) { throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError( @@ -126,7 +126,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat withProjectAndFilter(project, filters, dsScan, needsUnsafeConversion = false) :: Nil case PhysicalOperation(project, filters, - DataSourceV2ScanRelation(_, scan: LocalScan, output)) => + DataSourceV2ScanRelation(_, scan: LocalScan, output, _)) => val localScanExec = LocalTableScanExec(output, scan.rows().toSeq) withProjectAndFilter(project, filters, localScanExec, needsUnsafeConversion = false) :: Nil @@ -138,7 +138,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case _: DynamicPruning => true case _ => false } - val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters) + val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters, + relation.keyGroupedPartitioning) withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation) @@ -262,7 +263,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case DeleteFromTable(relation, condition) => relation match { - case DataSourceV2ScanRelation(r, _, output) => + case DataSourceV2ScanRelation(r, _, output, _) => val table = r.table if (SubqueryExpression.hasSubquery(condition)) { throw QueryCompilationErrors.unsupportedDeleteByConditionWithSubqueryError(condition) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala index 0d9146d31c883..275255c9a3d39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala @@ -17,23 +17,27 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} -import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils.toCatalyst +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RepartitionByExpression, Sort} -import org.apache.spark.sql.connector.distributions.{ClusteredDistribution, OrderedDistribution, UnspecifiedDistribution} +import org.apache.spark.sql.connector.distributions._ import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.collection.Utils.sequenceToOption object DistributionAndOrderingUtils { def prepareQuery(write: Write, query: LogicalPlan, conf: SQLConf): LogicalPlan = write match { case write: RequiresDistributionAndOrdering => val numPartitions = write.requiredNumPartitions() + val distribution = write.requiredDistribution match { - case d: OrderedDistribution => d.ordering.map(e => toCatalyst(e, query)) - case d: ClusteredDistribution => d.clustering.map(e => toCatalyst(e, query)) - case _: UnspecifiedDistribution => Array.empty[Expression] + case d: OrderedDistribution => toCatalystOrdering(d.ordering(), query) + case d: ClusteredDistribution => + sequenceToOption(d.clustering.map(e => toCatalyst(e, query))) + .getOrElse(Seq.empty[Expression]) + case _: UnspecifiedDistribution => Seq.empty[Expression] } val queryWithDistribution = if (distribution.nonEmpty) { @@ -52,10 +56,7 @@ object DistributionAndOrderingUtils { query } - val ordering = write.requiredOrdering.toSeq - .map(e => toCatalyst(e, query)) - .asInstanceOf[Seq[SortOrder]] - + val ordering = toCatalystOrdering(write.requiredOrdering, query) val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { Sort(ordering, global = false, queryWithDistribution) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala index 1430a32c8e81a..3db7fb7851249 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} @@ -31,7 +31,8 @@ case class MicroBatchScanExec( @transient scan: Scan, @transient stream: MicroBatchStream, @transient start: Offset, - @transient end: Offset) extends DataSourceV2ScanExecBase { + @transient end: Offset, + keyGroupedPartitioning: Option[Seq[Expression]] = None) extends DataSourceV2ScanExecBase { // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { @@ -41,7 +42,7 @@ case class MicroBatchScanExec( override def hashCode(): Int = stream.hashCode() - override lazy val partitions: Seq[InputPartition] = stream.planInputPartitions(start, end) + override lazy val inputPartitions: Seq[InputPartition] = stream.planInputPartitions(start, end) override lazy val readerFactory: PartitionReaderFactory = stream.createReaderFactory() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioning.scala new file mode 100644 index 0000000000000..8d2b3a8880cd3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioning.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.FunctionCatalog +import org.apache.spark.sql.connector.read.SupportsReportPartitioning +import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning} +import org.apache.spark.util.collection.Utils.sequenceToOption + +/** + * Extracts [[DataSourceV2ScanRelation]] from the input logical plan, converts any V2 partitioning + * reported by data sources to their catalyst counterparts. Then, annotates the plan with the + * result. + */ +object V2ScanPartitioning extends Rule[LogicalPlan] with SQLConfHelper { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportPartitioning, _, _) => + val funCatalogOpt = relation.catalog.flatMap { + case c: FunctionCatalog => Some(c) + case _ => None + } + + val catalystPartitioning = scan.outputPartitioning() match { + case kgp: KeyGroupedPartitioning => sequenceToOption(kgp.keys().map( + V2ExpressionUtils.toCatalyst(_, relation, funCatalogOpt))) + case _: UnknownPartitioning => None + case p => throw new IllegalArgumentException("Unsupported data source V2 partitioning " + + "type: " + p.getClass.getSimpleName) + } + + d.copy(keyGroupedPartitioning = catalystPartitioning) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 89d66034f06cd..114d58c739e29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -78,7 +78,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { } else { None } - case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _)) => + case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _)) => val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r) if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { Some(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index de1806ab87b4c..67a58da89625e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf /** * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] @@ -137,8 +138,16 @@ case class EnsureRequirements( Some(finalCandidateSpecs.values.maxBy(_.numPartitions)) } + // Check if 1) all children are of `KeyGroupedPartitioning` and 2) they are all compatible + // with each other. If both are true, skip shuffle. + val allCompatible = childrenIndexes.sliding(2).forall { + case Seq(a, b) => + checkKeyGroupedSpec(specs(a)) && checkKeyGroupedSpec(specs(b)) && + specs(a).isCompatibleWith(specs(b)) + } + children = children.zip(requiredChildDistributions).zipWithIndex.map { - case ((child, _), idx) if !childrenIndexes.contains(idx) => + case ((child, _), idx) if allCompatible || !childrenIndexes.contains(idx) => child case ((child, dist), idx) => if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { @@ -177,6 +186,26 @@ case class EnsureRequirements( children } + private def checkKeyGroupedSpec(shuffleSpec: ShuffleSpec): Boolean = { + def check(spec: KeyGroupedShuffleSpec): Boolean = { + val attributes = spec.partitioning.expressions.flatMap(_.collectLeaves()) + val clustering = spec.distribution.clustering + + if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) { + attributes.length == clustering.length && attributes.zip(clustering).forall { + case (l, r) => l.semanticEquals(r) + } + } else { + true // already validated in `KeyGroupedPartitioning.satisfies` + } + } + shuffleSpec match { + case spec: KeyGroupedShuffleSpec => check(spec) + case ShuffleSpecCollection(specs) => specs.exists(checkKeyGroupedSpec) + case _ => false + } + } + private def reorder( leftKeys: IndexedSeq[Expression], rightKeys: IndexedSeq[Expression], @@ -256,6 +285,16 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) + case (Some(KeyGroupedPartitioning(clustering, _, _)), _) => + val leafExprs = clustering.flatMap(_.collectLeaves()) + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) + .orElse(reorderJoinKeysRecursively( + leftKeys, rightKeys, None, rightPartitioning)) + case (_, Some(KeyGroupedPartitioning(clustering, _, _))) => + val leafExprs = clustering.flatMap(_.collectLeaves()) + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) + .orElse(reorderJoinKeysRecursively( + leftKeys, rightKeys, leftPartitioning, None)) case (Some(PartitioningCollection(partitionings)), _) => partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) => res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning)) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java index e5c50beeaf611..08be0ce9543a7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java @@ -18,18 +18,17 @@ package test.org.apache.spark.sql.connector; import java.io.IOException; -import java.util.Arrays; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.connector.TestingV2Source; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.Transform; -import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.read.*; -import org.apache.spark.sql.connector.read.partitioning.ClusteredDistribution; -import org.apache.spark.sql.connector.read.partitioning.Distribution; import org.apache.spark.sql.connector.read.partitioning.Partitioning; +import org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning; import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaPartitionAwareDataSource implements TestingV2Source { @@ -51,7 +50,8 @@ public PartitionReaderFactory createReaderFactory() { @Override public Partitioning outputPartitioning() { - return new MyPartitioning(); + Expression[] clustering = new Transform[] { Expressions.identity("i") }; + return new KeyGroupedPartitioning(clustering, 2); } } @@ -70,25 +70,7 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { }; } - static class MyPartitioning implements Partitioning { - - @Override - public int numPartitions() { - return 2; - } - - @Override - public boolean satisfy(Distribution distribution) { - if (distribution instanceof ClusteredDistribution) { - String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; - return Arrays.asList(clusteredCols).contains("i"); - } - - return false; - } - } - - static class SpecificInputPartition implements InputPartition { + static class SpecificInputPartition implements InputPartition, HasPartitionKey { int[] i; int[] j; @@ -97,6 +79,11 @@ static class SpecificInputPartition implements InputPartition { this.i = i; this.j = j; } + + @Override + public InternalRow partitionKey() { + return new GenericInternalRow(new Object[] {i[0]}); + } } static class SpecificReaderFactory implements PartitionReaderFactory { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index da7f81901c5bb..9e7eb1d0ad501 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -842,7 +842,7 @@ class FileBasedDataSourceSuite extends QueryTest }) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: FileScan, _) => f + case BatchScanExec(_, f: FileScan, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) @@ -882,7 +882,7 @@ class FileBasedDataSourceSuite extends QueryTest assert(filterCondition.isDefined) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: FileScan, _) => f + case BatchScanExec(_, f: FileScan, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 8174888ad8321..44d4f1fa825d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -27,15 +27,16 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{Literal, Transform} +import org.apache.spark.sql.connector.expressions.{FieldReference, Literal, Transform} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructType} @@ -245,34 +246,36 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS test("partitioning reporting") { import org.apache.spark.sql.functions.{count, sum} - Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls => - withClue(cls.getName) { - val df = spark.read.format(cls.getName).load() - checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - - val groupByColA = df.groupBy($"i").agg(sum($"j")) - checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) - assert(collectFirst(groupByColA.queryExecution.executedPlan) { - case e: ShuffleExchangeExec => e - }.isEmpty) - - val groupByColAB = df.groupBy($"i", $"j").agg(count("*")) - checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) - assert(collectFirst(groupByColAB.queryExecution.executedPlan) { - case e: ShuffleExchangeExec => e - }.isEmpty) - - val groupByColB = df.groupBy($"j").agg(sum($"i")) - checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) - assert(collectFirst(groupByColB.queryExecution.executedPlan) { - case e: ShuffleExchangeExec => e - }.isDefined) - - val groupByAPlusB = df.groupBy($"i" + $"j").agg(count("*")) - checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) - assert(collectFirst(groupByAPlusB.queryExecution.executedPlan) { - case e: ShuffleExchangeExec => e - }.isDefined) + withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "true") { + Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) + + val groupByColA = df.groupBy($"i").agg(sum($"j")) + checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) + assert(collectFirst(groupByColA.queryExecution.executedPlan) { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColAB = df.groupBy($"i", $"j").agg(count("*")) + checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) + assert(collectFirst(groupByColAB.queryExecution.executedPlan) { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColB = df.groupBy($"j").agg(sum($"i")) + checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) + assert(collectFirst(groupByColB.queryExecution.executedPlan) { + case e: ShuffleExchangeExec => e + }.isDefined) + + val groupByAPlusB = df.groupBy($"i" + $"j").agg(count("*")) + checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) + assert(collectFirst(groupByAPlusB.queryExecution.executedPlan) { + case e: ShuffleExchangeExec => e + }.isDefined) + } } } } @@ -896,7 +899,8 @@ class PartitionAwareDataSource extends TestingV2Source { SpecificReaderFactory } - override def outputPartitioning(): Partitioning = new MyPartitioning + override def outputPartitioning(): Partitioning = + new KeyGroupedPartitioning(Array(FieldReference("i")), 2) } override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { @@ -904,18 +908,13 @@ class PartitionAwareDataSource extends TestingV2Source { new MyScanBuilder() } } - - class MyPartitioning extends Partitioning { - override def numPartitions(): Int = 2 - - override def satisfy(distribution: Distribution): Boolean = distribution match { - case c: ClusteredDistribution => c.clusteredColumns.contains("i") - case _ => false - } - } } -case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition +case class SpecificInputPartition( + i: Array[Int], + j: Array[Int]) extends InputPartition with HasPartitionKey { + override def partitionKey(): InternalRow = InternalRow.fromSeq(Seq(i(0))) +} object SpecificReaderFactory extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala new file mode 100644 index 0000000000000..f4317e632761c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.connector.catalog.InMemoryCatalog +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.test.SharedSparkSession + +abstract class DistributionAndOrderingSuiteBase + extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) + } + + override def afterAll(): Unit = { + spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") + super.afterAll() + } + + protected val resolver: Resolver = conf.resolver + + protected def resolvePartitioning[T <: QueryPlan[T]]( + partitioning: Partitioning, + plan: QueryPlan[T]): Partitioning = partitioning match { + case HashPartitioning(exprs, numPartitions) => + HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) + case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) => + KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, + partitionValues) + case PartitioningCollection(partitionings) => + PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) + case RangePartitioning(ordering, numPartitions) => + RangePartitioning(ordering.map(resolveAttrs(_, plan).asInstanceOf[SortOrder]), numPartitions) + case p @ SinglePartition => + p + case p: UnknownPartitioning => + p + case p => + fail(s"unexpected partitioning: $p") + } + + protected def resolveDistribution[T <: QueryPlan[T]]( + distribution: physical.Distribution, + plan: QueryPlan[T]): physical.Distribution = distribution match { + case physical.ClusteredDistribution(clustering, numPartitions, _) => + physical.ClusteredDistribution(clustering.map(resolveAttrs(_, plan)), numPartitions) + case physical.OrderedDistribution(ordering) => + physical.OrderedDistribution(ordering.map(resolveAttrs(_, plan).asInstanceOf[SortOrder])) + case physical.UnspecifiedDistribution => + physical.UnspecifiedDistribution + case d => + fail(s"unexpected distribution: $d") + } + + protected def resolveAttrs[T <: QueryPlan[T]]( + expr: catalyst.expressions.Expression, + plan: QueryPlan[T]): catalyst.expressions.Expression = { + + expr.transform { + case UnresolvedAttribute(Seq(attrName)) => + plan.output.find(attr => resolver(attr.name, attrName)).get + case UnresolvedAttribute(nameParts) => + val attrName = nameParts.mkString(".") + fail(s"cannot resolve a nested attr: $attrName") + } + } + + protected def attr(name: String): UnresolvedAttribute = { + UnresolvedAttribute(name) + } + + protected def catalog: InMemoryCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("testcat") + catalog.asTableCatalog.asInstanceOf[InMemoryCatalog] + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala new file mode 100644 index 0000000000000..834faedd1ceef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -0,0 +1,475 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector + +import java.util.Collections + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder => catalystSortOrder, TransformExpression} +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog +import org.apache.spark.sql.connector.catalog.functions._ +import org.apache.spark.sql.connector.distributions.Distribution +import org.apache.spark.sql.connector.distributions.Distributions +import org.apache.spark.sql.connector.expressions._ +import org.apache.spark.sql.connector.expressions.Expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.types._ + +class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { + private var originalV2BucketingEnabled: Boolean = false + private var originalAutoBroadcastJoinThreshold: Long = -1 + + override def beforeAll(): Unit = { + super.beforeAll() + originalV2BucketingEnabled = conf.getConf(V2_BUCKETING_ENABLED) + conf.setConf(V2_BUCKETING_ENABLED, true) + originalAutoBroadcastJoinThreshold = conf.getConf(AUTO_BROADCASTJOIN_THRESHOLD) + conf.setConf(AUTO_BROADCASTJOIN_THRESHOLD, -1L) + } + + override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + conf.setConf(V2_BUCKETING_ENABLED, originalV2BucketingEnabled) + conf.setConf(AUTO_BROADCASTJOIN_THRESHOLD, originalAutoBroadcastJoinThreshold) + } + } + + before { + Seq(UnboundYearsFunction, UnboundDaysFunction, UnboundBucketFunction).foreach { f => + catalog.createFunction(Identifier.of(Array.empty, f.name()), f) + } + } + + after { + catalog.clearTables() + catalog.clearFunctions() + } + + private val emptyProps: java.util.Map[String, String] = { + Collections.emptyMap[String, String] + } + private val table: String = "tbl" + private val schema = new StructType() + .add("id", IntegerType) + .add("data", StringType) + .add("ts", TimestampType) + + test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { + val partitions: Array[Transform] = Array(Expressions.years("ts")) + + // create a table with 3 partitions, partitioned by `years` transform + createTable(table, schema, partitions, + Distributions.clustered(partitions.map(_.asInstanceOf[Expression]))) + sql(s"INSERT INTO testcat.ns.$table VALUES " + + s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") + + var df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY ts") + val catalystDistribution = physical.ClusteredDistribution( + Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) + val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + + checkQueryPlan(df, catalystDistribution, + physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + + // multiple group keys should work too as long as partition keys are subset of them + df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") + checkQueryPlan(df, catalystDistribution, + physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + } + + test("non-clustered distribution: fallback to super.partitioning") { + val partitions: Array[Transform] = Array(years("ts")) + val ordering: Array[SortOrder] = Array(sort(FieldReference("ts"), + SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + + createTable(table, schema, partitions, Distributions.ordered(ordering), ordering) + sql(s"INSERT INTO testcat.ns.$table VALUES " + + s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") + + val df = sql(s"SELECT * FROM testcat.ns.$table") + val catalystOrdering = Seq(catalystSortOrder(attr("ts"), Ascending)) + val catalystDistribution = physical.OrderedDistribution(catalystOrdering) + + checkQueryPlan(df, catalystDistribution, physical.UnknownPartitioning(0)) + } + + test("non-clustered distribution: no partition") { + val partitions: Array[Transform] = Array(bucket(32, "ts")) + createTable(table, schema, partitions, + Distributions.clustered(partitions.map(_.asInstanceOf[Expression]))) + + val df = sql(s"SELECT * FROM testcat.ns.$table") + val distribution = physical.ClusteredDistribution( + Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + + checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) + } + + test("non-clustered distribution: single partition") { + val partitions: Array[Transform] = Array(bucket(32, "ts")) + createTable(table, schema, partitions, + Distributions.clustered(partitions.map(_.asInstanceOf[Expression]))) + sql(s"INSERT INTO testcat.ns.$table VALUES (0, 'aaa', CAST('2020-01-01' AS timestamp))") + + val df = sql(s"SELECT * FROM testcat.ns.$table") + val distribution = physical.ClusteredDistribution( + Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + + checkQueryPlan(df, distribution, physical.SinglePartition) + } + + test("non-clustered distribution: no V2 catalog") { + spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName) + val nonFunctionCatalog = spark.sessionState.catalogManager.catalog("testcat2") + .asInstanceOf[InMemoryTableCatalog] + val partitions: Array[Transform] = Array(bucket(32, "ts")) + createTable(table, schema, partitions, + Distributions.clustered(partitions.map(_.asInstanceOf[Expression])), + catalog = nonFunctionCatalog) + sql(s"INSERT INTO testcat2.ns.$table VALUES " + + s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") + + val df = sql(s"SELECT * FROM testcat2.ns.$table") + val distribution = physical.UnspecifiedDistribution + + try { + checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) + } finally { + spark.conf.unset("spark.sql.catalog.testcat2") + } + } + + test("non-clustered distribution: no V2 function provided") { + catalog.clearFunctions() + + val partitions: Array[Transform] = Array(bucket(32, "ts")) + createTable(table, schema, partitions, + Distributions.clustered(partitions.map(_.asInstanceOf[Expression]))) + sql(s"INSERT INTO testcat.ns.$table VALUES " + + s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") + + val df = sql(s"SELECT * FROM testcat.ns.$table") + val distribution = physical.UnspecifiedDistribution + + checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) + } + + test("non-clustered distribution: V2 bucketing disabled") { + withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "false") { + val partitions: Array[Transform] = Array(bucket(32, "ts")) + createTable(table, schema, partitions, + Distributions.clustered(partitions.map(_.asInstanceOf[Expression]))) + sql(s"INSERT INTO testcat.ns.$table VALUES " + + s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") + + val df = sql(s"SELECT * FROM testcat.ns.$table") + val distribution = physical.ClusteredDistribution( + Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + + checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) + } + } + + /** + * Check whether the query plan from `df` has the expected `distribution`, `ordering` and + * `partitioning`. + */ + private def checkQueryPlan( + df: DataFrame, + distribution: physical.Distribution, + partitioning: physical.Partitioning): Unit = { + // check distribution & ordering are correctly populated in logical plan + val relation = df.queryExecution.optimizedPlan.collect { + case r: DataSourceV2ScanRelation => r + }.head + + resolveDistribution(distribution, relation) match { + case physical.ClusteredDistribution(clustering, _, _) => + assert(relation.keyGroupedPartitioning.isDefined && + relation.keyGroupedPartitioning.get == clustering) + case _ => + assert(relation.keyGroupedPartitioning.isEmpty) + } + + // check distribution, ordering and output partitioning are correctly populated in physical plan + val scan = collect(df.queryExecution.executedPlan) { + case s: BatchScanExec => s + }.head + + val expectedPartitioning = resolvePartitioning(partitioning, scan) + assert(expectedPartitioning == scan.outputPartitioning) + } + + private def createTable( + table: String, + schema: StructType, + partitions: Array[Transform], + distribution: Distribution = Distributions.unspecified(), + ordering: Array[expressions.SortOrder] = Array.empty, + catalog: InMemoryTableCatalog = catalog): Unit = { + catalog.createTable(Identifier.of(Array("ns"), table), + schema, partitions, emptyProps, distribution, ordering, None) + } + + private val customers: String = "customers" + private val customers_schema = new StructType() + .add("customer_name", StringType) + .add("customer_age", IntegerType) + .add("customer_id", LongType) + + private val orders: String = "orders" + private val orders_schema = new StructType() + .add("order_amount", DoubleType) + .add("customer_id", LongType) + + private def testWithCustomersAndOrders( + customers_partitions: Array[Transform], + customers_distribution: Distribution, + orders_partitions: Array[Transform], + orders_distribution: Distribution, + expectedNumOfShuffleExecs: Int): Unit = { + createTable(customers, customers_schema, customers_partitions, customers_distribution) + sql(s"INSERT INTO testcat.ns.$customers VALUES " + + s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)") + + createTable(orders, orders_schema, orders_partitions, orders_distribution) + sql(s"INSERT INTO testcat.ns.$orders VALUES " + + s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50, 3)") + + val df = sql("SELECT customer_name, customer_age, order_amount " + + s"FROM testcat.ns.$customers c JOIN testcat.ns.$orders o " + + "ON c.customer_id = o.customer_id ORDER BY c.customer_id, order_amount") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.length == expectedNumOfShuffleExecs) + + checkAnswer(df, + Seq(Row("aaa", 10, 100.0), Row("aaa", 10, 200.0), Row("bbb", 20, 150.0), + Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50))) + } + + private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = { + // here we skip collecting shuffle operators that are not associated with SMJ + collect(plan) { + case s: SortMergeJoinExec => s + }.flatMap(smj => + collect(smj) { + case s: ShuffleExchangeExec => s + }) + } + + test("partitioned join: exact distribution (same number of buckets) from both sides") { + val customers_partitions = Array(bucket(4, "customer_id")) + val orders_partitions = Array(bucket(4, "customer_id")) + + testWithCustomersAndOrders(customers_partitions, + Distributions.clustered(customers_partitions.toArray), + orders_partitions, + Distributions.clustered(orders_partitions.toArray), + 0) + } + + test("partitioned join: number of buckets mismatch should trigger shuffle") { + val customers_partitions = Array(bucket(4, "customer_id")) + val orders_partitions = Array(bucket(2, "customer_id")) + + // should shuffle both sides when number of buckets are not the same + testWithCustomersAndOrders(customers_partitions, + Distributions.clustered(customers_partitions.toArray), + orders_partitions, + Distributions.clustered(orders_partitions.toArray), + 2) + } + + test("partitioned join: only one side reports partitioning") { + val customers_partitions = Array(bucket(4, "customer_id")) + val orders_partitions = Array(bucket(2, "customer_id")) + + testWithCustomersAndOrders(customers_partitions, + Distributions.clustered(customers_partitions.toArray), + orders_partitions, + Distributions.unspecified(), + 2) + } + + private val items: String = "items" + private val items_schema: StructType = new StructType() + .add("id", LongType) + .add("name", StringType) + .add("price", FloatType) + .add("arrive_time", TimestampType) + + private val purchases: String = "purchases" + private val purchases_schema: StructType = new StructType() + .add("item_id", LongType) + .add("price", FloatType) + .add("time", TimestampType) + + test("partitioned join: join with two partition keys and matching & sorted partitions") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, items_schema, items_partitions, + Distributions.clustered(items_partitions.toArray)) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchases_schema, purchases_partitions, + Distributions.clustered(purchases_partitions.toArray)) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") + + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + checkAnswer(df, + Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0), + Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5)) + ) + } + + test("partitioned join: join with two partition keys and unsorted partitions") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, items_schema, items_partitions, + Distributions.clustered(items_partitions.toArray)) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchases_schema, purchases_partitions, + Distributions.clustered(purchases_partitions.toArray)) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") + + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + checkAnswer(df, + Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0), + Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5)) + ) + } + + test("partitioned join: join with two partition keys and different # of partition keys") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, items_schema, items_partitions, + Distributions.clustered(items_partitions.toArray)) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchases_schema, purchases_partitions, + Distributions.clustered(purchases_partitions.toArray)) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp))") + + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id AND i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, "should add shuffle when partition keys mismatch") + } + + test("data source partitioning + dynamic partition filtering") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "10") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions, + Distributions.clustered(items_partitions.toArray)) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchases_schema, purchases_partitions, + Distributions.clustered(purchases_partitions.toArray)) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") + + // number of unique partitions changed after dynamic filtering - should throw exception + var df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p WHERE " + + s"i.id = p.item_id AND i.price > 40.0") + val e = intercept[Exception](df.collect()) + assert(e.getMessage.contains("number of unique partition values")) + + // dynamic filtering doesn't change partitioning so storage-partitioned join should kick in + df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p WHERE " + + s"i.id = p.item_id AND i.price >= 10.0") + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + checkAnswer(df, Seq(Row(303.5))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 5f8684a144778..36efe5ec1d2ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -19,41 +19,28 @@ package org.apache.spark.sql.connector import java.util.Collections -import org.scalatest.BeforeAndAfter - -import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder} import org.apache.spark.sql.connector.expressions.LogicalExpressions._ import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions.lit import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger} -import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.sql.util.QueryExecutionListener -class WriteDistributionAndOrderingSuite - extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { - - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase { import testImplicits._ - before { - spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) - } - after { spark.sessionState.catalogManager.reset() - spark.sessionState.conf.unsetConf("spark.sql.catalog.testcat") } private val microBatchPrefix = "micro_batch_" @@ -65,8 +52,6 @@ class WriteDistributionAndOrderingSuite .add("id", IntegerType) .add("data", StringType) - private val resolver = conf.resolver - test("ordered distribution and sort with same exprs: append") { checkOrderedDistributionAndSortWithSameExprs("append") } @@ -1027,28 +1012,6 @@ class WriteDistributionAndOrderingSuite assert(actualOrdering == expectedOrdering, "ordering must match") } - private def resolveAttrs( - expr: catalyst.expressions.Expression, - plan: SparkPlan): catalyst.expressions.Expression = { - - expr.transform { - case UnresolvedAttribute(Seq(attrName)) => - plan.output.find(attr => resolver(attr.name, attrName)).get - case UnresolvedAttribute(nameParts) => - val attrName = nameParts.mkString(".") - fail(s"cannot resolve a nested attr: $attrName") - } - } - - private def attr(name: String): UnresolvedAttribute = { - UnresolvedAttribute(name) - } - - private def catalog: InMemoryTableCatalog = { - val catalog = spark.sessionState.catalogManager.catalog("testcat") - catalog.asTableCatalog.asInstanceOf[InMemoryTableCatalog] - } - // executes a write operation and keeps the executed physical plan private def execute(writeFunc: => Unit): SparkPlan = { var executedPlan: SparkPlan = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala new file mode 100644 index 0000000000000..1994874d3289e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog.functions + +import org.apache.spark.sql.types._ + +object UnboundYearsFunction extends UnboundFunction { + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 1 && isValidType(inputType.head.dataType)) YearsFunction + else throw new UnsupportedOperationException( + "'years' only take date or timestamp as input type") + } + + private def isValidType(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _ => false + } + + override def description(): String = name() + override def name(): String = "years" +} + +object YearsFunction extends BoundFunction { + override def inputTypes(): Array[DataType] = Array(TimestampType) + override def resultType(): DataType = LongType + override def name(): String = "years" + override def canonicalName(): String = name() +} + +object DaysFunction extends BoundFunction { + override def inputTypes(): Array[DataType] = Array(TimestampType) + override def resultType(): DataType = LongType + override def name(): String = "days" + override def canonicalName(): String = name() +} + +object UnboundDaysFunction extends UnboundFunction { + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 1 && isValidType(inputType.head.dataType)) DaysFunction + else throw new UnsupportedOperationException( + "'days' only take date or timestamp as input type") + } + + private def isValidType(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _ => false + } + + override def description(): String = name() + override def name(): String = "days" +} + +object UnboundBucketFunction extends UnboundFunction { + override def bind(inputType: StructType): BoundFunction = BucketFunction + override def description(): String = name() + override def name(): String = "bucket" +} + +object BucketFunction extends BoundFunction { + override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType) + override def resultType(): DataType = IntegerType + override def name(): String = "bucket" + override def canonicalName(): String = name() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala index 9909996059dac..775f34f1f6156 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala @@ -95,7 +95,7 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase { assert(getScanExecPartitionSize(plan) == expectedPartitionCount) val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse { - case BatchScanExec(_, scan: FileScan, _) => scan.partitionFilters + case BatchScanExec(_, scan: FileScan, _, _) => scan.partitionFilters } val pushedDownPartitionFilters = plan.collectFirst(collectFn) .map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index a62ce9226a6a6..a9cb01b6d5657 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -59,7 +59,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { .where(Column(predicate)) query.queryExecution.optimizedPlan match { - case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _)) => + case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") assert(o.pushedFilters.nonEmpty, "No filter is pushed down") val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 46a7f8d3d90de..53d2ccdc5af68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -120,8 +120,7 @@ trait OrcTest extends QueryTest with FileBasedDataSourceTest with BeforeAndAfter .where(Column(predicate)) query.queryExecution.optimizedPlan match { - case PhysicalOperation(_, filters, - DataSourceV2ScanRelation(_, o: OrcScan, _)) => + case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") if (noneSupported) { assert(o.pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala index 107a2b7912029..7fb6d4c36968d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala @@ -40,7 +40,7 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { val fileSourceScanSchemata = collect(df.queryExecution.executedPlan) { - case BatchScanExec(_, scan: OrcScan, _) => scan.readDataSchema + case BatchScanExec(_, scan: OrcScan, _, _) => scan.readDataSchema } assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index d2ec82bf443e0..0c4b239c11818 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -2016,7 +2016,7 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { query.queryExecution.optimizedPlan.collectFirst { case PhysicalOperation(_, filters, - DataSourceV2ScanRelation(_, scan: ParquetScan, _)) => + DataSourceV2ScanRelation(_, scan: ParquetScan, _, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter(_, true)).toArray val pushedFilters = scan.pushedFilters diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index db99557466d95..7237cc5f0fa51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.exchange -import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION import org.apache.spark.sql.test.SharedSparkSession class EnsureRequirementsSuite extends SharedSparkSession { @@ -79,6 +81,48 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } + test("reorder should handle KeyGroupedPartitioning") { + // partitioning on the left + val plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(Seq( + years(exprA), bucket(4, exprB), days(exprC)), 4) + ) + val plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(Seq( + years(exprB), bucket(4, exprA), days(exprD)), 4) + ) + val smjExec = SortMergeJoinExec( + exprB :: exprC :: exprA :: Nil, exprA :: exprD :: exprB :: Nil, + Inner, None, plan1, plan2 + ) + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, DummySparkPlan(_, _, _: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: KeyGroupedPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprA, exprB, exprC)) + assert(rightKeys === Seq(exprB, exprA, exprD)) + case other => fail(other.toString) + } + + // partitioning on the right + val plan3 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(Seq( + bucket(4, exprD), days(exprA), years(exprC)), 4) + ) + val smjExec2 = SortMergeJoinExec( + exprB :: exprD :: exprC :: Nil, exprA :: exprC :: exprD :: Nil, + Inner, None, plan1, plan3 + ) + EnsureRequirements.apply(smjExec2) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprC, exprB, exprD)) + assert(rightKeys === Seq(exprD, exprA, exprC)) + case other => fail(other.toString) + } + } + test("reorder should fallback to the other side partitioning") { val plan1 = DummySparkPlan( outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5)) @@ -645,4 +689,268 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } } + + test("Check with KeyGroupedPartitioning") { + // simplest case: identity transforms + var plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(exprA :: exprB :: Nil, 5)) + var plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(exprA :: exprC :: Nil, 5)) + var smjExec = SortMergeJoinExec( + exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + assert(left.expressions === Seq(exprA, exprB)) + assert(right.expressions === Seq(exprA, exprC)) + case other => fail(other.toString) + } + + // matching bucket transforms from both sides + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) + ) + plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4) + ) + smjExec = SortMergeJoinExec( + exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + assert(left.expressions === Seq(bucket(4, exprA), bucket(16, exprB))) + assert(right.expressions === Seq(bucket(4, exprA), bucket(16, exprC))) + case other => fail(other.toString) + } + + // partition collections + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) + ) + plan2 = DummySparkPlan( + outputPartitioning = PartitioningCollection(Seq( + KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4), + HashPartitioning(exprA :: exprC :: Nil, 4)) + ) + ) + smjExec = SortMergeJoinExec( + exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => + assert(left.expressions === Seq(bucket(4, exprA), bucket(16, exprB))) + case other => fail(other.toString) + } + smjExec = SortMergeJoinExec( + exprA :: exprC :: Nil, exprA :: exprB :: Nil, Inner, None, plan2, plan1) + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + assert(right.expressions === Seq(bucket(4, exprA), bucket(16, exprB))) + case other => fail(other.toString) + } + + // bucket + years transforms from both sides + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, 4) + ) + plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, 4) + ) + smjExec = SortMergeJoinExec( + exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + assert(left.expressions === Seq(bucket(4, exprA), years(exprB))) + assert(right.expressions === Seq(bucket(4, exprA), years(exprC))) + case other => fail(other.toString) + } + + // by default spark.sql.requireAllClusterKeysForCoPartition is true, so when there isn't + // exact match on all partition keys, Spark will fallback to shuffle. + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) + ) + plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: bucket(4, exprC) :: Nil, 4) + ) + smjExec = SortMergeJoinExec( + exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => + assert(left.expressions === Seq(exprA, exprB, exprB)) + assert(right.expressions === Seq(exprA, exprC, exprC)) + case other => fail(other.toString) + } + } + + test(s"KeyGroupedPartitioning with ${REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key} = false") { + var plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprB) :: years(exprC) :: Nil, 4) + ) + var plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprC) :: years(exprB) :: Nil, 4) + ) + + // simple case + var smjExec = SortMergeJoinExec( + exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) + applyEnsureRequirementsWithSubsetKeys(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + assert(left.expressions === Seq(bucket(4, exprB), years(exprC))) + assert(right.expressions === Seq(bucket(4, exprC), years(exprB))) + case other => fail(other.toString) + } + + // should also work with distributions with duplicated keys + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: years(exprB) :: Nil, 4) + ) + plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: years(exprC) :: Nil, 4) + ) + smjExec = SortMergeJoinExec( + exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) + applyEnsureRequirementsWithSubsetKeys(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + assert(left.expressions === Seq(bucket(4, exprA), years(exprB))) + assert(right.expressions === Seq(bucket(4, exprA), years(exprC))) + case other => fail(other.toString) + } + + // both partitioning and distribution have duplicated keys + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, 5)) + plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, 5)) + smjExec = SortMergeJoinExec( + exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) + applyEnsureRequirementsWithSubsetKeys(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + assert(left.expressions === Seq(years(exprA), bucket(4, exprB), days(exprA))) + assert(right.expressions === Seq(years(exprA), bucket(4, exprC), days(exprA))) + case other => fail(other.toString) + } + + // invalid case: partitioning key positions don't match + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) + ) + plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprB) :: bucket(4, exprC) :: Nil, 4) + ) + + smjExec = SortMergeJoinExec( + exprA :: exprB :: exprC :: Nil, exprA :: exprB :: exprC :: Nil, Inner, None, plan1, plan2) + applyEnsureRequirementsWithSubsetKeys(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => + assert(left.expressions === Seq(exprA, exprB, exprC)) + assert(right.expressions === Seq(exprA, exprB, exprC)) + case other => fail(other.toString) + } + + // invalid case: different number of buckets (we don't support coalescing/repartitioning yet) + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) + ) + plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + bucket(4, exprA) :: bucket(8, exprC) :: Nil, 4) + ) + smjExec = SortMergeJoinExec( + exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) + applyEnsureRequirementsWithSubsetKeys(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => + assert(left.expressions === Seq(exprA, exprB, exprB)) + assert(right.expressions === Seq(exprA, exprC, exprC)) + case other => fail(other.toString) + } + + // invalid case: partition key positions match but with different transforms + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, 4) + ) + plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, 4) + ) + smjExec = SortMergeJoinExec( + exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) + applyEnsureRequirementsWithSubsetKeys(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => + assert(left.expressions === Seq(exprA, exprB, exprB)) + assert(right.expressions === Seq(exprA, exprC, exprC)) + case other => fail(other.toString) + } + + + // invalid case: multiple references in transform + plan1 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) + ) + plan2 = DummySparkPlan( + outputPartitioning = KeyGroupedPartitioning( + years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) + ) + smjExec = SortMergeJoinExec( + exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) + applyEnsureRequirementsWithSubsetKeys(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(left: HashPartitioning, _, _), _), + SortExec(_, _, ShuffleExchangeExec(right: HashPartitioning, _, _), _), _) => + assert(left.expressions === Seq(exprA, exprB, exprB)) + assert(right.expressions === Seq(exprA, exprC, exprC)) + case other => fail(other.toString) + } + } + + def bucket(numBuckets: Int, expr: Expression): TransformExpression = { + TransformExpression(BucketFunction, Seq(expr), Some(numBuckets)) + } + + def buckets(numBuckets: Int, expr: Seq[Expression]): TransformExpression = { + TransformExpression(BucketFunction, expr, Some(numBuckets)) + } + + def years(expr: Expression): TransformExpression = { + TransformExpression(YearsFunction, Seq(expr)) + } + + def days(expr: Expression): TransformExpression = { + TransformExpression(DaysFunction, Seq(expr)) + } }