From fe24634d14bc0973ca38222db2f58eafbf0c890d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Jun 2017 00:43:21 -0700 Subject: [PATCH 001/779] [SPARK-21145][SS] Added StateStoreProviderId with queryRunId to reload StateStoreProviders when query is restarted ## What changes were proposed in this pull request? StateStoreProvider instances are loaded on-demand in a executor when a query is started. When a query is restarted, the loaded provider instance will get reused. Now, there is a non-trivial chance, that the task of the previous query run is still running, while the tasks of the restarted run has started. So for a stateful partition, there may be two concurrent tasks related to the same stateful partition, and there for using the same provider instance. This can lead to inconsistent results and possibly random failures, as state store implementations are not designed to be thread-safe. To fix this, I have introduced a `StateStoreProviderId`, that unique identifies a provider loaded in an executor. It has the query run id in it, thus making sure that restarted queries will force the executor to load a new provider instance, thus avoiding two concurrent tasks (from two different runs) from reusing the same provider instance. Additional minor bug fixes - All state stores related to query run is marked as deactivated in the `StateStoreCoordinator` so that the executors can unload them and clear resources. - Moved the code that determined the checkpoint directory of a state store from implementation-specific code (`HDFSBackedStateStoreProvider`) to non-specific code (StateStoreId), so that implementation do not accidentally get it wrong. - Also added store name to the path, to support multiple stores per sql operator partition. *Note:* This change does not address the scenario where two tasks of the same run (e.g. speculative tasks) are concurrently running in the same executor. The chance of this very small, because ideally speculative tasks should never run in the same executor. ## How was this patch tested? Existing unit tests + new unit test. Author: Tathagata Das Closes #18355 from tdas/SPARK-21145. --- .../sql/execution/aggregate/AggUtils.scala | 2 +- .../sql/execution/command/commands.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 7 +- .../streaming/IncrementalExecution.scala | 27 +++--- .../execution/streaming/StreamExecution.scala | 1 + .../state/HDFSBackedStateStoreProvider.scala | 16 ++-- .../streaming/state/StateStore.scala | 91 +++++++++++++----- .../state/StateStoreCoordinator.scala | 41 ++++---- .../streaming/state/StateStoreRDD.scala | 21 ++++- .../execution/streaming/state/package.scala | 25 ++--- .../streaming/statefulOperators.scala | 38 ++++---- .../sql/streaming/StreamingQueryManager.scala | 1 + .../state/StateStoreCoordinatorSuite.scala | 61 ++++++++++-- .../streaming/state/StateStoreRDDSuite.scala | 51 +++++----- .../streaming/state/StateStoreSuite.scala | 93 +++++++++++++++---- .../spark/sql/streaming/StreamSuite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 13 ++- 17 files changed, 329 insertions(+), 166 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index aa789af6f812f..12f8cffb6774a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -311,7 +311,7 @@ object AggUtils { val saved = StateStoreSaveExec( groupingAttributes, - stateId = None, + stateInfo = None, outputMode = None, eventTimeWatermark = None, partialMerged2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2d82fcf4da6e9..81bc93e7ebcf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.UUID + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -117,7 +119,8 @@ case class ExplainCommand( // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. new IncrementalExecution( - sparkSession, logicalPlan, OutputMode.Append(), "", 0, OffsetSeqMetadata(0, 0)) + sparkSession, logicalPlan, OutputMode.Append(), "", + UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 2aad8701a4eca..9dcac33b4107c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -50,7 +50,7 @@ case class FlatMapGroupsWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, timeoutConf: GroupStateTimeout, @@ -107,10 +107,7 @@ case class FlatMapGroupsWithStateExec( } child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, groupingAttributes.toStructType, stateAttributes.toStructType, indexOrdinal = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 622e049630db2..ab89dc6b705d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging @@ -36,6 +37,7 @@ class IncrementalExecution( logicalPlan: LogicalPlan, val outputMode: OutputMode, val checkpointLocation: String, + val runId: UUID, val currentBatchId: Long, offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { @@ -69,7 +71,13 @@ class IncrementalExecution( * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ - private val operatorId = new AtomicInteger(0) + private val statefulOperatorId = new AtomicInteger(0) + + /** Get the state info of the next stateful operator */ + private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo( + checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + } /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -78,35 +86,28 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - + val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, - Some(stateId), + Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), agg.withNewChildren( StateStoreRestoreExec( keys, - Some(stateId), + Some(aggStateInfo), child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - StreamingDeduplicateExec( keys, child, - Some(stateId), + Some(nextStatefulOperationStateInfo), Some(offsetSeqMetadata.batchWatermarkMs)) case m: FlatMapGroupsWithStateExec => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) m.copy( - stateId = Some(stateId), + stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 74f0f509bbf85..06bdec8b06407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -652,6 +652,7 @@ class StreamExecution( triggerLogicalPlan, outputMode, checkpointFile("state"), + runId, currentBatchId, offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 67d86daf10812..bae7a15165e43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -92,7 +92,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null - override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId override def get(key: UnsafeRow): UnsafeRow = { mapToUpdate.get(key) @@ -177,7 +177,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** * Whether all updates have been committed */ - override private[streaming] def hasCommitted: Boolean = { + override def hasCommitted: Boolean = { state == COMMITTED } @@ -205,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): Unit = { - this.stateStoreId = stateStoreId + this.stateStoreId_ = stateStoreId this.keySchema = keySchema this.valueSchema = valueSchema this.storeConf = storeConf @@ -213,7 +213,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit fs.mkdirs(baseDir) } - override def id: StateStoreId = stateStoreId + override def stateStoreId: StateStoreId = stateStoreId_ /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { @@ -231,20 +231,20 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def toString(): String = { - s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStoreProvider[" + + s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]" } /* Internal fields and methods */ - @volatile private var stateStoreId: StateStoreId = _ + @volatile private var stateStoreId_ : StateStoreId = _ @volatile private var keySchema: StructType = _ @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ private lazy val loadedMaps = new mutable.HashMap[Long, MapType] - private lazy val baseDir = - new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private lazy val baseDir = stateStoreId.storeCheckpointLocation() private lazy val fs = baseDir.getFileSystem(hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 29c456f86e1ed..a94ff8a7ebd1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID import java.util.concurrent.{ScheduledFuture, TimeUnit} import javax.annotation.concurrent.GuardedBy @@ -24,14 +25,14 @@ import scala.collection.mutable import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} - /** * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific * version of state data, and such instances are created through a [[StateStoreProvider]]. @@ -99,7 +100,7 @@ trait StateStore { /** * Whether all updates have been committed */ - private[streaming] def hasCommitted: Boolean + def hasCommitted: Boolean } @@ -147,7 +148,7 @@ trait StateStoreProvider { * Return the id of the StateStores this provider will generate. * Should be the same as the one passed in init(). */ - def id: StateStoreId + def stateStoreId: StateStoreId /** Called when the provider instance is unloaded from the executor */ def close(): Unit @@ -179,13 +180,46 @@ object StateStoreProvider { } } +/** + * Unique identifier for a provider, used to identify when providers can be reused. + * Note that `queryRunId` is used uniquely identify a provider, so that the same provider + * instance is not reused across query restarts. + */ +case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID) -/** Unique identifier for a bunch of keyed state data. */ +/** + * Unique identifier for a bunch of keyed state data. + * @param checkpointRootLocation Root directory where all the state data of a query is stored + * @param operatorId Unique id of a stateful operator + * @param partitionId Index of the partition of an operators state data + * @param storeName Optional, name of the store. Each partition can optionally use multiple state + * stores, but they have to be identified by distinct names. + */ case class StateStoreId( - checkpointLocation: String, + checkpointRootLocation: String, operatorId: Long, partitionId: Int, - name: String = "") + storeName: String = StateStoreId.DEFAULT_STORE_NAME) { + + /** + * Checkpoint directory to be used by a single state store, identified uniquely by the tuple + * (operatorId, partitionId, storeName). All implementations of [[StateStoreProvider]] should + * use this path for saving state data, as this ensures that distinct stores will write to + * different locations. + */ + def storeCheckpointLocation(): Path = { + if (storeName == StateStoreId.DEFAULT_STORE_NAME) { + // For reading state store data that was generated before store names were used (Spark <= 2.2) + new Path(checkpointRootLocation, s"$operatorId/$partitionId") + } else { + new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName") + } + } +} + +object StateStoreId { + val DEFAULT_STORE_NAME = "default" +} /** Mutable, and reusable class for representing a pair of UnsafeRows. */ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { @@ -211,7 +245,7 @@ object StateStore extends Logging { val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 @GuardedBy("loadedProviders") - private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]() /** * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` @@ -253,7 +287,7 @@ object StateStore extends Logging { /** Get or create a store associated with the id. */ def get( - storeId: StateStoreId, + storeProviderId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -264,24 +298,24 @@ object StateStore extends Logging { val storeProvider = loadedProviders.synchronized { startMaintenanceIfNeeded() val provider = loadedProviders.getOrElseUpdate( - storeId, + storeProviderId, StateStoreProvider.instantiate( - storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) ) - reportActiveStoreInstance(storeId) + reportActiveStoreInstance(storeProviderId) provider } storeProvider.getStore(version) } /** Unload a state store provider */ - def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeId).foreach(_.close()) + def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeProviderId).foreach(_.close()) } /** Whether a state store provider is loaded or not */ - def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { - loadedProviders.contains(storeId) + def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeProviderId) } def isMaintenanceRunning: Boolean = loadedProviders.synchronized { @@ -340,21 +374,21 @@ object StateStore extends Logging { } } - private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { + private def reportActiveStoreInstance(storeProviderId: StateStoreProviderId): Unit = { if (SparkEnv.get != null) { val host = SparkEnv.get.blockManager.blockManagerId.host val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) - logDebug(s"Reported that the loaded instance $storeId is active") + coordinatorRef.foreach(_.reportActiveInstance(storeProviderId, host, executorId)) + logInfo(s"Reported that the loaded instance $storeProviderId is active") } } - private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { + private def verifyIfStoreInstanceActive(storeProviderId: StateStoreProviderId): Boolean = { if (SparkEnv.get != null) { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId val verified = - coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) - logDebug(s"Verified whether the loaded instance $storeId is active: $verified") + coordinatorRef.map(_.verifyIfInstanceActive(storeProviderId, executorId)).getOrElse(false) + logDebug(s"Verified whether the loaded instance $storeProviderId is active: $verified") verified } else { false @@ -364,12 +398,21 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - if (_coordRef == null) { + logInfo("Env is not null") + val isDriver = + env.executorId == SparkContext.DRIVER_IDENTIFIER || + env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + // If running locally, then the coordinator reference in _coordRef may be have become inactive + // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, + // always recreate the reference. + if (isDriver || _coordRef == null) { + logInfo("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } - logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") + logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { + logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index d0f81887e62d1..3884f5e6ce766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.collection.mutable import org.apache.spark.SparkEnv @@ -29,16 +31,19 @@ import org.apache.spark.util.RpcUtils private sealed trait StateStoreCoordinatorMessage extends Serializable /** Classes representing messages */ -private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) +private case class ReportActiveInstance( + storeId: StateStoreProviderId, + host: String, + executorId: String) extends StateStoreCoordinatorMessage -private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) +private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, executorId: String) extends StateStoreCoordinatorMessage -private case class GetLocation(storeId: StateStoreId) +private case class GetLocation(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage -private case class DeactivateInstances(checkpointLocation: String) +private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -80,25 +85,27 @@ object StateStoreCoordinatorRef extends Logging { class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[state] def reportActiveInstance( - storeId: StateStoreId, + stateStoreProviderId: StateStoreProviderId, host: String, executorId: String): Unit = { - rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId)) + rpcEndpointRef.send(ReportActiveInstance(stateStoreProviderId, host, executorId)) } /** Verify whether the given executor has the active instance of a state store */ - private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { - rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(storeId, executorId)) + private[state] def verifyIfInstanceActive( + stateStoreProviderId: StateStoreProviderId, + executorId: String): Boolean = { + rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(stateStoreProviderId, executorId)) } /** Get the location of the state store */ - private[state] def getLocation(storeId: StateStoreId): Option[String] = { - rpcEndpointRef.askSync[Option[String]](GetLocation(storeId)) + private[state] def getLocation(stateStoreProviderId: StateStoreProviderId): Option[String] = { + rpcEndpointRef.askSync[Option[String]](GetLocation(stateStoreProviderId)) } - /** Deactivate instances related to a set of operator */ - private[state] def deactivateInstances(storeRootLocation: String): Unit = { - rpcEndpointRef.askSync[Boolean](DeactivateInstances(storeRootLocation)) + /** Deactivate instances related to a query */ + private[sql] def deactivateInstances(runId: UUID): Unit = { + rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } private[state] def stop(): Unit = { @@ -113,7 +120,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => @@ -135,11 +142,11 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) logDebug(s"Got location of the state store $id: $executorId") context.reply(executorId) - case DeactivateInstances(checkpointLocation) => + case DeactivateInstances(runId) => val storeIdsToRemove = - instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq + instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove - logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " + + logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index b744c25dc97a8..01d8e75980993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} @@ -34,8 +36,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, @@ -52,16 +54,25 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override protected def getPartitions: Array[Partition] = dataRDD.partitions + /** + * Set the preferred location of each partition using the executor that has the related + * [[StateStoreProvider]] already loaded. + */ override def getPreferredLocations(partition: Partition): Seq[String] = { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) - storeCoordinator.flatMap(_.getLocation(storeId)).toSeq + val stateStoreProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq } override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) + val storeProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + store = StateStore.get( - storeId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 228fe86d59940..a0086e251f9c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.TaskContext @@ -32,20 +34,14 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ def mapPartitionsWithStateStore[U: ClassTag]( sqlContext: SQLContext, - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { mapPartitionsWithStateStore( - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo, keySchema, valueSchema, indexOrdinal, @@ -56,10 +52,7 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -79,10 +72,10 @@ package object state { new StateStoreRDD( dataRDD, wrappedF, - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo.checkpointLocation, + stateInfo.queryRunId, + stateInfo.operatorId, + stateInfo.storeVersion, keySchema, valueSchema, indexOrdinal, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3e57f3fbada32..c5722466a33af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.TimeUnit._ import org.apache.spark.rdd.RDD @@ -36,20 +37,22 @@ import org.apache.spark.util.{CompletionIterator, NextIterator} /** Used to identify the state store for a given operator. */ -case class OperatorStateId( +case class StatefulOperatorStateInfo( checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - batchId: Long) + storeVersion: Long) /** - * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should - * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + * An operator that reads or writes state from the [[StateStore]]. + * The [[StatefulOperatorStateInfo]] should be filled in by `prepareForExecution` in + * [[IncrementalExecution]]. */ trait StatefulOperator extends SparkPlan { - def stateId: Option[OperatorStateId] + def stateInfo: Option[StatefulOperatorStateInfo] - protected def getStateId: OperatorStateId = attachTree(this) { - stateId.getOrElse { + protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) { + stateInfo.getOrElse { throw new IllegalStateException("State location not present for execution") } } @@ -140,7 +143,7 @@ trait WatermarkSupport extends UnaryExecNode { */ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], child: SparkPlan) extends UnaryExecNode with StateStoreReader { @@ -148,10 +151,7 @@ case class StateStoreRestoreExec( val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeName = "default", - storeVersion = getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -177,7 +177,7 @@ case class StateStoreRestoreExec( */ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) @@ -189,10 +189,7 @@ case class StateStoreSaveExec( "Incorrect planning in IncrementalExecution, outputMode has not been set") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -319,7 +316,7 @@ case class StateStoreSaveExec( case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], child: SparkPlan, - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, eventTimeWatermark: Option[Long] = None) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -331,10 +328,7 @@ case class StreamingDeduplicateExec( metrics // force lazy init at driver child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 002c45413b4c2..48b0ea20e5da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -332,5 +332,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } awaitTerminationLock.notifyAll() } + stateStoreCoordinator.deactivateInstances(terminatedQuery.runId) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index a7e32626264cc..9a7595eee7bd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count +import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -29,7 +35,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("report, verify, getLocation") { withCoordinatorRef(sc) { coordinatorRef => - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) assert(coordinatorRef.getLocation(id) === None) @@ -57,9 +63,11 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("make inactive") { withCoordinatorRef(sc) { coordinatorRef => - val id1 = StateStoreId("x", 0, 0) - val id2 = StateStoreId("y", 1, 0) - val id3 = StateStoreId("x", 0, 1) + val runId1 = UUID.randomUUID + val runId2 = UUID.randomUUID + val id1 = StateStoreProviderId(StateStoreId("x", 0, 0), runId1) + val id2 = StateStoreProviderId(StateStoreId("y", 1, 0), runId2) + val id3 = StateStoreProviderId(StateStoreId("x", 0, 1), runId1) val host = "hostX" val exec = "exec1" @@ -73,7 +81,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) } - coordinatorRef.deactivateInstances("x") + coordinatorRef.deactivateInstances(runId1) assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) @@ -85,7 +93,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { Some(ExecutorCacheTaskLocation(host, exec).toString)) assert(coordinatorRef.getLocation(id3) === None) - coordinatorRef.deactivateInstances("y") + coordinatorRef.deactivateInstances(runId2) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) assert(coordinatorRef.getLocation(id2) === None) } @@ -95,7 +103,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { withCoordinatorRef(sc) { coordRef1 => val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) coordRef1.reportActiveInstance(id, "hostX", "exec1") @@ -107,6 +115,45 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } } + + test("query stop deactivates related store providers") { + var coordRef: StateStoreCoordinatorRef = null + try { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + SparkSession.setActiveSession(spark) + import spark.implicits._ + coordRef = spark.streams.stateStoreCoordinator + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + + // Start a query and run a batch to load state stores + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + + // Verify state store has been loaded + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, 0), query.runId) + assert(coordRef.getLocation(providerId).nonEmpty) + + // Stop and verify whether the stores are deactivated in the coordinator + query.stop() + assert(coordRef.getLocation(providerId).isEmpty) + } finally { + SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + if (coordRef != null) coordRef.stop() + StateStore.stop() + } + } } object StateStoreCoordinatorSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 4a1a089af54c2..defb9ed63a881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -19,20 +19,19 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File import java.nio.file.Files +import java.util.UUID import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.LocalSparkSession._ -import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} @@ -57,16 +56,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val opId = 0 - val rdd1 = - makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) @@ -76,7 +73,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( @@ -85,7 +81,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn storeVersion: Int): RDD[(String, Int)] = { implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = storeVersion), + keySchema, valueSchema, None)(increment) } // Generate RDDs and state store data @@ -132,17 +129,17 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } @@ -150,22 +147,25 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("preferred locations using StateStoreCoordinator") { quietly { + val queryRunId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2") + val storeProviderId1 = StateStoreProviderId(StateStoreId(path, opId, 0), queryRunId) + val storeProviderId2 = StateStoreProviderId(StateStoreId(path, opId, 1), queryRunId) + coordinatorRef.reportActiveInstance(storeProviderId1, "host1", "exec1") + coordinatorRef.reportActiveInstance(storeProviderId2, "host2", "exec2") - assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) === + require( + coordinatorRef.getLocation(storeProviderId1) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( - increment) + sqlContext, operatorStateInfo(path, queryRunId = queryRunId), + keySchema, valueSchema, None)(increment) require(rdd.partitions.length === 2) assert( @@ -192,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -210,6 +210,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) } + private def operatorStateInfo( + path: String, + queryRunId: UUID = UUID.randomUUID, + version: Int = 0): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + } + private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => val key = stringToRow(s) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index af2b9f1c11fb6..c2087ec219e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util.UUID import scala.collection.JavaConverters._ import scala.collection.mutable @@ -33,8 +34,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -143,7 +147,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] provider.getStore(0).commit() // Verify we don't leak temp files - val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation), + val tempFiles = FileUtils.listFiles(new File(provider.stateStoreId.checkpointRootLocation), null, true).asScala.filter(_.getName.startsWith("temp-")) assert(tempFiles.isEmpty) } @@ -183,7 +187,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("StateStore.get") { quietly { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() @@ -243,18 +247,18 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] .set("spark.rpc.numRetries", "1") val opId = 0 val dir = newDir() - val storeId = StateStoreId(dir, opId, 0) + val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), UUID.randomUUID) val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = newStoreProvider(storeId) + val provider = newStoreProvider(storeProviderId.storeId) var latestStoreVersion = 0 def generateStoreVersions() { for (i <- 1 to 20) { - val store = StateStore.get(storeId, keySchema, valueSchema, None, + val store = StateStore.get(storeProviderId, keySchema, valueSchema, None, latestStoreVersion, storeConf, hadoopConf) put(store, "a", i) store.commit() @@ -274,7 +278,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] eventually(timeout(timeoutDuration)) { // Store should have been reported to the coordinator - assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + assert(coordinatorRef.getLocation(storeProviderId).nonEmpty, + "active instance was not reported") // Background maintenance should clean up and generate snapshots assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") @@ -295,35 +300,35 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") } - // If driver decides to deactivate all instances of the store, then this instance - // should be unloaded - coordinatorRef.deactivateInstances(dir) + // If driver decides to deactivate all stores related to a query run, + // then this instance should be unloaded + coordinatorRef.deactivateInstances(storeProviderId.queryRunId) eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) // If some other executor loads the store, then this instance should be unloaded - coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + coordinatorRef.reportActiveInstance(storeProviderId, "other-host", "other-exec") eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) } } // Verify if instance is unloaded if SparkContext is stopped eventually(timeout(timeoutDuration)) { require(SparkEnv.get === null) - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) assert(!StateStore.isMaintenanceRunning) } } @@ -344,7 +349,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("SPARK-18416: do not create temp delta file until the store is updated") { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() val deltaFileDir = new File(s"$dir/0/0/") @@ -408,12 +413,60 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(numDeltaFiles === 3) } + test("SPARK-21145: Restarted queries create new provider instances") { + try { + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val spark = SparkSession.builder().master("local[2]").getOrCreate() + SparkSession.setActiveSession(spark) + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + import spark.implicits._ + val inputData = MemoryStream[Int] + + def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = { + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) + // stateful query + val query = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + require(query.lastProgress != null) // at least one batch processed after start + val loadedProvidersMethod = + PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]('loadedProviders) + val loadedProvidersMap = StateStore invokePrivate loadedProvidersMethod() + val loadedProviders = loadedProvidersMap.synchronized { loadedProvidersMap.values.toSeq } + query.stop() + loadedProviders + } + + val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders() + require(loadedProvidersAfterRun1.length === 1) + + val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders() + assert(loadedProvidersAfterRun2.length === 2) // two providers loaded for 2 runs + + // Both providers should have the same StateStoreId, but the should be different objects + assert(loadedProvidersAfterRun2(0).stateStoreId === loadedProvidersAfterRun2(1).stateStoreId) + assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1)) + + } finally { + SparkSession.getActiveSession.foreach { spark => + spark.streams.active.foreach(_.stop()) + spark.stop() + } + } + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = { - newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation) + newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointRootLocation) } override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = { @@ -423,7 +476,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] override def getData( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = newStoreProvider(provider.id) + val reloadedProvider = newStoreProvider(provider.stateStoreId) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 4ede4fd9a035e..86c3a35a59c13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -777,7 +777,7 @@ class TestStateStoreProvider extends StateStoreProvider { throw new Exception("Successfully instantiated") } - override def id: StateStoreId = null + override def stateStoreId: StateStoreId = null override def close(): Unit = { } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 2a4039cc5831a..b2c42eef88f6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -26,9 +26,8 @@ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import org.scalatest.Assertions +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, Timeouts} -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span @@ -39,9 +38,10 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * A framework for implementing tests for streaming queries and sources. @@ -67,7 +67,12 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} * avoid hanging forever in the case of failures. However, individual suites can change this * by overriding `streamingTimeout`. */ -trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { +trait StreamTest extends QueryTest with SharedSQLContext with Timeouts with BeforeAndAfterAll { + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() // stop the state store maintenance thread and unload store providers + } /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds From 153dd49b74e1b6df2b8e35760806c9754ca7bfae Mon Sep 17 00:00:00 2001 From: jinxing Date: Fri, 23 Jun 2017 20:41:17 +0800 Subject: [PATCH 002/779] [SPARK-21047] Add test suites for complicated cases in ColumnarBatchSuite ## What changes were proposed in this pull request? Current ColumnarBatchSuite has very simple test cases for `Array` and `Struct`. This pr wants to add some test suites for complicated cases in ColumnVector. Author: jinxing Closes #18327 from jinxing64/SPARK-21047. --- .../execution/vectorized/ColumnarBatch.java | 35 ++++- .../vectorized/ColumnarBatchSuite.scala | 122 ++++++++++++++++++ 2 files changed, 156 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 8b7b0e655b31d..e23a64350cbc5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -241,7 +241,40 @@ public MapData getMap(int ordinal) { @Override public Object get(int ordinal, DataType dataType) { - throw new UnsupportedOperationException(); + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index e48e3f6402901..80d41577dcf2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -739,6 +739,128 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Nest Array in Array.") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val column = ColumnVector.allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), + memMode) + val childColumn = column.arrayData() + val data = column.arrayData().arrayData() + (0 until 6).foreach { + case 3 => data.putNull(3) + case i => data.putInt(i, i) + } + // Arrays in child column: [0], [1, 2], [], [null, 4, 5] + childColumn.putArray(0, 0, 1) + childColumn.putArray(1, 1, 2) + childColumn.putArray(2, 2, 0) + childColumn.putArray(3, 3, 3) + // Arrays in column: [[0]], [[1, 2], []], [[], [null, 4, 5]], null + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putArray(2, 2, 2) + column.putNull(3) + + assert(column.getArray(0).getArray(0).toIntArray() === Array(0)) + assert(column.getArray(1).getArray(0).toIntArray() === Array(1, 2)) + assert(column.getArray(1).getArray(1).toIntArray() === Array()) + assert(column.getArray(2).getArray(0).toIntArray() === Array()) + assert(column.getArray(2).getArray(1).isNullAt(0)) + assert(column.getArray(2).getArray(1).getInt(1) === 4) + assert(column.getArray(2).getArray(1).getInt(2) === 5) + assert(column.isNullAt(3)) + } + } + + test("Nest Struct in Array.") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val column = ColumnVector.allocate(10, new ArrayType(schema, true), memMode) + val data = column.arrayData() + val c0 = data.getChildColumn(0) + val c1 = data.getChildColumn(1) + // Structs in child column: (0, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50) + (0 until 6).foreach { i => + c0.putInt(i, i) + c1.putLong(i, i * 10) + } + // Arrays in column: [(0, 0), (1, 10)], [(1, 10), (2, 20), (3, 30)], + // [(4, 40), (5, 50)] + column.putArray(0, 0, 2) + column.putArray(1, 1, 3) + column.putArray(2, 4, 2) + + assert(column.getArray(0).getStruct(0, 2).toSeq(schema) === Seq(0, 0)) + assert(column.getArray(0).getStruct(1, 2).toSeq(schema) === Seq(1, 10)) + assert(column.getArray(1).getStruct(0, 2).toSeq(schema) === Seq(1, 10)) + assert(column.getArray(1).getStruct(1, 2).toSeq(schema) === Seq(2, 20)) + assert(column.getArray(1).getStruct(2, 2).toSeq(schema) === Seq(3, 30)) + assert(column.getArray(2).getStruct(0, 2).toSeq(schema) === Seq(4, 40)) + assert(column.getArray(2).getStruct(1, 2).toSeq(schema) === Seq(5, 50)) + } + } + + test("Nest Array in Struct.") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val schema = new StructType() + .add("int", IntegerType) + .add("array", new ArrayType(IntegerType, true)) + val column = ColumnVector.allocate(10, schema, memMode) + val c0 = column.getChildColumn(0) + val c1 = column.getChildColumn(1) + c0.putInt(0, 0) + c0.putInt(1, 1) + c0.putInt(2, 2) + val c1Child = c1.arrayData() + (0 until 6).foreach { i => + c1Child.putInt(i, i) + } + // Arrays in c1: [0, 1], [2], [3, 4, 5] + c1.putArray(0, 0, 2) + c1.putArray(1, 2, 1) + c1.putArray(2, 3, 3) + + assert(column.getStruct(0).getInt(0) === 0) + assert(column.getStruct(0).getArray(1).toIntArray() === Array(0, 1)) + assert(column.getStruct(1).getInt(0) === 1) + assert(column.getStruct(1).getArray(1).toIntArray() === Array(2)) + assert(column.getStruct(2).getInt(0) === 2) + assert(column.getStruct(2).getArray(1).toIntArray() === Array(3, 4, 5)) + } + } + + test("Nest Struct in Struct.") { + (MemoryMode.ON_HEAP :: Nil).foreach { memMode => + val subSchema = new StructType() + .add("int", IntegerType) + .add("int", IntegerType) + val schema = new StructType() + .add("int", IntegerType) + .add("struct", subSchema) + val column = ColumnVector.allocate(10, schema, memMode) + val c0 = column.getChildColumn(0) + val c1 = column.getChildColumn(1) + c0.putInt(0, 0) + c0.putInt(1, 1) + c0.putInt(2, 2) + val c1c0 = c1.getChildColumn(0) + val c1c1 = c1.getChildColumn(1) + // Structs in c1: (7, 70), (8, 80), (9, 90) + c1c0.putInt(0, 7) + c1c0.putInt(1, 8) + c1c0.putInt(2, 9) + c1c1.putInt(0, 70) + c1c1.putInt(1, 80) + c1c1.putInt(2, 90) + + assert(column.getStruct(0).getInt(0) === 0) + assert(column.getStruct(0).getStruct(1, 2).toSeq(subSchema) === Seq(7, 70)) + assert(column.getStruct(1).getInt(0) === 1) + assert(column.getStruct(1).getStruct(1, 2).toSeq(subSchema) === Seq(8, 80)) + assert(column.getStruct(2).getInt(0) === 2) + assert(column.getStruct(2).getStruct(1, 2).toSeq(subSchema) === Seq(9, 90)) + } + } + test("ColumnarBatch basic") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType() From acd208ee50b29bde4e097bf88761867b1d57a665 Mon Sep 17 00:00:00 2001 From: 10129659 Date: Fri, 23 Jun 2017 20:53:26 +0800 Subject: [PATCH 003/779] [SPARK-21115][CORE] If the cores left is less than the coresPerExecutor,the cores left will not be allocated, so it should not to check in every schedule ## What changes were proposed in this pull request? If we start an app with the param --total-executor-cores=4 and spark.executor.cores=3, the cores left is always 1, so it will try to allocate executors in the function org.apache.spark.deploy.master.startExecutorsOnWorkers in every schedule. Another question is, is it will be better to allocate another executor with 1 core for the cores left. ## How was this patch tested? unit test Author: 10129659 Closes #18322 from eatoncys/leftcores. --- .../scala/org/apache/spark/SparkConf.scala | 11 +++++++ .../apache/spark/deploy/master/Master.scala | 29 ++++++++++--------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ba7a65f79c414..de2f475c6895f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -543,6 +543,17 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } + if (contains("spark.cores.max") && contains("spark.executor.cores")) { + val totalCores = getInt("spark.cores.max", 1) + val executorCores = getInt("spark.executor.cores", 1) + val leftCores = totalCores % executorCores + if (leftCores != 0) { + logWarning(s"Total executor cores: ${totalCores} is not " + + s"divisible by cores per executor: ${executorCores}, " + + s"the left cores: ${leftCores} will not be allocated") + } + } + val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED) require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c192a0cc82ef6..0dee25fb2ebe2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -659,19 +659,22 @@ private[deploy] class Master( private def startExecutorsOnWorkers(): Unit = { // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. - for (app <- waitingApps if app.coresLeft > 0) { - val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor - // Filter out workers that don't have enough resources to launch an executor - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && - worker.coresFree >= coresPerExecutor.getOrElse(1)) - .sortBy(_.coresFree).reverse - val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) - - // Now that we've decided how many cores to allocate on each worker, let's allocate them - for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { - allocateWorkerResourceToExecutors( - app, assignedCores(pos), coresPerExecutor, usableWorkers(pos)) + for (app <- waitingApps) { + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(1) + // If the cores left is less than the coresPerExecutor,the cores left will not be allocated + if (app.coresLeft >= coresPerExecutor) { + // Filter out workers that don't have enough resources to launch an executor + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= coresPerExecutor) + .sortBy(_.coresFree).reverse + val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) + + // Now that we've decided how many cores to allocate on each worker, let's allocate them + for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { + allocateWorkerResourceToExecutors( + app, assignedCores(pos), app.desc.coresPerExecutor, usableWorkers(pos)) + } } } } From 5dca10b8fdec81a3cc476301fa4f82ea917c34ec Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Jun 2017 21:51:55 +0800 Subject: [PATCH 004/779] [SPARK-21193][PYTHON] Specify Pandas version in setup.py ## What changes were proposed in this pull request? It looks we missed specifying the Pandas version. This PR proposes to fix it. For the current state, it should be Pandas 0.13.0 given my test. This PR propose to fix it as 0.13.0. Running the codes below: ```python from pyspark.sql.types import * schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", BooleanType()).add("d", FloatType()) data = [ (1, "foo", True, 3.0,), (2, "foo", True, 5.0), (3, "bar", False, -1.0), (4, "bar", False, 6.0), ] spark.createDataFrame(data, schema).toPandas().dtypes ``` prints ... **With Pandas 0.13.0** - released, 2014-01 ``` a int32 b object c bool d float32 dtype: object ``` **With Pandas 0.12.0** - - released, 2013-06 ``` Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/dataframe.py", line 1734, in toPandas pdf[f] = pdf[f].astype(t, copy=False) TypeError: astype() got an unexpected keyword argument 'copy' ``` without `copy` ``` a int32 b object c bool d float32 dtype: object ``` **With Pandas 0.11.0** - released, 2013-03 ``` Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/dataframe.py", line 1734, in toPandas pdf[f] = pdf[f].astype(t, copy=False) TypeError: astype() got an unexpected keyword argument 'copy' ``` without `copy` ``` a int32 b object c bool d float32 dtype: object ``` **With Pandas 0.10.0** - released, 2012-12 ``` Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/dataframe.py", line 1734, in toPandas pdf[f] = pdf[f].astype(t, copy=False) TypeError: astype() got an unexpected keyword argument 'copy' ``` without `copy` ``` a int64 # <- this should be 'int32' b object c bool d float64 # <- this should be 'float32' ``` ## How was this patch tested? Manually tested with Pandas from 0.10.0 to 0.13.0. Author: hyukjinkwon Closes #18403 from HyukjinKwon/SPARK-21193. --- python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index f50035435e26b..2644d3e79dea1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -199,7 +199,7 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas'] + 'sql': ['pandas>=0.13.0'] }, classifiers=[ 'Development Status :: 5 - Production/Stable', From f3dea60793d86212ba1068e88ad89cb3dcf07801 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 23 Jun 2017 09:28:02 -0700 Subject: [PATCH 005/779] [SPARK-21144][SQL] Print a warning if the data schema and partition schema have the duplicate columns ## What changes were proposed in this pull request? The current master outputs unexpected results when the data schema and partition schema have the duplicate columns: ``` withTempPath { dir => val basePath = dir.getCanonicalPath spark.range(0, 3).toDF("foo").write.parquet(new Path(basePath, "foo=1").toString) spark.range(0, 3).toDF("foo").write.parquet(new Path(basePath, "foo=a").toString) spark.read.parquet(basePath).show() } +---+ |foo| +---+ | 1| | 1| | a| | a| | 1| | a| +---+ ``` This patch added code to print a warning when the duplication found. ## How was this patch tested? Manually checked. Author: Takeshi Yamamuro Closes #18375 from maropu/SPARK-21144-3. --- .../apache/spark/sql/util/SchemaUtils.scala | 53 +++++++++++++++++++ .../execution/datasources/DataSource.scala | 6 +++ 2 files changed, 59 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala new file mode 100644 index 0000000000000..e881685ce6262 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.internal.Logging + + +/** + * Utils for handling schemas. + * + * TODO: Merge this file with [[org.apache.spark.ml.util.SchemaUtils]]. + */ +private[spark] object SchemaUtils extends Logging { + + /** + * Checks if input column names have duplicate identifiers. Prints a warning message if + * the duplication exists. + * + * @param columnNames column names to check + * @param colType column type name, used in a warning message + * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not + */ + def checkColumnNameDuplication( + columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = { + val names = if (caseSensitiveAnalysis) { + columnNames + } else { + columnNames.map(_.toLowerCase) + } + if (names.distinct.length != names.length) { + val duplicateColumns = names.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => s"`$x`" + } + logWarning(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}. " + + "You might need to assign different column names.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 08c78e6e326af..75e530607570f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils /** @@ -182,6 +183,11 @@ case class DataSource( throw new AnalysisException( s"Unable to infer schema for $format. It must be specified manually.") } + + SchemaUtils.checkColumnNameDuplication( + (dataSchema ++ partitionSchema).map(_.name), "in the data schema and the partition schema", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + (dataSchema, partitionSchema) } From 07479b3cfb7a617a18feca14e9e31c208c80630e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Jun 2017 09:59:24 -0700 Subject: [PATCH 006/779] [SPARK-21149][R] Add job description API for R ## What changes were proposed in this pull request? Extend `setJobDescription` to SparkR API. ## How was this patch tested? It looks difficult to add a test. Manually tested as below: ```r df <- createDataFrame(iris) count(df) setJobDescription("This is an example job.") count(df) ``` prints ... ![2017-06-22 12 05 49](https://user-images.githubusercontent.com/6477701/27415670-2a649936-5743-11e7-8e95-312f1cd103af.png) Author: hyukjinkwon Closes #18382 from HyukjinKwon/SPARK-21149. --- R/pkg/NAMESPACE | 3 ++- R/pkg/R/sparkR.R | 17 +++++++++++++++++ R/pkg/tests/fulltests/test_context.R | 1 + 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 229de4a997eef..b7fdae58de459 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -75,7 +75,8 @@ exportMethods("glm", # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", - "cancelJobGroup") + "cancelJobGroup", + "setJobDescription") # Export Utility methods export("setLogLevel") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d0a12b7ecec65..f2d2620e5447a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -535,6 +535,23 @@ cancelJobGroup <- function(sc, groupId) { } } +#' Set a human readable description of the current job. +#' +#' Set a description that is shown as a job description in UI. +#' +#' @param value The job description of the current job. +#' @rdname setJobDescription +#' @name setJobDescription +#' @examples +#'\dontrun{ +#' setJobDescription("This is an example job.") +#'} +#' @note setJobDescription since 2.3.0 +setJobDescription <- function(value) { + sc <- getSparkContext() + invisible(callJMethod(sc, "setJobDescription", value)) +} + sparkConfToSubmitOps <- new.env() sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index 710485d56685a..77635c5a256b9 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -100,6 +100,7 @@ test_that("job group functions can be called", { setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() + setJobDescription("job description") suppressWarnings(setJobGroup(sc, "groupId", "job description", TRUE)) suppressWarnings(cancelJobGroup(sc, "groupId")) From b803b66a8133f705463039325ee71ee6827ce1a7 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 23 Jun 2017 10:33:53 -0700 Subject: [PATCH 007/779] [SPARK-21180][SQL] Remove conf from stats functions since now we have conf in LogicalPlan ## What changes were proposed in this pull request? After wiring `SQLConf` in logical plan ([PR 18299](https://github.com/apache/spark/pull/18299)), we can remove the need of passing `conf` into `def stats` and `def computeStats`. ## How was this patch tested? Covered by existing tests, plus some modified existing tests. Author: wangzhenhua Author: Zhenhua Wang Closes #18391 from wzhfy/removeConf. --- .../sql/catalyst/catalog/interface.scala | 3 +- .../optimizer/CostBasedJoinReorder.scala | 4 +- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../optimizer/StarSchemaDetection.scala | 14 ++-- .../plans/logical/LocalRelation.scala | 3 +- .../catalyst/plans/logical/LogicalPlan.scala | 15 ++--- .../plans/logical/basicLogicalOperators.scala | 65 +++++++++---------- .../sql/catalyst/plans/logical/hints.scala | 5 +- .../statsEstimation/AggregateEstimation.scala | 7 +- .../statsEstimation/EstimationUtils.scala | 5 +- .../statsEstimation/FilterEstimation.scala | 5 +- .../statsEstimation/JoinEstimation.scala | 21 +++--- .../statsEstimation/ProjectEstimation.scala | 7 +- .../optimizer/JoinOptimizationSuite.scala | 2 +- .../optimizer/LimitPushdownSuite.scala | 6 +- .../AggregateEstimationSuite.scala | 30 +++++---- .../BasicStatsEstimationSuite.scala | 27 +++++--- .../FilterEstimationSuite.scala | 2 +- .../statsEstimation/JoinEstimationSuite.scala | 26 ++++---- .../ProjectEstimationSuite.scala | 4 +- .../StatsEstimationTestBase.scala | 18 +++-- .../spark/sql/execution/ExistingRDD.scala | 5 +- .../spark/sql/execution/QueryExecution.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 13 ++-- .../execution/columnar/InMemoryRelation.scala | 3 +- .../datasources/LogicalRelation.scala | 3 +- .../sql/execution/streaming/memory.scala | 3 +- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../spark/sql/StatisticsCollectionSuite.scala | 18 ++--- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../datasources/HadoopFsRelationSuite.scala | 2 +- .../execution/streaming/MemorySinkSuite.scala | 6 +- .../apache/spark/sql/test/SQLTestData.scala | 3 - .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../spark/sql/hive/StatisticsSuite.scala | 10 +-- .../PruneFileSourcePartitionsSuite.scala | 2 +- 38 files changed, 178 insertions(+), 173 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c043ed9c431b7..b63bef9193332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -436,7 +435,7 @@ case class CatalogRelation( createTime = -1 )) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for // hive serde tables, we will always generate a statistics. // TODO: unify the table stats generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 51eca6ca33760..3a7543e2141e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && - items.forall(_.stats(conf).rowCount.isDefined)) { + items.forall(_.stats.rowCount.isDefined)) { JoinReorderDP.search(conf, items, conditions, output) } else { plan @@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging { /** Get the cost of the root node of this plan tree. */ def rootCost(conf: SQLConf): Cost = { if (itemIds.size > 1) { - val rootStats = plan.stats(conf) + val rootStats = plan.stats Cost(rootStats.rowCount.get, rootStats.sizeInBytes) } else { // If the plan is a leaf item, it has zero cost. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3ab70fb90470c..b410312030c5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] { case FullOuter => (left.maxRows, right.maxRows) match { case (None, None) => - if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) { + if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { join.copy(left = maybePushLimit(exp, left)) } else { join.copy(right = maybePushLimit(exp, right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 97ee9988386dd..ca729127e7d1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { // Find if the input plans are eligible for star join detection. // An eligible plan is a base table access with valid statistics. val foundEligibleJoin = input.forall { - case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true case _ => false } @@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) + val stats = t.stats stats.rowCount match { case Some(rowCount) if rowCount >= 0 => if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { @@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) + val stats = t.stats stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) case None => false } @@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { */ private def getTableAccessCardinality( input: LogicalPlan): Option[BigInt] = input match { - case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => - if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { - Option(input.stats(conf).rowCount.get) + case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined => + if (conf.cboEnabled && input.stats.rowCount.isDefined) { + Option(input.stats.rowCount.get) } else { - Option(t.stats(conf).rowCount.get) + Option(t.stats.rowCount.get) } case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 9cd5dfd21b160..dc2add64b68b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 95b4165f6b10d..0c098ac0209e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai * first time. If the configuration changes, the cache can be invalidated by calling * [[invalidateStatsCache()]]. */ - final def stats(conf: SQLConf): Statistics = statsCache.getOrElse { - statsCache = Some(computeStats(conf)) + final def stats: Statistics = statsCache.getOrElse { + statsCache = Some(computeStats) statsCache.get } @@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai * * [[LeafNode]]s must override this. */ - protected def computeStats(conf: SQLConf): Statistics = { + protected def computeStats: Statistics = { if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } - Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product) + Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product) } override def verboseStringWithSuffix: String = { @@ -333,13 +332,13 @@ abstract class UnaryNode extends LogicalPlan { override protected def validConstraints: Set[Expression] = child.constraints - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 // Assume there will be the same number of rows as child has. - var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize + var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). @@ -347,7 +346,7 @@ abstract class UnaryNode extends LogicalPlan { } // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints) + Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6e88b7a57dc33..d8f89b108e63f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -65,11 +64,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (conf.cboEnabled) { - ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf)) + ProjectEstimation.estimate(this).getOrElse(super.computeStats) } else { - super.computeStats(conf) + super.computeStats } } } @@ -139,11 +138,11 @@ case class Filter(condition: Expression, child: LogicalPlan) child.constraints.union(predicates.toSet) } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (conf.cboEnabled) { - FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) + FilterEstimation(this).estimate.getOrElse(super.computeStats) } else { - super.computeStats(conf) + super.computeStats } } } @@ -192,13 +191,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } - override def computeStats(conf: SQLConf): Statistics = { - val leftSize = left.stats(conf).sizeInBytes - val rightSize = right.stats(conf).sizeInBytes + override def computeStats: Statistics = { + val leftSize = left.stats.sizeInBytes + val rightSize = right.stats.sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize Statistics( sizeInBytes = sizeInBytes, - hints = left.stats(conf).hints.resetForJoin()) + hints = left.stats.hints.resetForJoin()) } } @@ -209,8 +208,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override protected def validConstraints: Set[Expression] = leftConstraints - override def computeStats(conf: SQLConf): Statistics = { - left.stats(conf).copy() + override def computeStats: Statistics = { + left.stats.copy() } } @@ -248,8 +247,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override def computeStats(conf: SQLConf): Statistics = { - val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum + override def computeStats: Statistics = { + val sizeInBytes = children.map(_.stats.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } @@ -357,20 +356,20 @@ case class Join( case _ => resolvedExceptNatural } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { def simpleEstimation: Statistics = joinType match { case LeftAnti | LeftSemi => // LeftSemi and LeftAnti won't ever be bigger than left - left.stats(conf) + left.stats case _ => // Make sure we don't propagate isBroadcastable in other joins, because // they could explode the size. - val stats = super.computeStats(conf) + val stats = super.computeStats stats.copy(hints = stats.hints.resetForJoin()) } if (conf.cboEnabled) { - JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation) + JoinEstimation.estimate(this).getOrElse(simpleEstimation) } else { simpleEstimation } @@ -523,7 +522,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) } @@ -556,20 +555,20 @@ case class Aggregate( child.constraints.union(getAliasedConstraints(nonAgg)) } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { Statistics( sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), rowCount = Some(1), - hints = child.stats(conf).hints) + hints = child.stats.hints) } else { - super.computeStats(conf) + super.computeStats } } if (conf.cboEnabled) { - AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation) + AggregateEstimation.estimate(this).getOrElse(simpleEstimation) } else { simpleEstimation } @@ -672,8 +671,8 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def computeStats(conf: SQLConf): Statistics = { - val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length + override def computeStats: Statistics = { + val sizeInBytes = super.computeStats.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } @@ -743,9 +742,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats(conf) + val childStats = child.stats val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) // Don't propagate column stats, because we don't know the distribution after a limit operation Statistics( @@ -763,9 +762,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats(conf) + val childStats = child.stats if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). @@ -832,9 +831,9 @@ case class Sample( override def output: Seq[Attribute] = child.output - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val ratio = upperBound - lowerBound - val childStats = child.stats(conf) + val childStats = child.stats var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) if (sizeInBytes == 0) { sizeInBytes = 1 @@ -898,7 +897,7 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1) + override def computeStats: Statistics = Statistics(sizeInBytes = 1) } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index e49970df80457..8479c702d7561 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf /** * A general hint for the child that is not yet resolved. This node is generated by the parser and @@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override lazy val canonicalized: LogicalPlan = child.canonicalized - override def computeStats(conf: SQLConf): Statistics = { - val stats = child.stats(conf) + override def computeStats: Statistics = { + val stats = child.stats stats.copy(hints = hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index a0c23198451a8..c41fac4015ec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} -import org.apache.spark.sql.internal.SQLConf object AggregateEstimation { @@ -29,13 +28,13 @@ object AggregateEstimation { * Estimate the number of output rows based on column stats of group-by columns, and propagate * column stats for aggregate expressions. */ - def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = { - val childStats = agg.child.stats(conf) + def estimate(agg: Aggregate): Option[Statistics] = { + val childStats = agg.child.stats // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) } - if (rowCountsExist(conf, agg.child) && colStatsExist) { + if (rowCountsExist(agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index e5fcdf9039be9..9c34a9b7aa756 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -21,15 +21,14 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, _} object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ - def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean = - plans.forall(_.stats(conf).rowCount.isDefined) + def rowCountsExist(plans: LogicalPlan*): Boolean = + plans.forall(_.stats.rowCount.isDefined) /** Check if each attribute has column stat in the corresponding statistics. */ def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index df190867189ec..5a3bee7b9e449 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging { +case class FilterEstimation(plan: Filter) extends Logging { - private val childStats = plan.child.stats(catalystConf) + private val childStats = plan.child.stats private val colStatsMap = new ColumnStatsMap(childStats.attributeStats) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 8ef905c45d50d..f48196997a24d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.internal.SQLConf object JoinEstimation extends Logging { @@ -34,12 +33,12 @@ object JoinEstimation extends Logging { * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(conf: SQLConf, join: Join): Option[Statistics] = { + def estimate(join: Join): Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => - InnerOuterEstimation(conf, join).doEstimate() + InnerOuterEstimation(join).doEstimate() case LeftSemi | LeftAnti => - LeftSemiAntiEstimation(conf, join).doEstimate() + LeftSemiAntiEstimation(join).doEstimate() case _ => logDebug(s"[CBO] Unsupported join type: ${join.joinType}") None @@ -47,16 +46,16 @@ object JoinEstimation extends Logging { } } -case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { +case class InnerOuterEstimation(join: Join) extends Logging { - private val leftStats = join.left.stats(conf) - private val rightStats = join.right.stats(conf) + private val leftStats = join.left.stats + private val rightStats = join.right.stats /** * Estimate output size and number of rows after a join operator, and update output column stats. */ def doEstimate(): Option[Statistics] = join match { - case _ if !rowCountsExist(conf, join.left, join.right) => + case _ if !rowCountsExist(join.left, join.right) => None case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => @@ -273,13 +272,13 @@ case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { } } -case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) { +case class LeftSemiAntiEstimation(join: Join) { def doEstimate(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more // accurate estimation when advanced stats (e.g. histograms) are available. - if (rowCountsExist(conf, join.left)) { - val leftStats = join.left.stats(conf) + if (rowCountsExist(join.left)) { + val leftStats = join.left.stats // Propagate the original column stats for cartesian product val outputRows = leftStats.rowCount.get Some(Statistics( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index d700cd3b20f7d..489eb904ffd05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} -import org.apache.spark.sql.internal.SQLConf object ProjectEstimation { import EstimationUtils._ - def estimate(conf: SQLConf, project: Project): Option[Statistics] = { - if (rowCountsExist(conf, project.child)) { - val childStats = project.child.stats(conf) + def estimate(project: Project): Option[Statistics] = { + if (rowCountsExist(project.child)) { + val childStats = project.child.stats val inputAttrStats = childStats.attributeStats // Match alias with its child's column stat val aliasStats = project.expressions.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 105407d43bf39..a6584aa5fbba7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -142,7 +142,7 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, expected) val broadcastChildren = optimized.collect { - case Join(_, r, _, _) if r.stats(conf).sizeInBytes == 1 => r + case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r } assert(broadcastChildren.size == 1) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index fb34c82de468b..d8302dfc9462d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -112,7 +112,7 @@ class LimitPushdownSuite extends PlanTest { } test("full outer join where neither side is limited and both sides have same statistics") { - assert(x.stats(conf).sizeInBytes === y.stats(conf).sizeInBytes) + assert(x.stats.sizeInBytes === y.stats.sizeInBytes) val originalQuery = x.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze @@ -121,7 +121,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and left side has larger statistics") { val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) - assert(xBig.stats(conf).sizeInBytes > y.stats(conf).sizeInBytes) + assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes) val originalQuery = xBig.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze @@ -130,7 +130,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and right side has larger statistics") { val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) - assert(x.stats(conf).sizeInBytes < yBig.stats(conf).sizeInBytes) + assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes) val originalQuery = x.join(yBig, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 38483a298cef0..30ddf03bd3c4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -100,17 +100,23 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { size = Some(4 * (8 + 4)), attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) - val noGroupAgg = Aggregate(groupingExpressions = Nil, - aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) - assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == - // overhead + count result size - Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) - - val hasGroupAgg = Aggregate(groupingExpressions = attributes, - aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) - assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == - // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize - Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + try { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + val noGroupAgg = Aggregate(groupingExpressions = Nil, + aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) + assert(noGroupAgg.stats == + // overhead + count result size + Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) + + val hasGroupAgg = Aggregate(groupingExpressions = attributes, + aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) + assert(hasGroupAgg.stats == + // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize + Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + } finally { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + } } private def checkAggStats( @@ -134,6 +140,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedOutputRowCount), attributeStats = expectedAttrStats) - assert(testAgg.stats(conf) == expectedStats) + assert(testAgg.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 833f5a71994f7..e9ed36feec48c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -57,16 +57,16 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan) // LocalLimit's stats is just its child's stats except column stats - checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) } test("limit estimation: limit > child's rowCount") { val localLimit = LocalLimit(Literal(20), plan) val globalLimit = GlobalLimit(Literal(20), plan) - checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. - checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(globalLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) } test("limit estimation: limit = 0") { @@ -113,12 +113,19 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - // Invalidate statistics - plan.invalidateStatsCache() - assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn) - - plan.invalidateStatsCache() - assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff) + val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + try { + // Invalidate statistics + plan.invalidateStatsCache() + SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) + assert(plan.stats == expectedStatsCboOn) + + plan.invalidateStatsCache() + SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + assert(plan.stats == expectedStatsCboOff) + } finally { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + } } /** Check estimated stats when it's the same whether cbo is turned on or off. */ @@ -135,6 +142,6 @@ private case class DummyLogicalPlan( cboStats: Statistics) extends LogicalPlan { override def output: Seq[Attribute] = Nil override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = if (conf.cboEnabled) cboStats else defaultStats } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2fa53a6466ef2..455037e6c9952 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -620,7 +620,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRowCount), attributeStats = expectedAttributeMap) - val filterStats = filter.stats(conf) + val filterStats = filter.stats assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) assert(filterStats.rowCount == expectedStats.rowCount) val rowCountValue = filterStats.rowCount.getOrElse(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 2d6b6e8e21f34..097c78eb27fca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -77,7 +77,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Keep the column stat from both sides unchanged. attributeStats = AttributeMap( Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint inner join") { @@ -90,7 +90,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1, rowCount = Some(0), attributeStats = AttributeMap(Nil)) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint left outer join") { @@ -106,7 +106,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Null count for right side columns = left row count Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5), nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint right outer join") { @@ -122,7 +122,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Null count for left side columns = right row count Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3), nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint full outer join") { @@ -140,7 +140,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("inner join") { @@ -161,7 +161,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat, nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("inner join with multiple equi-join keys") { @@ -183,7 +183,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("left outer join") { @@ -201,7 +201,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"), nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("right outer join") { @@ -219,7 +219,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat, nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("full outer join") { @@ -234,7 +234,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Keep the column stat from both sides unchanged. attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"), nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("left semi/anti join") { @@ -248,7 +248,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 3 * (8 + 4 * 2), rowCount = Some(3), attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } @@ -306,7 +306,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), rowCount = Some(1), attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } } @@ -323,6 +323,6 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1, rowCount = Some(0), attributeStats = AttributeMap(Nil)) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index a5c4d22a29386..cda54fa9d64f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -45,7 +45,7 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 2 * (8 + 4 + 4), rowCount = Some(2), attributeStats = expectedAttrStats) - assert(proj.stats(conf) == expectedStats) + assert(proj.stats == expectedStats) } test("project on empty table") { @@ -131,6 +131,6 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { sizeInBytes = expectedSize, rowCount = Some(expectedRowCount), attributeStats = projectAttrMap) - assert(proj.stats(conf) == expectedStats) + assert(proj.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 263f4e18803d5..eaa33e44a6a5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -21,14 +21,24 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED} import org.apache.spark.sql.types.{IntegerType, StringType} trait StatsEstimationTestBase extends SparkFunSuite { - /** Enable stats estimation based on CBO. */ - protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true) + var originalValue: Boolean = false + + override def beforeAll(): Unit = { + super.beforeAll() + // Enable stats estimation based on CBO. + originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) + } + + override def afterAll(): Unit = { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + super.afterAll() + } def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes @@ -55,7 +65,7 @@ case class StatsTestPlan( attributeStats: AttributeMap[ColumnStat], size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList - override def computeStats(conf: SQLConf): Statistics = Statistics( + override def computeStats: Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 3d1b481a53e75..66f66a289a065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -89,7 +88,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: SQLConf): Statistics = Statistics( + @transient override def computeStats: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -157,7 +156,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: SQLConf): Statistics = Statistics( + @transient override def computeStats: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 34998cbd61552..c7cac332a0377 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -221,7 +221,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { def stringWithStats: String = { // trigger to compute stats for logical plans - optimizedPlan.stats(sparkSession.sessionState.conf) + optimizedPlan.stats // only show optimized logical plan and physical plan s"""== Optimized Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ea86f6e00fefa..a57d5abb90c0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -114,9 +113,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats(conf).hints.broadcast || - (plan.stats(conf).sizeInBytes >= 0 && - plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) + plan.stats.hints.broadcast || + (plan.stats.sizeInBytes >= 0 && + plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold) } /** @@ -126,7 +125,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * dynamic. */ private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { - plan.stats(conf).sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions } /** @@ -137,7 +136,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * use the size of bytes here as estimation. */ private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { - a.stats(conf).sizeInBytes * 3 <= b.stats(conf).sizeInBytes + a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes } private def canBuildRight(joinType: JoinType): Boolean = joinType match { @@ -206,7 +205,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.stats(conf).sizeInBytes <= left.stats(conf).sizeInBytes) { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) { BuildRight } else { BuildLeft diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 456a8f3b20f30..2972132336de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -70,7 +69,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 3813f953e06a3..c1b2895f1747e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils @@ -46,7 +45,7 @@ case class LogicalRelation( // Only care about relation when canonicalizing. override def preCanonicalized: LogicalPlan = copy(catalogTable = None) - @transient override def computeStats(conf: SQLConf): Statistics = { + @transient override def computeStats: Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( Statistics(sizeInBytes = relation.sizeInBytes)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7eaa803a9ecb4..a5dac469f85b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -230,6 +229,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 8532a5b5bc8eb..506cc2548e260 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -313,7 +313,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum - assert(cached.stats(sqlConf).sizeInBytes === actualSizeInBytes) + assert(cached.stats.sizeInBytes === actualSizeInBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 165176f6c040e..87b7b090de3bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1146,7 +1146,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // instead of Int for avoiding possible overflow. val ds = (0 to 10000).map( i => (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() - val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes + val sizeInBytes = ds.logicalPlan.stats.sizeInBytes // sizeInBytes is 2404280404, before the fix, it overflows to a negative number assert(sizeInBytes > 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1a66aa85f5a02..895ca196a7a51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -33,7 +33,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() def statisticSizeInByte(df: DataFrame): BigInt = { - df.queryExecution.optimizedPlan.stats(sqlConf).sizeInBytes + df.queryExecution.optimizedPlan.stats.sizeInBytes } test("equi-join is hash-join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 601324f2c0172..9824062f969b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -60,7 +60,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = df1.join(df2, Seq("k"), "left") val sizes = df.queryExecution.analyzed.collect { case g: Join => - g.stats(conf).sizeInBytes + g.stats.sizeInBytes } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") @@ -107,9 +107,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) - assert(df.queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) - assert(df.selectExpr("a").queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.selectExpr("a").queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) } @@ -250,13 +250,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils test("SPARK-18856: non-empty partitioned table should not report zero size") { withTable("ds_tbl", "hive_tbl") { spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") - val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats(conf) + val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") - val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats(conf) + val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") } } @@ -296,10 +296,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) // Check relation statistics - assert(relation.stats(conf).sizeInBytes == 0) - assert(relation.stats(conf).rowCount == Some(0)) - assert(relation.stats(conf).attributeStats.size == 1) - val (attribute, colStat) = relation.stats(conf).attributeStats.head + assert(relation.stats.sizeInBytes == 0) + assert(relation.stats.rowCount == Some(0)) + assert(relation.stats.attributeStats.size == 1) + val (attribute, colStat) = relation.stats.attributeStats.head assert(attribute.name == "c1") assert(colStat == emptyColStat) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 109b1d9db60d2..8d411eb191cd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { .toDF().createOrReplaceTempView("sizeTst") spark.catalog.cacheTable("sizeTst") assert( - spark.table("sizeTst").queryExecution.analyzed.stats(sqlConf).sizeInBytes > + spark.table("sizeTst").queryExecution.analyzed.stats.sizeInBytes > spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index becb3aa270401..caf03885e3873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -36,7 +36,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { }) val totalSize = allFiles.map(_.length()).sum val df = spark.read.parquet(dir.toString) - assert(df.queryExecution.logical.stats(sqlConf).sizeInBytes === BigInt(totalSize)) + assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 24a7b7740fa5b..e8420eee7fe9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -216,15 +216,15 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { // Before adding data, check output checkAnswer(sink.allData, Seq.empty) - assert(plan.stats(sqlConf).sizeInBytes === 0) + assert(plan.stats.sizeInBytes === 0) sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 12) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 24) } ignore("stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index f9b3ff8405823..0cfe260e52152 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -21,7 +21,6 @@ import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} -import org.apache.spark.sql.internal.SQLConf /** * A collection of sample data used in SQL tests. @@ -29,8 +28,6 @@ import org.apache.spark.sql.internal.SQLConf private[sql] trait SQLTestData { self => protected def spark: SparkSession - protected def sqlConf: SQLConf = spark.sessionState.conf - // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.spark.sqlContext diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ff5afc8e3ce05..808dc013f170b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -154,7 +154,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong + val sizeInBytes = relation.stats.sizeInBytes.toLong val fileIndex = { val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes) if (lazyPruningEnabled) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 001bbc230ff18..279db9a397258 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -68,7 +68,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") - val sizeInBytes = relation.stats(conf).sizeInBytes + val sizeInBytes = relation.stats.sizeInBytes assert(sizeInBytes === BigInt(file1.length() + file2.length())) } } @@ -77,7 +77,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("analyze Hive serde tables") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -659,7 +659,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("estimates the size of a test Hive serde tables") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats.sizeInBytes } assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), @@ -679,7 +679,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats(conf).sizeInBytes + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats.sizeInBytes } assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold, @@ -733,7 +733,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats.sizeInBytes } assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index d91f25a4da013..3a724aa14f2a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -86,7 +86,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te case relation: LogicalRelation => relation } assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") - val size2 = relations(0).computeStats(conf).sizeInBytes + val size2 = relations(0).computeStats.sizeInBytes assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) assert(size2 < tableStats.get.sizeInBytes) } From 1ebe7ffe072bcac03360e65e959a6cd36530a9c4 Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Fri, 23 Jun 2017 10:36:29 -0700 Subject: [PATCH 008/779] [SPARK-21181] Release byteBuffers to suppress netty error messages ## What changes were proposed in this pull request? We are explicitly calling release on the byteBuf's used to encode the string to Base64 to suppress the memory leak error message reported by netty. This is to make it less confusing for the user. ### Changes proposed in this fix By explicitly invoking release on the byteBuf's we are decrement the internal reference counts for the wrappedByteBuf's. Now, when the GC kicks in, these would be reclaimed as before, just that netty wouldn't report any memory leak error messages as the internal ref. counts are now 0. ## How was this patch tested? Ran a few spark-applications and examined the logs. The error message no longer appears. Original PR was opened against branch-2.1 => https://github.com/apache/spark/pull/18392 Author: Dhruve Ashar Closes #18407 from dhruve/master. --- .../spark/network/sasl/SparkSaslServer.java | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index e24fdf0c74de3..00f3e83dbc8b3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -34,6 +34,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.handler.codec.base64.Base64; import org.slf4j.Logger; @@ -187,14 +188,31 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(StandardCharsets.UTF_8))) - .toString(StandardCharsets.UTF_8); + return getBase64EncodedString(identifier); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(StandardCharsets.UTF_8))) - .toString(StandardCharsets.UTF_8).toCharArray(); + return getBase64EncodedString(password).toCharArray(); + } + + /** Return a Base64-encoded string. */ + private static String getBase64EncodedString(String str) { + ByteBuf byteBuf = null; + ByteBuf encodedByteBuf = null; + try { + byteBuf = Unpooled.wrappedBuffer(str.getBytes(StandardCharsets.UTF_8)); + encodedByteBuf = Base64.encode(byteBuf); + return encodedByteBuf.toString(StandardCharsets.UTF_8); + } finally { + // The release is called to suppress the memory leak error messages raised by netty. + if (byteBuf != null) { + byteBuf.release(); + if (encodedByteBuf != null) { + encodedByteBuf.release(); + } + } + } } } From 2ebd0838d165fe33b404e8d86c0fa445d1f47439 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Jun 2017 10:55:02 -0700 Subject: [PATCH 009/779] [SPARK-21192][SS] Preserve State Store provider class configuration across StreamingQuery restarts ## What changes were proposed in this pull request? If the SQL conf for StateStore provider class is changed between restarts (i.e. query started with providerClass1 and attempted to restart using providerClass2), then the query will fail in a unpredictable way as files saved by one provider class cannot be used by the newer one. Ideally, the provider class used to start the query should be used to restart the query, and the configuration in the session where it is being restarted should be ignored. This PR saves the provider class config to OffsetSeqLog, in the same way # shuffle partitions is saved and recovered. ## How was this patch tested? new unit tests Author: Tathagata Das Closes #18402 from tdas/SPARK-21192. --- .../apache/spark/sql/internal/SQLConf.scala | 5 +- .../sql/execution/streaming/OffsetSeq.scala | 39 +++++++++++++- .../execution/streaming/StreamExecution.scala | 26 +++------- .../streaming/state/StateStore.scala | 3 +- .../streaming/state/StateStoreConf.scala | 2 +- .../streaming/OffsetSeqLogSuite.scala | 10 ++-- .../spark/sql/streaming/StreamSuite.scala | 51 +++++++++++++++---- 7 files changed, 96 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e609256db2802..9c8e26a8eeadf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -601,7 +601,8 @@ object SQLConf { "The class used to manage state data in stateful streaming queries. This class must " + "be a subclass of StateStoreProvider, and must have a zero-arg constructor.") .stringConf - .createOptional + .createWithDefault( + "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider") val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot") @@ -897,7 +898,7 @@ class SQLConf extends Serializable with Logging { def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) - def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 8249adab4bba8..4e0a468b962a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -20,6 +20,10 @@ package org.apache.spark.sql.execution.streaming import org.json4s.NoTypeHints import org.json4s.jackson.Serialization +import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} + /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance @@ -78,7 +82,40 @@ case class OffsetSeqMetadata( def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } -object OffsetSeqMetadata { +object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) + private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS) + def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) + + def apply( + batchWatermarkMs: Long, + batchTimestampMs: Long, + sessionConf: RuntimeConfig): OffsetSeqMetadata = { + val confs = relevantSQLConfs.map { conf => conf.key -> sessionConf.get(conf.key) }.toMap + OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs) + } + + /** Set the SparkSession configuration with the values in the metadata */ + def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: RuntimeConfig): Unit = { + OffsetSeqMetadata.relevantSQLConfs.map(_.key).foreach { confKey => + + metadata.conf.get(confKey) match { + + case Some(valueInMetadata) => + // Config value exists in the metadata, update the session config with this value + val optionalValueInSession = sessionConf.getOption(confKey) + if (optionalValueInSession.isDefined && optionalValueInSession.get != valueInMetadata) { + logWarning(s"Updating the value of conf '$confKey' in current session from " + + s"'${optionalValueInSession.get}' to '$valueInMetadata'.") + } + sessionConf.set(confKey, valueInMetadata) + + case None => + // For backward compatibility, if a config was not recorded in the offset log, + // then log it, and let the existing conf value in SparkSession prevail. + logWarning (s"Conf '$confKey' was not found in the offset log, using existing value") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 06bdec8b06407..d5f8d2acba92b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -125,9 +125,8 @@ class StreamExecution( } /** Metadata associated with the offset seq of a batch in the query. */ - protected var offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, - conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> - sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString)) + protected var offsetSeqMetadata = OffsetSeqMetadata( + batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf) override val id: UUID = UUID.fromString(streamMetadata.id) @@ -285,9 +284,8 @@ class StreamExecution( val sparkSessionToRunBatches = sparkSession.cloneSession() // Adaptive execution can change num shuffle partitions, disallow sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") - offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, - conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> - sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS.key))) + offsetSeqMetadata = OffsetSeqMetadata( + batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionToRunBatches.conf) if (state.compareAndSet(INITIALIZING, ACTIVE)) { // Unblock `awaitInitialization` @@ -441,21 +439,9 @@ class StreamExecution( // update offset metadata nextOffsets.metadata.foreach { metadata => - val shufflePartitionsSparkSession: Int = - sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS) - val shufflePartitionsToUse = metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, { - // For backward compatibility, if # partitions was not recorded in the offset log, - // then ensure it is not missing. The new value is picked up from the conf. - logWarning("Number of shuffle partitions from previous run not found in checkpoint. " - + s"Using the value from the conf, $shufflePartitionsSparkSession partitions.") - shufflePartitionsSparkSession - }) + OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( - metadata.batchWatermarkMs, metadata.batchTimestampMs, - metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsToUse.toString)) - // Update conf with correct number of shuffle partitions - sparkSessionToRunBatches.conf.set( - SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString) + metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) } /* identify the current batch id: if commit log indicates we successfully processed the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index a94ff8a7ebd1e..86886466c4f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -172,8 +172,7 @@ object StateStoreProvider { indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): StateStoreProvider = { - val providerClass = storeConf.providerClass.map(Utils.classForName) - .getOrElse(classOf[HDFSBackedStateStoreProvider]) + val providerClass = Utils.classForName(storeConf.providerClass) val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider] provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) provider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index bab297c7df594..765ff076cb467 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -38,7 +38,7 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) * Optional fully qualified name of the subclass of [[StateStoreProvider]] * managing state data. That is, the implementation of the State Store to use. */ - val providerClass: Option[String] = sqlConf.stateStoreProviderClass + val providerClass: String = sqlConf.stateStoreProviderClass /** * Additional configurations related to state store. This will capture all configs in diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index dc556322beddb..e6cdc063c4e9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -37,16 +37,18 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { } // None set - assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) + assert(new OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) // One set - assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) - assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) + assert(new OffsetSeqMetadata(1, 0, Map.empty) === + OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) + assert(new OffsetSeqMetadata(0, 2, Map.empty) === + OffsetSeqMetadata("""{"batchTimestampMs":2}""")) assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) === OffsetSeqMetadata(s"""{"conf": {"$key":2}}""")) // Two set - assert(OffsetSeqMetadata(1, 2, Map.empty) === + assert(new OffsetSeqMetadata(1, 2, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}""")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 86c3a35a59c13..6f7b9d35a6bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -637,19 +637,11 @@ class StreamSuite extends StreamTest { } testQuietly("specify custom state store provider") { - val queryName = "memStream" val providerClassName = classOf[TestStateStoreProvider].getCanonicalName withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) { val input = MemoryStream[Int] - val query = input - .toDS() - .groupBy() - .count() - .writeStream - .outputMode("complete") - .format("memory") - .queryName(queryName) - .start() + val df = input.toDS().groupBy().count() + val query = df.writeStream.outputMode("complete").format("memory").queryName("name").start() input.addData(1, 2, 3) val e = intercept[Exception] { query.awaitTermination() @@ -659,6 +651,45 @@ class StreamSuite extends StreamTest { assert(e.getMessage.contains("instantiated")) } } + + testQuietly("custom state store provider read from offset log") { + val input = MemoryStream[Int] + val df = input.toDS().groupBy().count() + val providerConf1 = "spark.sql.streaming.stateStore.providerClass" -> + "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider" + val providerConf2 = "spark.sql.streaming.stateStore.providerClass" -> + classOf[TestStateStoreProvider].getCanonicalName + + def runQuery(queryName: String, checkpointLoc: String): Unit = { + val query = df.writeStream + .outputMode("complete") + .format("memory") + .queryName(queryName) + .option("checkpointLocation", checkpointLoc) + .start() + input.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + } + + withTempDir { dir => + val checkpointLoc1 = new File(dir, "1").getCanonicalPath + withSQLConf(providerConf1) { + runQuery("query1", checkpointLoc1) // generate checkpoints + } + + val checkpointLoc2 = new File(dir, "2").getCanonicalPath + withSQLConf(providerConf2) { + // Verify new query will use new provider that throw error on loading + intercept[Exception] { + runQuery("query2", checkpointLoc2) + } + + // Verify old query from checkpoint will still use old provider + runQuery("query1", checkpointLoc1) + } + } + } } abstract class FakeSource extends StreamSourceProvider { From 4cc62951a2b12a372a2b267bf8597a0a31e2b2cb Mon Sep 17 00:00:00 2001 From: Ong Ming Yang Date: Fri, 23 Jun 2017 10:56:59 -0700 Subject: [PATCH 010/779] [MINOR][DOCS] Docs in DataFrameNaFunctions.scala use wrong method ## What changes were proposed in this pull request? * Following the first few examples in this file, the remaining methods should also be methods of `df.na` not `df`. * Filled in some missing parentheses ## How was this patch tested? N/A Author: Ong Ming Yang Closes #18398 from ongmingyang/master. --- .../spark/sql/DataFrameNaFunctions.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ee949e78fa3ba..871fff71e5538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -268,13 +268,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * import com.google.common.collect.ImmutableMap; * * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.replace("height", ImmutableMap.of(1.0, 2.0)); + * df.na.replace("height", ImmutableMap.of(1.0, 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * * @param col name of the column to apply the value replacement @@ -295,10 +295,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * import com.google.common.collect.ImmutableMap; * * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); + * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * * @param cols list of columns to apply the value replacement @@ -319,13 +319,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.replace("height", Map(1.0 -> 2.0)) + * df.na.replace("height", Map(1.0 -> 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.replace("name", Map("UNKNOWN" -> "unnamed") + * df.na.replace("name", Map("UNKNOWN" -> "unnamed")); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.replace("*", Map("UNKNOWN" -> "unnamed") + * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); * }}} * * @param col name of the column to apply the value replacement @@ -348,10 +348,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); + * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"); + * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); * }}} * * @param cols list of columns to apply the value replacement From 13c2a4f2f8c6d3484f920caadddf4e5edce0a945 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 23 Jun 2017 11:02:54 -0700 Subject: [PATCH 011/779] [SPARK-20417][SQL] Move subquery error handling to checkAnalysis from Analyzer ## What changes were proposed in this pull request? Currently we do a lot of validations for subquery in the Analyzer. We should move them to CheckAnalysis which is the framework to catch and report Analysis errors. This was mentioned as a review comment in SPARK-18874. ## How was this patch tested? Exists tests + A few tests added to SQLQueryTestSuite. Author: Dilip Biswal Closes #17713 from dilipbiswal/subquery_checkanalysis. --- .../sql/catalyst/analysis/Analyzer.scala | 230 +----------- .../sql/catalyst/analysis/CheckAnalysis.scala | 338 ++++++++++++++---- .../sql/catalyst/expressions/predicates.scala | 46 ++- .../analysis/AnalysisErrorSuite.scala | 3 +- .../analysis/ResolveSubquerySuite.scala | 2 +- .../negative-cases/subq-input-typecheck.sql | 47 +++ .../subq-input-typecheck.sql.out | 106 ++++++ .../org/apache/spark/sql/SubquerySuite.scala | 2 +- 8 files changed, 464 insertions(+), 310 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 647fc0b9342c1..193082eb77024 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ -import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ @@ -1257,217 +1256,16 @@ class Analyzer( } /** - * Validates to make sure the outer references appearing inside the subquery - * are legal. This function also returns the list of expressions - * that contain outer references. These outer references would be kept as children - * of subquery expressions by the caller of this function. - */ - private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { - val outerReferences = ArrayBuffer.empty[Expression] - - // Validate that correlated aggregate expression do not contain a mixture - // of outer and local references. - def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { - expr.foreach { - case a: AggregateExpression if containsOuter(a) => - val outer = a.collect { case OuterReference(e) => e.toAttribute } - val local = a.references -- outer - if (local.nonEmpty) { - val msg = - s""" - |Found an aggregate expression in a correlated predicate that has both - |outer and local references, which is not supported yet. - |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, - |Outer references: ${outer.map(_.sql).mkString(", ")}, - |Local references: ${local.map(_.sql).mkString(", ")}. - """.stripMargin.replace("\n", " ").trim() - failAnalysis(msg) - } - case _ => - } - } - - // Make sure a plan's subtree does not contain outer references - def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (hasOuterReferences(p)) { - failAnalysis(s"Accessing outer query column is not allowed in:\n$p") - } - } - - // Make sure a plan's expressions do not contain : - // 1. Aggregate expressions that have mixture of outer and local references. - // 2. Expressions containing outer references on plan nodes other than Filter. - def failOnInvalidOuterReference(p: LogicalPlan): Unit = { - p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) - if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { - failAnalysis( - "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + - s"clauses:\n$p") - } - } - - // SPARK-17348: A potential incorrect result case. - // When a correlated predicate is a non-equality predicate, - // certain operators are not permitted from the operator - // hosting the correlated predicate up to the operator on the outer table. - // Otherwise, the pull up of the correlated predicate - // will generate a plan with a different semantics - // which could return incorrect result. - // Currently we check for Aggregate and Window operators - // - // Below shows an example of a Logical Plan during Analyzer phase that - // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] - // through the Aggregate (or Window) operator could alter the result of - // the Aggregate. - // - // Project [c1#76] - // +- Project [c1#87, c2#88] - // : (Aggregate or Window operator) - // : +- Filter [outer(c2#77) >= c2#88)] - // : +- SubqueryAlias t2, `t2` - // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] - // : +- LocalRelation [_1#84, _2#85] - // +- SubqueryAlias t1, `t1` - // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] - // +- LocalRelation [_1#73, _2#74] - def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { - if (found) { - // Report a non-supported case as an exception - failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") - } - } - - var foundNonEqualCorrelatedPred : Boolean = false - - // Simplify the predicates before validating any unsupported correlation patterns - // in the plan. - BooleanSimplification(sub).foreachUp { - - // Whitelist operators allowed in a correlated subquery - // There are 4 categories: - // 1. Operators that are allowed anywhere in a correlated subquery, and, - // by definition of the operators, they either do not contain - // any columns or cannot host outer references. - // 2. Operators that are allowed anywhere in a correlated subquery - // so long as they do not host outer references. - // 3. Operators that need special handlings. These operators are - // Project, Filter, Join, Aggregate, and Generate. - // - // Any operators that are not in the above list are allowed - // in a correlated subquery only if they are not on a correlation path. - // In other word, these operators are allowed only under a correlation point. - // - // A correlation path is defined as the sub-tree of all the operators that - // are on the path from the operator hosting the correlated expressions - // up to the operator producing the correlated values. - - // Category 1: - // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => - - // Category 2: - // These operators can be anywhere in a correlated subquery. - // so long as they do not host outer references in the operators. - case s: Sort => - failOnInvalidOuterReference(s) - case r: RepartitionByExpression => - failOnInvalidOuterReference(r) - - // Category 3: - // Filter is one of the two operators allowed to host correlated expressions. - // The other operator is Join. Filter can be anywhere in a correlated subquery. - case f: Filter => - // Find all predicates with an outer reference. - val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) - - // Find any non-equality correlated predicates - foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { - case _: EqualTo | _: EqualNullSafe => false - case _ => true - } - - failOnInvalidOuterReference(f) - // The aggregate expressions are treated in a special way by getOuterReferences. If the - // aggregate expression contains only outer reference attributes then the entire aggregate - // expression is isolated as an OuterReference. - // i.e min(OuterReference(b)) => OuterReference(min(b)) - outerReferences ++= getOuterReferences(correlated) - - // Project cannot host any correlated expressions - // but can be anywhere in a correlated subquery. - case p: Project => - failOnInvalidOuterReference(p) - - // Aggregate cannot host any correlated expressions - // It can be on a correlation path if the correlation contains - // only equality correlated predicates. - // It cannot be on a correlation path if the correlation has - // non-equality correlated predicates. - case a: Aggregate => - failOnInvalidOuterReference(a) - failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - - // Join can host correlated expressions. - case j @ Join(left, right, joinType, _) => - joinType match { - // Inner join, like Filter, can be anywhere. - case _: InnerLike => - failOnInvalidOuterReference(j) - - // Left outer join's right operand cannot be on a correlation path. - // LeftAnti and ExistenceJoin are special cases of LeftOuter. - // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame - // so it should not show up here in Analysis phase. This is just a safety net. - // - // LeftSemi does not allow output from the right operand. - // Any correlated references in the subplan - // of the right operand cannot be pulled up. - case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - failOnInvalidOuterReference(j) - failOnOuterReferenceInSubTree(right) - - // Likewise, Right outer join's left operand cannot be on a correlation path. - case RightOuter => - failOnInvalidOuterReference(j) - failOnOuterReferenceInSubTree(left) - - // Any other join types not explicitly listed above, - // including Full outer join, are treated as Category 4. - case _ => - failOnOuterReferenceInSubTree(j) - } - - // Generator with join=true, i.e., expressed with - // LATERAL VIEW [OUTER], similar to inner join, - // allows to have correlation under it - // but must not host any outer references. - // Note: - // Generator with join=false is treated as Category 4. - case g: Generate if g.join => - failOnInvalidOuterReference(g) - - // Category 4: Any other operators not in the above 3 categories - // cannot be on a correlation path, that is they are allowed only - // under a correlation point but they and their descendant operators - // are not allowed to have any correlated expressions. - case p => - failOnOuterReferenceInSubTree(p) - } - outerReferences - } - - /** - * Resolves the subquery. The subquery is resolved using its outer plans. This method - * will resolve the subquery by alternating between the regular analyzer and by applying the - * resolveOuterReferences rule. + * Resolves the subquery plan that is referenced in a subquery expression. The normal + * attribute references are resolved using regular analyzer and the outer references are + * resolved from the outer plans using the resolveOuterReferences method. * * Outer references from the correlated predicates are updated as children of * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, - plans: Seq[LogicalPlan], - requiredColumns: Int = 0)( + plans: Seq[LogicalPlan])( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { // Step 1: Resolve the outer expressions. var previous: LogicalPlan = null @@ -1488,15 +1286,8 @@ class Analyzer( // Step 2: If the subquery plan is fully resolved, pull the outer references and record // them as children of SubqueryExpression. if (current.resolved) { - // Make sure the resolved query has the required number of output columns. This is only - // needed for Scalar and IN subqueries. - if (requiredColumns > 0 && requiredColumns != current.output.size) { - failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + - s"does not match the required number of columns ($requiredColumns)") - } - // Validate the outer reference and record the outer references as children of - // subquery expression. - f(current, checkAndGetOuterReferences(current)) + // Record the outer references as children of subquery expression. + f(current, SubExprUtils.getOuterReferences(current)) } else { e.withNewPlan(current) } @@ -1514,16 +1305,11 @@ class Analyzer( private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => - resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) + resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => - // Get the left hand side expressions. - val expressions = value match { - case cns : CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId)) In(value, Seq(expr)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2e3ac3e474866..fb81a7006bc5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ +import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -129,61 +131,8 @@ trait CheckAnalysis extends PredicateHelper { case None => w } - case s @ ScalarSubquery(query, conditions, _) => - checkAnalysis(query) - - // If no correlation, the output must be exactly one column - if (conditions.isEmpty && query.output.size != 1) { - failAnalysis( - s"Scalar subquery must return only one column, but got ${query.output.size}") - } else if (conditions.nonEmpty) { - def checkAggregate(agg: Aggregate): Unit = { - // Make sure correlated scalar subqueries contain one row for every outer row by - // enforcing that they are aggregates containing exactly one aggregate expression. - // The analyzer has already checked that subquery contained only one output column, - // and added all the grouping expressions to the aggregate. - val aggregates = agg.expressions.flatMap(_.collect { - case a: AggregateExpression => a - }) - if (aggregates.isEmpty) { - failAnalysis("The output of a correlated scalar subquery must be aggregated") - } - - // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns - // are not part of the correlated columns. - val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) - // Collect the local references from the correlated predicate in the subquery. - val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) - .filterNot(conditions.flatMap(_.references).contains) - val correlatedCols = AttributeSet(subqueryColumns) - val invalidCols = groupByCols -- correlatedCols - // GROUP BY columns must be a subset of columns in the predicates - if (invalidCols.nonEmpty) { - failAnalysis( - "A GROUP BY clause in a scalar correlated subquery " + - "cannot contain non-correlated columns: " + - invalidCols.mkString(",")) - } - } - - // Skip subquery aliases added by the Analyzer. - // For projects, do the necessary mapping and skip to its child. - def cleanQuery(p: LogicalPlan): LogicalPlan = p match { - case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => cleanQuery(p.child) - case child => child - } - - cleanQuery(query) match { - case a: Aggregate => checkAggregate(a) - case Filter(_, a: Aggregate) => checkAggregate(a) - case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") - } - } - s - case s: SubqueryExpression => - checkAnalysis(s.plan) + checkSubqueryExpression(operator, s) s } @@ -291,19 +240,6 @@ trait CheckAnalysis extends PredicateHelper { case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr) - case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => - p match { - case _: Filter | _: Aggregate | _: Project => // Ok - case other => failAnalysis( - s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") - } - - case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => - p match { - case _: Filter => // Ok - case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") - } - case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) def ordinalNumber(i: Int): String = i match { @@ -414,4 +350,272 @@ trait CheckAnalysis extends PredicateHelper { plan.foreach(_.setAnalyzed()) } + + /** + * Validates subquery expressions in the plan. Upon failure, returns an user facing error. + */ + private def checkSubqueryExpression(plan: LogicalPlan, expr: SubqueryExpression): Unit = { + def checkAggregateInScalarSubquery( + conditions: Seq[Expression], + query: LogicalPlan, agg: Aggregate): Unit = { + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates containing exactly one aggregate expression. + val aggregates = agg.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + failAnalysis("The output of a correlated scalar subquery must be aggregated") + } + + // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns + // are not part of the correlated columns. + val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + // Collect the local references from the correlated predicate in the subquery. + val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) + val correlatedCols = AttributeSet(subqueryColumns) + val invalidCols = groupByCols -- correlatedCols + // GROUP BY columns must be a subset of columns in the predicates + if (invalidCols.nonEmpty) { + failAnalysis( + "A GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } + } + + // Skip subquery aliases added by the Analyzer. + // For projects, do the necessary mapping and skip to its child. + def cleanQueryInScalarSubquery(p: LogicalPlan): LogicalPlan = p match { + case s: SubqueryAlias => cleanQueryInScalarSubquery(s.child) + case p: Project => cleanQueryInScalarSubquery(p.child) + case child => child + } + + // Validate the subquery plan. + checkAnalysis(expr.plan) + + expr match { + case ScalarSubquery(query, conditions, _) => + // Scalar subquery must return one column as output. + if (query.output.size != 1) { + failAnalysis( + s"Scalar subquery must return only one column, but got ${query.output.size}") + } + + if (conditions.nonEmpty) { + cleanQueryInScalarSubquery(query) match { + case a: Aggregate => checkAggregateInScalarSubquery(conditions, query, a) + case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(conditions, query, a) + case fail => failAnalysis(s"Correlated scalar subqueries must be aggregated: $fail") + } + + // Only certain operators are allowed to host subquery expression containing + // outer references. + plan match { + case _: Filter | _: Aggregate | _: Project => // Ok + case other => failAnalysis( + "Correlated scalar sub-queries can only be used in a " + + s"Filter/Aggregate/Project: $plan") + } + } + + case inSubqueryOrExistsSubquery => + plan match { + case _: Filter => // Ok + case _ => + failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in a Filter: $plan") + } + } + + // Validate to make sure the correlations appearing in the query are valid and + // allowed by spark. + checkCorrelationsInSubquery(expr.plan) + } + + /** + * Validates to make sure the outer references appearing inside the subquery + * are allowed. + */ + private def checkCorrelationsInSubquery(sub: LogicalPlan): Unit = { + // Validate that correlated aggregate expression do not contain a mixture + // of outer and local references. + def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { + expr.foreach { + case a: AggregateExpression if containsOuter(a) => + val outer = a.collect { case OuterReference(e) => e.toAttribute } + val local = a.references -- outer + if (local.nonEmpty) { + val msg = + s""" + |Found an aggregate expression in a correlated predicate that has both + |outer and local references, which is not supported yet. + |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, + |Outer references: ${outer.map(_.sql).mkString(", ")}, + |Local references: ${local.map(_.sql).mkString(", ")}. + """.stripMargin.replace("\n", " ").trim() + failAnalysis(msg) + } + case _ => + } + } + + // Make sure a plan's subtree does not contain outer references + def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { + if (hasOuterReferences(p)) { + failAnalysis(s"Accessing outer query column is not allowed in:\n$p") + } + } + + // Make sure a plan's expressions do not contain : + // 1. Aggregate expressions that have mixture of outer and local references. + // 2. Expressions containing outer references on plan nodes other than Filter. + def failOnInvalidOuterReference(p: LogicalPlan): Unit = { + p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) + if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { + failAnalysis( + "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + + s"clauses:\n$p") + } + } + + // SPARK-17348: A potential incorrect result case. + // When a correlated predicate is a non-equality predicate, + // certain operators are not permitted from the operator + // hosting the correlated predicate up to the operator on the outer table. + // Otherwise, the pull up of the correlated predicate + // will generate a plan with a different semantics + // which could return incorrect result. + // Currently we check for Aggregate and Window operators + // + // Below shows an example of a Logical Plan during Analyzer phase that + // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] + // through the Aggregate (or Window) operator could alter the result of + // the Aggregate. + // + // Project [c1#76] + // +- Project [c1#87, c2#88] + // : (Aggregate or Window operator) + // : +- Filter [outer(c2#77) >= c2#88)] + // : +- SubqueryAlias t2, `t2` + // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : +- LocalRelation [_1#84, _2#85] + // +- SubqueryAlias t1, `t1` + // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] + // +- LocalRelation [_1#73, _2#74] + def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { + if (found) { + // Report a non-supported case as an exception + failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + } + } + + var foundNonEqualCorrelatedPred: Boolean = false + + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. + BooleanSimplification(sub).foreachUp { + // Whitelist operators allowed in a correlated subquery + // There are 4 categories: + // 1. Operators that are allowed anywhere in a correlated subquery, and, + // by definition of the operators, they either do not contain + // any columns or cannot host outer references. + // 2. Operators that are allowed anywhere in a correlated subquery + // so long as they do not host outer references. + // 3. Operators that need special handlings. These operators are + // Filter, Join, Aggregate, and Generate. + // + // Any operators that are not in the above list are allowed + // in a correlated subquery only if they are not on a correlation path. + // In other word, these operators are allowed only under a correlation point. + // + // A correlation path is defined as the sub-tree of all the operators that + // are on the path from the operator hosting the correlated expressions + // up to the operator producing the correlated values. + + // Category 1: + // ResolvedHint, Distinct, LeafNode, Repartition, and SubqueryAlias + case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => + + // Category 2: + // These operators can be anywhere in a correlated subquery. + // so long as they do not host outer references in the operators. + case p: Project => + failOnInvalidOuterReference(p) + + case s: Sort => + failOnInvalidOuterReference(s) + + case r: RepartitionByExpression => + failOnInvalidOuterReference(r) + + // Category 3: + // Filter is one of the two operators allowed to host correlated expressions. + // The other operator is Join. Filter can be anywhere in a correlated subquery. + case f: Filter => + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) + + // Find any non-equality correlated predicates + foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { + case _: EqualTo | _: EqualNullSafe => false + case _ => true + } + failOnInvalidOuterReference(f) + + // Aggregate cannot host any correlated expressions + // It can be on a correlation path if the correlation contains + // only equality correlated predicates. + // It cannot be on a correlation path if the correlation has + // non-equality correlated predicates. + case a: Aggregate => + failOnInvalidOuterReference(a) + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + + // Join can host correlated expressions. + case j @ Join(left, right, joinType, _) => + joinType match { + // Inner join, like Filter, can be anywhere. + case _: InnerLike => + failOnInvalidOuterReference(j) + + // Left outer join's right operand cannot be on a correlation path. + // LeftAnti and ExistenceJoin are special cases of LeftOuter. + // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame + // so it should not show up here in Analysis phase. This is just a safety net. + // + // LeftSemi does not allow output from the right operand. + // Any correlated references in the subplan + // of the right operand cannot be pulled up. + case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(right) + + // Likewise, Right outer join's left operand cannot be on a correlation path. + case RightOuter => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(left) + + // Any other join types not explicitly listed above, + // including Full outer join, are treated as Category 4. + case _ => + failOnOuterReferenceInSubTree(j) + } + + // Generator with join=true, i.e., expressed with + // LATERAL VIEW [OUTER], similar to inner join, + // allows to have correlation under it + // but must not host any outer references. + // Note: + // Generator with join=false is treated as Category 4. + case g: Generate if g.join => + failOnInvalidOuterReference(g) + + // Category 4: Any other operators not in the above 3 categories + // cannot be on a correlation path, that is they are allowed only + // under a correlation point but they and their descendant operators + // are not allowed to have any correlated expressions. + case p => + failOnOuterReferenceInSubTree(p) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c15ee2ab270bc..f3fe58caa6fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -144,27 +144,39 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case cns: CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - - val mismatchedColumns = valExprs.zip(sub.output).flatMap { - case (l, r) if l.dataType != r.dataType => - s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" - case _ => None - } - - if (mismatchedColumns.nonEmpty) { + if (valExprs.length != sub.output.length) { TypeCheckResult.TypeCheckFailure( s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${valExprs.length}. + |#columns in right hand side: ${sub.output.length}. + |Left side columns: + |[${valExprs.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${sub.output.map(_.sql).mkString(", ")}]. """.stripMargin) } else { - TypeCheckResult.TypeCheckSuccess + val mismatchedColumns = valExprs.zip(sub.output).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } + if (mismatchedColumns.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + """.stripMargin) + } else { + TypeCheckResult.TypeCheckSuccess + } } case _ => if (list.exists(l => l.dataType != value.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5050318d96358..4ed995e20d7ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -111,8 +111,7 @@ class AnalysisErrorSuite extends AnalysisTest { "scalar subquery with 2 columns", testRelation.select( (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), - "The number of columns in the subquery (2)" :: - "does not match the required number of columns (1)":: Nil) + "Scalar subquery must return only one column, but got 2" :: Nil) errorTest( "scalar subquery with no column", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 55693121431a2..1bf8d76da04d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -35,7 +35,7 @@ class ResolveSubquerySuite extends AnalysisTest { test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { - SimpleAnalyzer.ResolveSubquery(expr) + SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage assert(m.contains( "Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses")) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql new file mode 100644 index 0000000000000..b15f4da81dd93 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql @@ -0,0 +1,47 @@ +-- The test file contains negative test cases +-- of invalid queries where error messages are expected. + +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c); + +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c); + +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c); + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.03 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a); + +-- TC 01.04 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a); + diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out new file mode 100644 index 0000000000000..9ea9d3c4c6f40 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -0,0 +1,106 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 4 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1 +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 5 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 1. +#columns in right hand side: 2. +Left side columns: +[t1.`t1a`]. +Right side columns: +[t2.`t2a`, t2.`t2b`]. + ; + + +-- !query 6 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2. +#columns in right hand side: 1. +Left side columns: +[t1.`t1a`, t1.`t1b`]. +Right side columns: +[t2.`t2a`]. + ; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 4629a8c0dbe5f..820cff655c4ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -517,7 +517,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { val msg1 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") } - assert(msg1.getMessage.contains("Correlated scalar subqueries must be Aggregated")) + assert(msg1.getMessage.contains("Correlated scalar subqueries must be aggregated")) val msg2 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") From 03eb6117affcca21798be25706a39e0d5a2f7288 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 23 Jun 2017 14:48:33 -0700 Subject: [PATCH 012/779] [SPARK-21164][SQL] Remove isTableSample from Sample and isGenerated from Alias and AttributeReference ## What changes were proposed in this pull request? `isTableSample` and `isGenerated ` were introduced for SQL Generation respectively by https://github.com/apache/spark/pull/11148 and https://github.com/apache/spark/pull/11050 Since SQL Generation is removed, we do not need to keep `isTableSample`. ## How was this patch tested? The existing test cases Author: Xiao Li Closes #18379 from gatorsmile/CleanSample. --- .../sql/catalyst/analysis/Analyzer.scala | 8 ++--- .../expressions/namedExpressions.scala | 34 +++++++------------ .../optimizer/RewriteDistinctAggregates.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 4 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 6 +--- .../analysis/AnalysisErrorSuite.scala | 2 +- .../analysis/UnsupportedOperationsSuite.scala | 2 +- .../optimizer/ColumnPruningSuite.scala | 8 ++--- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +-- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +++--- .../BasicStatsEstimationSuite.scala | 4 +-- .../scala/org/apache/spark/sql/Dataset.scala | 4 +-- 15 files changed, 40 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 193082eb77024..7e5ebfc93286f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -874,7 +874,7 @@ class Analyzer( def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { - case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) + case a: Alias => Alias(a.child, a.name)() case other => other } } @@ -1368,7 +1368,7 @@ class Analyzer( val aggregatedCondition = Aggregate( grouping, - Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, + Alias(havingCondition, "havingCondition")() :: Nil, child) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = @@ -1424,7 +1424,7 @@ class Analyzer( try { val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) val aliasedOrdering = - unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true)) + unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = @@ -1935,7 +1935,7 @@ class Analyzer( leafNondeterministic.distinct.map { e => val ne = e match { case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")(isGenerated = true) + case _ => Alias(e, "_nondeterministic")() } e -> ne } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c842f85af693c..29c33804f077a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -81,9 +81,6 @@ trait NamedExpression extends Expression { /** Returns the metadata when an expression is a reference to another expression with metadata. */ def metadata: Metadata = Metadata.empty - /** Returns true if the expression is generated by Catalyst */ - def isGenerated: java.lang.Boolean = false - /** Returns a copy of this expression with a new `exprId`. */ def newInstance(): NamedExpression @@ -128,13 +125,11 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. - * @param isGenerated A flag to indicate if this alias is generated by Catalyst */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifier: Option[String] = None, - val explicitMetadata: Option[Metadata] = None, - override val isGenerated: java.lang.Boolean = false) + val explicitMetadata: Option[Metadata] = None) extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) @@ -159,13 +154,11 @@ case class Alias(child: Expression, name: String)( } def newInstance(): NamedExpression = - Alias(child, name)( - qualifier = qualifier, explicitMetadata = explicitMetadata, isGenerated = isGenerated) + Alias(child, name)(qualifier = qualifier, explicitMetadata = explicitMetadata) override def toAttribute: Attribute = { if (resolved) { - AttributeReference(name, child.dataType, child.nullable, metadata)( - exprId, qualifier, isGenerated) + AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifier) } else { UnresolvedAttribute(name) } @@ -174,7 +167,7 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifier :: explicitMetadata :: isGenerated :: Nil + exprId :: qualifier :: explicitMetadata :: Nil } override def hashCode(): Int = { @@ -207,7 +200,6 @@ case class Alias(child: Expression, name: String)( * @param qualifier An optional string that can be used to referred to this attribute in a fully * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. - * @param isGenerated A flag to indicate if this reference is generated by Catalyst */ case class AttributeReference( name: String, @@ -215,8 +207,7 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifier: Option[String] = None, - override val isGenerated: java.lang.Boolean = false) + val qualifier: Option[String] = None) extends Attribute with Unevaluable { /** @@ -253,8 +244,7 @@ case class AttributeReference( } override def newInstance(): AttributeReference = - AttributeReference(name, dataType, nullable, metadata)( - qualifier = qualifier, isGenerated = isGenerated) + AttributeReference(name, dataType, nullable, metadata)(qualifier = qualifier) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -263,7 +253,7 @@ case class AttributeReference( if (nullable == newNullability) { this } else { - AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier, isGenerated) + AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier) } } @@ -271,7 +261,7 @@ case class AttributeReference( if (name == newName) { this } else { - AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier, isGenerated) + AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier) } } @@ -282,7 +272,7 @@ case class AttributeReference( if (newQualifier == qualifier) { this } else { - AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier, isGenerated) + AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier) } } @@ -290,16 +280,16 @@ case class AttributeReference( if (exprId == newExprId) { this } else { - AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier, isGenerated) + AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier) } } override def withMetadata(newMetadata: Metadata): Attribute = { - AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier, isGenerated) + AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier) } override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifier :: isGenerated :: Nil + exprId :: qualifier :: Nil } /** Used to signal the column used to calculate an eventTime watermark (e.g. a#1-T{delayMs}) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 3b27cd2ffe028..4448ace7105a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -134,7 +134,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Aggregation strategy can handle queries with a single distinct group. if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. - val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true) + val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 315c6721b3f65..ef79cbcaa0ce6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -627,7 +627,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, s"Sampling fraction ($fraction) must be on interval [0, 1]", ctx) - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) } ctx.sampleType.getType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index ef925f92ecc7e..7f370fb731b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -80,12 +80,12 @@ object PhysicalOperation extends PredicateHelper { expr.transform { case a @ Alias(ref: AttributeReference, name) => aliases.get(ref) - .map(Alias(_, name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)) + .map(Alias(_, name)(a.exprId, a.qualifier)) .getOrElse(a) case a: AttributeReference => aliases.get(a) - .map(Alias(_, a.name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)).getOrElse(a) + .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 1f6d05bc8d816..01b3da3f7c482 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -200,7 +200,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // normalize that for equality testing, by assigning expr id from 0 incrementally. The // alias name doesn't matter and should be erased. val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes) - Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + Alias(normalizedChild, "")(ExprId(id), a.qualifier) case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => // Top level `AttributeReference` may also be used for output like `Alias`, we should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0c098ac0209e8..0d30aa76049a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -221,7 +221,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { - if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) { + if (resolver(attribute.name, nameParts.head)) { Option((attribute.withName(nameParts.head), nameParts.tail.toList)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d8f89b108e63f..e89caabf252d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -807,15 +807,13 @@ case class SubqueryAlias( * @param withReplacement Whether to sample with replacement. * @param seed the random seed * @param child the LogicalPlan - * @param isTableSample Is created from TABLESAMPLE in the parser. */ case class Sample( lowerBound: Double, upperBound: Double, withReplacement: Boolean, seed: Long, - child: LogicalPlan)( - val isTableSample: java.lang.Boolean = false) extends UnaryNode { + child: LogicalPlan) extends UnaryNode { val eps = RandomSampler.roundingEpsilon val fraction = upperBound - lowerBound @@ -842,8 +840,6 @@ case class Sample( // Don't propagate column stats, because we don't know the distribution after a sample operation Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints) } - - override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 4ed995e20d7ce..7311dc3899e53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -573,7 +573,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan5 = Filter( Exists( Sample(0.0, 0.5, false, 1L, - Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))).select('b) ), LocalRelation(a)) assertAnalysisError(plan5, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index c39e372c272b1..f68d930f60523 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -491,7 +491,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Other unary operations testUnaryOperatorInStreamingPlan( - "sample", Sample(0.1, 1, true, 1L, _)(), expectedMsg = "sampling") + "sample", Sample(0.1, 1, true, 1L, _), expectedMsg = "sampling") testUnaryOperatorInStreamingPlan( "window", Window(Nil, Nil, Nil, _), expectedMsg = "non-time-based windows") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 0b419e9631b29..08e58d47e0e25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -349,14 +349,14 @@ class ColumnPruningSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val x = testRelation.subquery('x) - val query1 = Sample(0.0, 0.6, false, 11L, x)().select('a) + val query1 = Sample(0.0, 0.6, false, 11L, x).select('a) val optimized1 = Optimize.execute(query1.analyze) - val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a))() + val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a)) comparePlans(optimized1, expected1.analyze) - val query2 = Sample(0.0, 0.6, false, 11L, x)().select('a as 'aa) + val query2 = Sample(0.0, 0.6, false, 11L, x).select('a as 'aa) val optimized2 = Optimize.execute(query2.analyze) - val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a))().select('a as 'aa) + val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a)).select('a as 'aa) comparePlans(optimized2, expected2.analyze) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 0a4ae098d65cc..bf15b85d5b510 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -411,9 +411,9 @@ class PlanParserSuite extends AnalysisTest { assertEqual(s"$sql tablesample(100 rows)", table("t").limit(100).select(star())) assertEqual(s"$sql tablesample(43 percent) as x", - Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x")).select(star())) assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", - Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x")).select(star())) intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported") intercept(s"$sql tablesample(bucket 11 out of 10) as x", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 25313af2be184..6883d23d477e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -63,14 +63,14 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { */ protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { - case filter @ Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) + case Filter(condition: Expression, child: LogicalPlan) => + Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And), child) case sample: Sample => - sample.copy(seed = 0L)(true) - case join @ Join(left, right, joinType, condition) if condition.isDefined => + sample.copy(seed = 0L) + case Join(left, right, joinType, condition) if condition.isDefined => val newCondition = - splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) + splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And) Join(left, right, joinType, Some(newCondition)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index e9ed36feec48c..912c5fed63450 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -78,14 +78,14 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { } test("sample estimation") { - val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)() + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan) checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) // Child doesn't have rowCount in stats val childStats = Statistics(sizeInBytes = 120) val childPlan = DummyLogicalPlan(childStats, childStats) val sample2 = - Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)() + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan) checkStats(sample2, Statistics(sizeInBytes = 14)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 767dad3e63a6d..6e66e92091ff9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1807,7 +1807,7 @@ class Dataset[T] private[sql]( */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + Sample(0.0, fraction, withReplacement, seed, logicalPlan) } } @@ -1863,7 +1863,7 @@ class Dataset[T] private[sql]( val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan), encoder) }.toArray } From 7525ce98b4575b1ac4e44cc9b3a5773f03eba19e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 24 Jun 2017 11:39:41 +0800 Subject: [PATCH 013/779] [SPARK-20431][SS][FOLLOWUP] Specify a schema by using a DDL-formatted string in DataStreamReader ## What changes were proposed in this pull request? This pr supported a DDL-formatted string in `DataStreamReader.schema`. This fix could make users easily define a schema without importing the type classes. For example, ```scala scala> spark.readStream.schema("col0 INT, col1 DOUBLE").load("/tmp/abc").printSchema() root |-- col0: integer (nullable = true) |-- col1: double (nullable = true) ``` ## How was this patch tested? Added tests in `DataStreamReaderWriterSuite`. Author: hyukjinkwon Closes #18373 from HyukjinKwon/SPARK-20431. --- python/pyspark/sql/readwriter.py | 2 ++ python/pyspark/sql/streaming.py | 24 ++++++++++++------- .../sql/streaming/DataStreamReader.scala | 12 ++++++++++ .../test/DataStreamReaderWriterSuite.scala | 12 ++++++++++ 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index aef71f9ca7001..7279173df6e4f 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -98,6 +98,8 @@ def schema(self, schema): :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). + + >>> s = spark.read.schema("col0 INT, col1 DOUBLE") """ from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 58aa2468e006d..5bbd70cf0a789 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -319,16 +319,21 @@ def schema(self, schema): .. note:: Evolving. - :param schema: a :class:`pyspark.sql.types.StructType` object + :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string + (For example ``col0 INT, col1 DOUBLE``). >>> s = spark.readStream.schema(sdf_schema) + >>> s = spark.readStream.schema("col0 INT, col1 DOUBLE") """ from pyspark.sql import SparkSession - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") spark = SparkSession.builder.getOrCreate() - jschema = spark._jsparkSession.parseDataType(schema.json()) - self._jreader = self._jreader.schema(jschema) + if isinstance(schema, StructType): + jschema = spark._jsparkSession.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + elif isinstance(schema, basestring): + self._jreader = self._jreader.schema(schema) + else: + raise TypeError("schema should be StructType or string") return self @since(2.0) @@ -372,7 +377,8 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options >>> json_sdf = spark.readStream.format("json") \\ @@ -415,7 +421,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param path: string represents path to the JSON dataset, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -542,7 +549,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non .. note:: Evolving. :param path: string, or list of strings, for input path(s). - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 7e8e6394b4862..70ddfa8e9b835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -59,6 +59,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo this } + /** + * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can + * infer the input schema automatically from data. By specifying the schema here, the underlying + * data source can skip the schema inference step, and thus speed up data loading. + * + * @since 2.3.0 + */ + def schema(schemaString: String): DataStreamReader = { + this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString)) + this + } + /** * Adds an input option for the underlying data source. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index b5f1e28d7396a..3de0ae67a3892 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -663,4 +663,16 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } assert(fs.exists(checkpointDir)) } + + test("SPARK-20431: Specify a schema by using a DDL-formatted string") { + spark.readStream + .format("org.apache.spark.sql.streaming.test") + .schema("aa INT") + .load() + + assert(LastOptions.schema.isDefined) + assert(LastOptions.schema.get === StructType(StructField("aa", IntegerType) :: Nil)) + + LastOptions.clear() + } } From b837bf9ae97cf7ee7558c10a5a34636e69367a05 Mon Sep 17 00:00:00 2001 From: Gabor Feher Date: Fri, 23 Jun 2017 21:53:38 -0700 Subject: [PATCH 014/779] [SPARK-20555][SQL] Fix mapping of Oracle DECIMAL types to Spark types in read path ## What changes were proposed in this pull request? This PR is to revert some code changes in the read path of https://github.com/apache/spark/pull/14377. The original fix is https://github.com/apache/spark/pull/17830 When merging this PR, please give the credit to gaborfeher ## How was this patch tested? Added a test case to OracleIntegrationSuite.scala Author: Gabor Feher Author: gatorsmile Closes #18408 from gatorsmile/OracleType. --- .../sql/jdbc/OracleIntegrationSuite.scala | 65 +++++++++++++------ .../apache/spark/sql/jdbc/OracleDialect.scala | 4 -- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index f7b1ec34ced76..b2f096964427e 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Date, Timestamp} import java.util.Properties +import java.math.BigDecimal import org.apache.spark.sql.Row import org.apache.spark.sql.test.SharedSQLContext @@ -93,8 +94,31 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo |USING org.apache.spark.sql.jdbc |OPTIONS (url '$jdbcUrl', dbTable 'datetime1', oracle.jdbc.mapDateToTimestamp 'false') """.stripMargin.replaceAll("\n", " ")) + + + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate(); + conn.prepareStatement( + "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate(); + conn.commit(); } + + test("SPARK-16625 : Importing Oracle numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties); + val rows = df.collect() + assert(rows.size == 1) + val row = rows(0) + // The main point of the below assertions is not to make sure that these Oracle types are + // mapped to decimal types, but to make sure that the returned values are correct. + // A value > 1 from DECIMAL(1) is correct: + assert(row.getDecimal(0).compareTo(BigDecimal.valueOf(4)) == 0) + // A value with fractions from DECIMAL(3, 2) is correct: + assert(row.getDecimal(1).compareTo(BigDecimal.valueOf(1.23)) == 0) + // A value > Int.MaxValue from DECIMAL(10) is correct: + assert(row.getDecimal(2).compareTo(BigDecimal.valueOf(9999999999l)) == 0) + } + + test("SPARK-12941: String datatypes to be mapped to Varchar in Oracle") { // create a sample dataframe with string type val df1 = sparkContext.parallelize(Seq(("foo"))).toDF("x") @@ -154,27 +178,28 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo val dfRead = spark.read.jdbc(jdbcUrl, tableName, props) val rows = dfRead.collect() // verify the data type is inserted - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Long")) - assert(types(3).equals("class java.lang.Float")) - assert(types(4).equals("class java.lang.Float")) - assert(types(5).equals("class java.lang.Integer")) - assert(types(6).equals("class java.lang.Integer")) - assert(types(7).equals("class java.lang.String")) - assert(types(8).equals("class [B")) - assert(types(9).equals("class java.sql.Date")) - assert(types(10).equals("class java.sql.Timestamp")) + val types = dfRead.schema.map(field => field.dataType) + assert(types(0).equals(DecimalType(1, 0))) + assert(types(1).equals(DecimalType(10, 0))) + assert(types(2).equals(DecimalType(19, 0))) + assert(types(3).equals(DecimalType(19, 4))) + assert(types(4).equals(DecimalType(19, 4))) + assert(types(5).equals(DecimalType(3, 0))) + assert(types(6).equals(DecimalType(5, 0))) + assert(types(7).equals(StringType)) + assert(types(8).equals(BinaryType)) + assert(types(9).equals(DateType)) + assert(types(10).equals(TimestampType)) + // verify the value is the inserted correct or not val values = rows(0) - assert(values.getBoolean(0).equals(booleanVal)) - assert(values.getInt(1).equals(integerVal)) - assert(values.getLong(2).equals(longVal)) - assert(values.getFloat(3).equals(floatVal)) - assert(values.getFloat(4).equals(doubleVal.toFloat)) - assert(values.getInt(5).equals(byteVal.toInt)) - assert(values.getInt(6).equals(shortVal.toInt)) + assert(values.getDecimal(0).compareTo(BigDecimal.valueOf(1)) == 0) + assert(values.getDecimal(1).compareTo(BigDecimal.valueOf(integerVal)) == 0) + assert(values.getDecimal(2).compareTo(BigDecimal.valueOf(longVal)) == 0) + assert(values.getDecimal(3).compareTo(BigDecimal.valueOf(floatVal)) == 0) + assert(values.getDecimal(4).compareTo(BigDecimal.valueOf(doubleVal)) == 0) + assert(values.getDecimal(5).compareTo(BigDecimal.valueOf(byteVal)) == 0) + assert(values.getDecimal(6).compareTo(BigDecimal.valueOf(shortVal)) == 0) assert(values.getString(7).equals(stringVal)) assert(values.getAs[Array[Byte]](8).mkString.equals("678")) assert(values.getDate(9).equals(dateVal)) @@ -183,7 +208,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo test("SPARK-19318: connection property keys should be case-sensitive") { def checkRow(row: Row): Unit = { - assert(row.getInt(0) == 1) + assert(row.getDecimal(0).equals(BigDecimal.valueOf(1))) assert(row.getDate(1).equals(Date.valueOf("1991-11-09"))) assert(row.getTimestamp(2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index f541996b651e9..20e634c06b610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -43,10 +43,6 @@ private case object OracleDialect extends JdbcDialect { // Not sure if there is a more robust way to identify the field as a float (or other // numeric types that do not specify a scale. case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case 1 => Option(BooleanType) - case 3 | 5 | 10 => Option(IntegerType) - case 19 if scale == 0L => Option(LongType) - case 19 if scale == 4L => Option(FloatType) case _ => None } } else { From bfd73a7c48b87456d1b84d826e04eca938a1be64 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 24 Jun 2017 13:23:43 +0800 Subject: [PATCH 015/779] [SPARK-21159][CORE] Don't try to connect to launcher in standalone cluster mode. Monitoring for standalone cluster mode is not implemented (see SPARK-11033), but the same scheduler implementation is used, and if it tries to connect to the launcher it will fail. So fix the scheduler so it only tries that in client mode; cluster mode applications will be correctly launched and will work, but monitoring through the launcher handle will not be available. Tested by running a cluster mode app with "SparkLauncher.startApplication". Author: Marcelo Vanzin Closes #18397 from vanzin/SPARK-21159. --- .../scheduler/cluster/StandaloneSchedulerBackend.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index fd8e64454bf70..a4e2a74341283 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -58,7 +58,13 @@ private[spark] class StandaloneSchedulerBackend( override def start() { super.start() - launcherBackend.connect() + + // SPARK-21159. The scheduler backend should only try to connect to the launcher when in client + // mode. In cluster mode, the code that submits the application to the Master needs to connect + // to the launcher instead. + if (sc.deployMode == "client") { + launcherBackend.connect() + } // The endpoint for executors to talk to us val driverUrl = RpcEndpointAddress( From 7c7bc8fc0ff85fe70968b47433bb7757326a6b12 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 24 Jun 2017 10:14:31 +0100 Subject: [PATCH 016/779] [SPARK-21189][INFRA] Handle unknown error codes in Jenkins rather then leaving incomplete comment in PRs ## What changes were proposed in this pull request? Recently, Jenkins tests were unstable due to unknown reasons as below: ``` /home/jenkins/workspace/SparkPullRequestBuilder/dev/lint-r ; process was terminated by signal 9 test_result_code, test_result_note = run_tests(tests_timeout) File "./dev/run-tests-jenkins.py", line 140, in run_tests test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] KeyError: -9 ``` ``` Traceback (most recent call last): File "./dev/run-tests-jenkins.py", line 226, in main() File "./dev/run-tests-jenkins.py", line 213, in main test_result_code, test_result_note = run_tests(tests_timeout) File "./dev/run-tests-jenkins.py", line 140, in run_tests test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] KeyError: -10 ``` This exception looks causing failing to update the comments in the PR. For example: ![2017-06-23 4 19 41](https://user-images.githubusercontent.com/6477701/27470626-d035ecd8-582f-11e7-883e-0ae6941659b7.png) ![2017-06-23 4 19 50](https://user-images.githubusercontent.com/6477701/27470629-d11ba782-582f-11e7-97e0-64d28cbc19aa.png) these comment just remain. This always requires, for both reviewers and the author, a overhead to click and check the logs, which I believe are not really useful. This PR proposes to leave the code in the PR comment messages and let update the comments. ## How was this patch tested? Jenkins tests below, I manually gave the error code to test this. Author: hyukjinkwon Closes #18399 from HyukjinKwon/jenkins-print-errors. --- dev/run-tests-jenkins.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 53061bc947e5f..914eb93622d51 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -137,7 +137,9 @@ def run_tests(tests_timeout): if test_result_code == 0: test_result_note = ' * This patch passes all tests.' else: - test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] + note = failure_note_by_errcode.get( + test_result_code, "due to an unknown error code, %s" % test_result_code) + test_result_note = ' * This patch **fails %s**.' % note return [test_result_code, test_result_note] From 2e1586f60a77ea0adb6f3f68ba74323f0c242199 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 24 Jun 2017 22:35:59 +0800 Subject: [PATCH 017/779] [SPARK-21203][SQL] Fix wrong results of insertion of Array of Struct ### What changes were proposed in this pull request? ```SQL CREATE TABLE `tab1` (`custom_fields` ARRAY>) USING parquet INSERT INTO `tab1` SELECT ARRAY(named_struct('id', 1, 'value', 'a'), named_struct('id', 2, 'value', 'b')) SELECT custom_fields.id, custom_fields.value FROM tab1 ``` The above query always return the last struct of the array, because the rule `SimplifyCasts` incorrectly rewrites the query. The underlying cause is we always use the same `GenericInternalRow` object when doing the cast. ### How was this patch tested? Author: gatorsmile Closes #18412 from gatorsmile/castStruct. --- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++-- .../spark/sql/sources/InsertSuite.scala | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a53ef426f79b5..43df19ba009a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -482,15 +482,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericInternalRow(from.fields.length) buildCast[InternalRow](_, row => { + val newRow = new GenericInternalRow(from.fields.length) var i = 0 while (i < row.numFields) { newRow.update(i, if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType))) i += 1 } - newRow.copy() + newRow }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 2eae66dda88de..41abff2a5da25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -345,4 +345,25 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } } + + test("SPARK-21203 wrong results of insertion of Array of Struct") { + val tabName = "tab1" + withTable(tabName) { + spark.sql( + """ + |CREATE TABLE `tab1` + |(`custom_fields` ARRAY>) + |USING parquet + """.stripMargin) + spark.sql( + """ + |INSERT INTO `tab1` + |SELECT ARRAY(named_struct('id', 1, 'value', 'a'), named_struct('id', 2, 'value', 'b')) + """.stripMargin) + + checkAnswer( + spark.sql("SELECT custom_fields.id, custom_fields.value FROM tab1"), + Row(Array(1, 2), Array("a", "b"))) + } + } } From b449a1d6aa322a50cf221cd7a2ae85a91d6c7e9f Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Sat, 24 Jun 2017 22:49:35 -0700 Subject: [PATCH 018/779] [SPARK-21079][SQL] Calculate total size of a partition table as a sum of individual partitions ## What changes were proposed in this pull request? Storage URI of a partitioned table may or may not point to a directory under which individual partitions are stored. In fact, individual partitions may be located in totally unrelated directories. Before this change, ANALYZE TABLE table COMPUTE STATISTICS command calculated total size of a table by adding up sizes of files found under table's storage URI. This calculation could produce 0 if partitions are stored elsewhere. This change uses storage URIs of individual partitions to calculate the sizes of all partitions of a table and adds these up to produce the total size of a table. CC: wzhfy ## How was this patch tested? Added unit test. Ran ANALYZE TABLE xxx COMPUTE STATISTICS on a partitioned Hive table and verified that sizeInBytes is calculated correctly. Before this change, the size would be zero. Author: Masha Basmanova Closes #18309 from mbasmanova/mbasmanova-analyze-part-table. --- .../command/AnalyzeTableCommand.scala | 29 ++++++-- .../spark/sql/hive/StatisticsSuite.scala | 72 +++++++++++++++++++ 2 files changed, 95 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 3c59b982c2dca..06e588f56f1e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.net.URI + import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -81,6 +83,21 @@ case class AnalyzeTableCommand( object AnalyzeTableCommand extends Logging { def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { + if (catalogTable.partitionColumnNames.isEmpty) { + calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) + } else { + // Calculate table size as a sum of the visible partitions. See SPARK-21079 + val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) + partitions.map(p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + ).sum + } + } + + private def calculateLocationSize( + sessionState: SessionState, + tableId: TableIdentifier, + locationUri: Option[URI]): Long = { // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) // in Hive 0.13 (except that we do not use fs.getContentSummary). @@ -91,13 +108,13 @@ object AnalyzeTableCommand extends Logging { // countFileSize to count the table size. val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - def calculateTableSize(fs: FileSystem, path: Path): Long = { + def calculateLocationSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => if (!status.getPath.getName.startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) + calculateLocationSize(fs, status.getPath) } else { 0L } @@ -109,16 +126,16 @@ object AnalyzeTableCommand extends Logging { size } - catalogTable.storage.locationUri.map { p => + locationUri.map { p => val path = new Path(p) try { val fs = path.getFileSystem(sessionState.newHadoopConf()) - calculateTableSize(fs, path) + calculateLocationSize(fs, path) } catch { case NonFatal(e) => logWarning( - s"Failed to get the size of table ${catalogTable.identifier.table} in the " + - s"database ${catalogTable.identifier.database} because of ${e.toString}", e) + s"Failed to get the size of table ${tableId.table} in the " + + s"database ${tableId.database} because of ${e.toString}", e) 0L } }.getOrElse(0L) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 279db9a397258..0ee18bbe9befe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { @@ -128,6 +129,77 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } + test("SPARK-21079 - analyze table with location different than that of individual partitions") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val tableName = "analyzeTable_part" + withTable(tableName) { + withTempPath { path => + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') SELECT * FROM src") + } + + sql(s"ALTER TABLE $tableName SET LOCATION '$path'") + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + + assert(queryTotalSize(tableName) === BigInt(17436)) + } + } + } + + test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val sourceTableName = "analyzeTable_part" + val tableName = "analyzeTable_part_vis" + withTable(sourceTableName, tableName) { + withTempPath { path => + // Create a table with 3 partitions all located under a single top-level directory 'path' + sql( + s""" + |CREATE TABLE $sourceTableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql( + s""" + |INSERT INTO TABLE $sourceTableName PARTITION (ds='$ds') + |SELECT * FROM src + """.stripMargin) + } + + // Create another table referring to the same location + sql( + s""" + |CREATE TABLE $tableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + // Register only one of the partitions found on disk + val ds = partitionDates.head + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')").collect() + + // Analyze original table - expect 3 partitions + sql(s"ANALYZE TABLE $sourceTableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(sourceTableName) === BigInt(3 * 5812)) + + // Analyze partial-copy table - expect only 1 partition + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(tableName) === BigInt(5812)) + } + } + } + test("analyzing views is not supported") { def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { val err = intercept[AnalysisException] { From 884347e1f79e4e7c157834881e79447d7ee58f88 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sun, 25 Jun 2017 15:06:29 +0100 Subject: [PATCH 019/779] [HOT FIX] fix stats functions in the recent patch ## What changes were proposed in this pull request? Builds failed due to the recent [merge](https://github.com/apache/spark/commit/b449a1d6aa322a50cf221cd7a2ae85a91d6c7e9f). This is because [PR#18309](https://github.com/apache/spark/pull/18309) needed update after [this patch](https://github.com/apache/spark/commit/b803b66a8133f705463039325ee71ee6827ce1a7) was merged. ## How was this patch tested? N/A Author: Zhenhua Wang Closes #18415 from wzhfy/hotfixStats. --- .../scala/org/apache/spark/sql/hive/StatisticsSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 0ee18bbe9befe..64deb3818d5d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { @@ -131,7 +130,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("SPARK-21079 - analyze table with location different than that of individual partitions") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes val tableName = "analyzeTable_part" withTable(tableName) { @@ -154,7 +153,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes val sourceTableName = "analyzeTable_part" val tableName = "analyzeTable_part_vis" From 6b3d02285ee0debc73cbcab01b10398a498fbeb8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 25 Jun 2017 11:05:57 -0700 Subject: [PATCH 020/779] [SPARK-21093][R] Terminate R's worker processes in the parent of R's daemon to prevent a leak ## What changes were proposed in this pull request? `mcfork` in R looks opening a pipe ahead but the existing logic does not properly close it when it is executed hot. This leads to the failure of more forking due to the limit for number of files open. This hot execution looks particularly for `gapply`/`gapplyCollect`. For unknown reason, this happens more easily in CentOS and could be reproduced in Mac too. All the details are described in https://issues.apache.org/jira/browse/SPARK-21093 This PR proposes simply to terminate R's worker processes in the parent of R's daemon to prevent a leak. ## How was this patch tested? I ran the codes below on both CentOS and Mac with that configuration disabled/enabled. ```r df <- createDataFrame(list(list(1L, 1, "1", 0.1)), c("a", "b", "c", "d")) collect(gapply(df, "a", function(key, x) { x }, schema(df))) collect(gapply(df, "a", function(key, x) { x }, schema(df))) ... # 30 times ``` Also, now it passes R tests on CentOS as below: ``` SparkSQL functions: Spark package found in SPARK_HOME: .../spark .............................................................................................................................................................. .............................................................................................................................................................. .............................................................................................................................................................. .............................................................................................................................................................. .............................................................................................................................................................. .................................................................................................................................... ``` Author: hyukjinkwon Closes #18320 from HyukjinKwon/SPARK-21093. --- R/pkg/inst/worker/daemon.R | 59 +++++++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3a318b71ea06d..6e385b2a27622 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,8 +30,55 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) +# Waits indefinitely for a socket connecion by default. +selectTimeout <- NULL + +# Exit code that children send to the parent to indicate they exited. +exitCode <- 1 + while (TRUE) { - ready <- socketSelect(list(inputCon)) + ready <- socketSelect(list(inputCon), timeout = selectTimeout) + + # Note that the children should be terminated in the parent. If each child terminates + # itself, it appears that the resource is not released properly, that causes an unexpected + # termination of this daemon due to, for example, running out of file descriptors + # (see SPARK-21093). Therefore, the current implementation tries to retrieve children + # that are exited (but not terminated) and then sends a kill signal to terminate them properly + # in the parent. + # + # There are two paths that it attempts to send a signal to terminate the children in the parent. + # + # 1. Every second if any socket connection is not available and if there are child workers + # running. + # 2. Right after a socket connection is available. + # + # In other words, the parent attempts to send the signal to the children every second if + # any worker is running or right before launching other worker children from the following + # new socket connection. + + # Only the process IDs of children sent data to the parent are returned below. The children + # send a custom exit code to the parent after being exited and the parent tries + # to terminate them only if they sent the exit code. + children <- parallel:::selectChildren(timeout = 0) + + if (is.integer(children)) { + lapply(children, function(child) { + # This data should be raw bytes if any data was sent from this child. + # Otherwise, this returns the PID. + data <- parallel:::readChild(child) + if (is.raw(data)) { + # This checks if the data from this child is the exit code that indicates an exited child. + if (unserialize(data) == exitCode) { + # If so, we terminate this child. + tools::pskill(child, tools::SIGUSR1) + } + } + }) + } else if (is.null(children)) { + # If it is NULL, there are no children. Waits indefinitely for a socket connecion. + selectTimeout <- NULL + } + if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -44,12 +91,16 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { + # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Set SIGUSR1 so that child can exit - tools::pskill(Sys.getpid(), tools::SIGUSR1) - parallel:::mcexit(0L) + # Note that this mcexit does not fully terminate this child. So, this writes back + # a custom exit code so that the parent can read and terminate this child. + parallel:::mcexit(0L, send = exitCode) + } else { + # Forking succeeded and we need to check if they finished their jobs every second. + selectTimeout <- 1 } } } From 5282bae0408dec8aa0cefafd7673dd34d232ead9 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 26 Jun 2017 01:26:32 -0700 Subject: [PATCH 021/779] [SPARK-21153] Use project instead of expand in tumbling windows ## What changes were proposed in this pull request? Time windowing in Spark currently performs an Expand + Filter, because there is no way to guarantee the amount of windows a timestamp will fall in, in the general case. However, for tumbling windows, a record is guaranteed to fall into a single bucket. In this case, doubling the number of records with Expand is wasteful, and can be improved by using a simple Projection instead. Benchmarks show that we get an order of magnitude performance improvement after this patch. ## How was this patch tested? Existing unit tests. Benchmarked using the following code: ```scala import org.apache.spark.sql.functions._ spark.time { spark.range(numRecords) .select(from_unixtime((current_timestamp().cast("long") * 1000 + 'id / 1000) / 1000) as 'time) .select(window('time, "10 seconds")) .count() } ``` Setup: - 1 c3.2xlarge worker (8 cores) ![image](https://user-images.githubusercontent.com/5243515/27348748-ed991b84-55a9-11e7-8f8b-6e7abc524417.png) 1 B rows ran in 287 seconds after this optimization. I didn't wait for it to finish without the optimization. Shows about 5x improvement for large number of records. Author: Burak Yavuz Closes #18364 from brkyvz/opt-tumble. --- .../sql/catalyst/analysis/Analyzer.scala | 72 +++++++++++++------ .../sql/catalyst/expressions/TimeWindow.scala | 12 ++-- .../sql/DataFrameTimeWindowingSuite.scala | 49 +++++++++---- 3 files changed, 94 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7e5ebfc93286f..434b6ffee37fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2301,6 +2301,7 @@ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { object TimeWindowing extends Rule[LogicalPlan] { import org.apache.spark.sql.catalyst.dsl.expressions._ + private final val WINDOW_COL_NAME = "window" private final val WINDOW_START = "start" private final val WINDOW_END = "end" @@ -2336,49 +2337,76 @@ object TimeWindowing extends Rule[LogicalPlan] { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = - p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct. + p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet // Only support a single window expression for now if (windowExpressions.size == 1 && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { + val window = windowExpressions.head val metadata = window.timeColumn match { case a: Attribute => a.metadata case _ => Metadata.empty } - val windowAttr = - AttributeReference("window", window.dataType, metadata = metadata)() - - val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt - val windows = Seq.tabulate(maxNumOverlapping + 1) { i => - val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / - window.slideDuration) - val windowStart = (windowId + i - maxNumOverlapping) * - window.slideDuration + window.startTime + + def getWindow(i: Int, overlappingWindows: Int): Expression = { + val division = (PreciseTimestampConversion( + window.timeColumn, TimestampType, LongType) - window.startTime) / window.slideDuration + val ceil = Ceil(division) + // if the division is equal to the ceiling, our record is the start of a window + val windowId = CaseWhen(Seq((ceil === division, ceil + 1)), Some(ceil)) + val windowStart = (windowId + i - overlappingWindows) * + window.slideDuration + window.startTime val windowEnd = windowStart + window.windowDuration CreateNamedStruct( - Literal(WINDOW_START) :: windowStart :: - Literal(WINDOW_END) :: windowEnd :: Nil) + Literal(WINDOW_START) :: + PreciseTimestampConversion(windowStart, LongType, TimestampType) :: + Literal(WINDOW_END) :: + PreciseTimestampConversion(windowEnd, LongType, TimestampType) :: + Nil) } - val projections = windows.map(_ +: p.children.head.output) + val windowAttr = AttributeReference( + WINDOW_COL_NAME, window.dataType, metadata = metadata)() + + if (window.windowDuration == window.slideDuration) { + val windowStruct = Alias(getWindow(0, 1), WINDOW_COL_NAME)( + exprId = windowAttr.exprId) + + val replacedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(window.timeColumn) - val filterExpr = - window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(windowStruct +: child.output, child)) :: Nil) + } else { + val overlappingWindows = + math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = + Seq.tabulate(overlappingWindows)(i => getWindow(i, overlappingWindows)) + + val projections = windows.map(_ +: child.output) + + val filterExpr = + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) - val expandedPlan = - Filter(filterExpr, + val substitutedPlan = Filter(filterExpr, Expand(projections, windowAttr +: child.output, child)) - val substitutedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } + val renamedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } - substitutedPlan.withNewChildren(expandedPlan :: Nil) + renamedPlan.withNewChildren(substitutedPlan :: Nil) + } } else if (windowExpressions.size > 1) { p.failAnalysis("Multiple time window expressions would result in a cartesian product " + "of rows, therefore they are currently not supported.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 7ff61ee479452..9a9f579b37f58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -152,12 +152,15 @@ object TimeWindow { } /** - * Expression used internally to convert the TimestampType to Long without losing + * Expression used internally to convert the TimestampType to Long and back without losing * precision, i.e. in microseconds. Used in time windowing. */ -case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) - override def dataType: DataType = LongType +case class PreciseTimestampConversion( + child: Expression, + fromType: DataType, + toType: DataType) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(fromType) + override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + @@ -165,4 +168,5 @@ case class PreciseTimestamp(child: Expression) extends UnaryExpression with Expe |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } + override def nullSafeEval(input: Any): Any = input } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 22d5c47a6fb51..6fe356877c268 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql -import java.util.TimeZone - import org.scalatest.BeforeAndAfterEach +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StringType @@ -29,11 +28,27 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B import testImplicits._ + test("simple tumbling window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), ("2016-03-27 19:39:56", 2, "a"), ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + checkAnswer( df.groupBy(window($"time", "10 seconds")) .agg(count("*").as("counts")) @@ -59,14 +74,18 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("tumbling window with multi-column projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Tumbling windows shouldn't require expand") checkAnswer( - df.select(window($"time", "10 seconds"), $"value") - .orderBy($"window.start".asc) - .select($"window.start".cast("string"), $"window.end".cast("string"), $"value"), + df, Seq( Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), @@ -104,13 +123,17 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("sliding window projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.nonEmpty, "Sliding windows require expand") checkAnswer( - df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") - .orderBy($"window.start".asc, $"value".desc).select("value"), + df, // 2016-03-27 19:39:27 UTC -> 4 bins // 2016-03-27 19:39:34 UTC -> 3 bins // 2016-03-27 19:39:56 UTC -> 3 bins From 9e50a1d37a4cf0c34e20a7c1a910ceaff41535a2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 26 Jun 2017 11:14:03 -0500 Subject: [PATCH 022/779] [SPARK-13669][SPARK-20898][CORE] Improve the blacklist mechanism to handle external shuffle service unavailable situation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently we are running into an issue with Yarn work preserving enabled + external shuffle service. In the work preserving enabled scenario, the failure of NM will not lead to the exit of executors, so executors can still accept and run the tasks. The problem here is when NM is failed, external shuffle service is actually inaccessible, so reduce tasks will always complain about the “Fetch failure”, and the failure of reduce stage will make the parent stage (map stage) rerun. The tricky thing here is Spark scheduler is not aware of the unavailability of external shuffle service, and will reschedule the map tasks on the executor where NM is failed, and again reduce stage will be failed with “Fetch failure”, and after 4 retries, the job is failed. This could also apply to other cluster manager with external shuffle service. So here the main problem is that we should avoid assigning tasks to those bad executors (where shuffle service is unavailable). Current Spark's blacklist mechanism could blacklist executors/nodes by failure tasks, but it doesn't handle this specific fetch failure scenario. So here propose to improve the current application blacklist mechanism to handle fetch failure issue (especially with external shuffle service unavailable issue), to blacklist the executors/nodes where shuffle fetch is unavailable. ## How was this patch tested? Unit test and small cluster verification. Author: jerryshao Closes #17113 from jerryshao/SPARK-13669. --- .../spark/internal/config/package.scala | 5 + .../spark/scheduler/BlacklistTracker.scala | 95 ++++++++++++++----- .../spark/scheduler/TaskSchedulerImpl.scala | 18 +--- .../spark/scheduler/TaskSetManager.scala | 6 ++ .../scheduler/BlacklistTrackerSuite.scala | 55 +++++++++++ .../scheduler/TaskSchedulerImplSuite.scala | 4 +- .../spark/scheduler/TaskSetManagerSuite.scala | 32 +++++++ docs/configuration.md | 9 ++ 8 files changed, 186 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 462c1890fd8df..be63c637a3a13 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -149,6 +149,11 @@ package object config { .internal() .timeConf(TimeUnit.MILLISECONDS) .createOptional + + private[spark] val BLACKLIST_FETCH_FAILURE_ENABLED = + ConfigBuilder("spark.blacklist.application.fetchFailure.enabled") + .booleanConf + .createWithDefault(false) // End blacklist confs private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index e130e609e4f63..cd8e61d6d0208 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -61,6 +61,7 @@ private[scheduler] class BlacklistTracker ( private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC) private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE) val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf) + private val BLACKLIST_FETCH_FAILURE_ENABLED = conf.get(config.BLACKLIST_FETCH_FAILURE_ENABLED) /** * A map from executorId to information on task failures. Tracks the time of each task failure, @@ -145,6 +146,74 @@ private[scheduler] class BlacklistTracker ( nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) } + private def killBlacklistedExecutor(exec: String): Unit = { + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(a) => + logInfo(s"Killing blacklisted executor id $exec " + + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") + a.killExecutors(Seq(exec), true, true) + case None => + logWarning(s"Not attempting to kill blacklisted executor id $exec " + + s"since allocation client is not defined.") + } + } + } + + private def killExecutorsOnBlacklistedNode(node: String): Unit = { + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(a) => + logInfo(s"Killing all executors on blacklisted host $node " + + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") + if (a.killExecutorsOnHost(node) == false) { + logError(s"Killing executors on node $node failed.") + } + case None => + logWarning(s"Not attempting to kill executors on blacklisted host $node " + + s"since allocation client is not defined.") + } + } + } + + def updateBlacklistForFetchFailure(host: String, exec: String): Unit = { + if (BLACKLIST_FETCH_FAILURE_ENABLED) { + // If we blacklist on fetch failures, we are implicitly saying that we believe the failure is + // non-transient, and can't be recovered from (even if this is the first fetch failure, + // stage is retried after just one failure, so we don't always get a chance to collect + // multiple fetch failures). + // If the external shuffle-service is on, then every other executor on this node would + // be suffering from the same issue, so we should blacklist (and potentially kill) all + // of them immediately. + + val now = clock.getTimeMillis() + val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS + + if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { + if (!nodeIdToBlacklistExpiryTime.contains(host)) { + logInfo(s"blacklisting node $host due to fetch failure of external shuffle service") + + nodeIdToBlacklistExpiryTime.put(host, expiryTimeForNewBlacklists) + listenerBus.post(SparkListenerNodeBlacklisted(now, host, 1)) + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + killExecutorsOnBlacklistedNode(host) + updateNextExpiryTime() + } + } else if (!executorIdToBlacklistStatus.contains(exec)) { + logInfo(s"Blacklisting executor $exec due to fetch failure") + + executorIdToBlacklistStatus.put(exec, BlacklistedExecutor(host, expiryTimeForNewBlacklists)) + // We hardcoded number of failure tasks to 1 for fetch failure, because there's no + // reattempt for such failure. + listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, 1)) + updateNextExpiryTime() + killBlacklistedExecutor(exec) + + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) + blacklistedExecsOnNode += exec + } + } + } def updateBlacklistForSuccessfulTaskSet( stageId: Int, @@ -174,17 +243,7 @@ private[scheduler] class BlacklistTracker ( listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, newTotal)) executorIdToFailureList.remove(exec) updateNextExpiryTime() - if (conf.get(config.BLACKLIST_KILL_ENABLED)) { - allocationClient match { - case Some(allocationClient) => - logInfo(s"Killing blacklisted executor id $exec " + - s"since spark.blacklist.killBlacklistedExecutors is set.") - allocationClient.killExecutors(Seq(exec), true, true) - case None => - logWarning(s"Not attempting to kill blacklisted executor id $exec " + - s"since allocation client is not defined.") - } - } + killBlacklistedExecutor(exec) // In addition to blacklisting the executor, we also update the data for failures on the // node, and potentially put the entire node into a blacklist as well. @@ -199,19 +258,7 @@ private[scheduler] class BlacklistTracker ( nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists) listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size)) _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) - if (conf.get(config.BLACKLIST_KILL_ENABLED)) { - allocationClient match { - case Some(allocationClient) => - logInfo(s"Killing all executors on blacklisted host $node " + - s"since spark.blacklist.killBlacklistedExecutors is set.") - if (allocationClient.killExecutorsOnHost(node) == false) { - logError(s"Killing executors on node $node failed.") - } - case None => - logWarning(s"Not attempting to kill executors on blacklisted host $node " + - s"since allocation client is not defined.") - } - } + killExecutorsOnBlacklistedNode(node) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bba0b294f1afb..91ec172ffeda1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -51,29 +51,21 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class TaskSchedulerImpl private[scheduler]( +private[spark] class TaskSchedulerImpl( val sc: SparkContext, val maxTaskFailures: Int, - private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) extends TaskScheduler with Logging { import TaskSchedulerImpl._ def this(sc: SparkContext) = { - this( - sc, - sc.conf.get(config.MAX_TASK_FAILURES), - TaskSchedulerImpl.maybeCreateBlacklistTracker(sc)) + this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) } - def this(sc: SparkContext, maxTaskFailures: Int, isLocal: Boolean) = { - this( - sc, - maxTaskFailures, - TaskSchedulerImpl.maybeCreateBlacklistTracker(sc), - isLocal = isLocal) - } + // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient, + // because ExecutorAllocationClient is created after this TaskSchedulerImpl. + private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc) val conf = sc.conf diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a41b059fa7dec..02d374dc37cd5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -774,6 +774,12 @@ private[spark] class TaskSetManager( tasksSuccessful += 1 } isZombie = true + + if (fetchFailed.bmAddress != null) { + blacklistTracker.foreach(_.updateBlacklistForFetchFailure( + fetchFailed.bmAddress.host, fetchFailed.bmAddress.executorId)) + } + None case ef: ExceptionFailure => diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 571c6bbb4585d..7ff03c44b0611 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -530,4 +530,59 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M verify(allocationClientMock).killExecutors(Seq("2"), true, true) verify(allocationClientMock).killExecutorsOnHost("hostA") } + + test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { + val allocationClientMock = mock[ExecutorAllocationClient] + when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { + // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist + // is updated before we ask the executor allocation client to kill all the executors + // on a particular host. + override def answer(invocation: InvocationOnMock): Boolean = { + if (blacklist.nodeBlacklist.contains("hostA") == false) { + throw new IllegalStateException("hostA should be on the blacklist") + } + true + } + }) + + conf.set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + + // Disable auto-kill. Blacklist an executor and make sure killExecutors is not called. + conf.set(config.BLACKLIST_KILL_ENABLED, false) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") + + verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutorsOnHost(any()) + + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. + conf.set(config.BLACKLIST_KILL_ENABLED, true) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + clock.advance(1000) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") + + verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock, never).killExecutorsOnHost(any()) + + assert(blacklist.executorIdToBlacklistStatus.contains("1")) + assert(blacklist.executorIdToBlacklistStatus("1").node === "hostA") + assert(blacklist.executorIdToBlacklistStatus("1").expiryTime === + 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty) + + // Enable external shuffle service to see if all the executors on this node will be killed. + conf.set(config.SHUFFLE_SERVICE_ENABLED, true) + clock.advance(1000) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "2") + + verify(allocationClientMock, never).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutorsOnHost("hostA") + + assert(blacklist.nodeIdToBlacklistExpiryTime.contains("hostA")) + assert(blacklist.nodeIdToBlacklistExpiryTime("hostA") === + 2000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 8b9d45f734cda..a00337776dadc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -87,7 +87,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B conf.set(config.BLACKLIST_ENABLED, true) sc = new SparkContext(conf) taskScheduler = - new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4), Some(blacklist)) { + new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4)) { override def createTaskSetManager(taskSet: TaskSet, maxFailures: Int): TaskSetManager = { val tsm = super.createTaskSetManager(taskSet, maxFailures) // we need to create a spied tsm just so we can set the TaskSetBlacklist @@ -98,6 +98,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B stageToMockTaskSetBlacklist(taskSet.stageId) = taskSetBlacklist tsmSpy } + + override private[scheduler] lazy val blacklistTrackerOpt = Some(blacklist) } setupHelper() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index db14c9acfdce5..80fb674725814 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1140,6 +1140,38 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) } + test("update application blacklist for shuffle-fetch") { + // Setup a taskset, and fail some one task for fetch failure. + val conf = new SparkConf() + .set(config.BLACKLIST_ENABLED, true) + .set(config.SHUFFLE_SERVICE_ENABLED, true) + .set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true) + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + val blacklistTracker = new BlacklistTracker(sc, None) + val tsm = new TaskSetManager(sched, taskSet, 4, Some(blacklistTracker)) + + // make some offers to our taskset, to get tasks we will fail + val taskDescs = Seq( + "exec1" -> "host1", + "exec2" -> "host2" + ).flatMap { case (exec, host) => + // offer each executor twice (simulating 2 cores per executor) + (0 until 2).flatMap{ _ => tsm.resourceOffer(exec, host, TaskLocality.ANY)} + } + assert(taskDescs.size === 4) + + assert(!blacklistTracker.isExecutorBlacklisted(taskDescs(0).executorId)) + assert(!blacklistTracker.isNodeBlacklisted("host1")) + + // Fail the task with fetch failure + tsm.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, + FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + + assert(blacklistTracker.isNodeBlacklisted("host1")) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { diff --git a/docs/configuration.md b/docs/configuration.md index f4bec589208be..c8e61537a457c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1479,6 +1479,15 @@ Apart from these, the following properties are also available, and may be useful all of the executors on that node will be killed. + + spark.blacklist.application.fetchFailure.enabled + false + + (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch + failure happenes. If external shuffle service is enabled, then the whole node will be + blacklisted. + + spark.speculation false From c22810004fb2db249be6477c9801d09b807af851 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 27 Jun 2017 02:35:51 +0800 Subject: [PATCH 023/779] [SPARK-20213][SQL][FOLLOW-UP] introduce SQLExecution.ignoreNestedExecutionId ## What changes were proposed in this pull request? in https://github.com/apache/spark/pull/18064, to work around the nested sql execution id issue, we introduced several internal methods in `Dataset`, like `collectInternal`, `countInternal`, `showInternal`, etc., to avoid nested execution id. However, this approach has poor expansibility. When we hit other nested execution id cases, we may need to add more internal methods in `Dataset`. Our goal is to ignore the nested execution id in some cases, and we can have a better approach to achieve this goal, by introducing `SQLExecution.ignoreNestedExecutionId`. Whenever we find a place which needs to ignore the nested execution, we can just wrap the action with `SQLExecution.ignoreNestedExecutionId`, and this is more expansible than the previous approach. The idea comes from https://github.com/apache/spark/pull/17540/files#diff-ab49028253e599e6e74cc4f4dcb2e3a8R57 by rdblue ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18419 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/Dataset.scala | 39 ++----------------- .../spark/sql/execution/SQLExecution.scala | 39 +++++++++++++++++-- .../command/AnalyzeTableCommand.scala | 5 ++- .../spark/sql/execution/command/cache.scala | 19 ++++----- .../datasources/csv/CSVDataSource.scala | 6 ++- .../datasources/jdbc/JDBCRelation.scala | 14 +++---- .../sql/execution/streaming/console.scala | 13 +++++-- .../sql/execution/streaming/memory.scala | 33 +++++++++------- 8 files changed, 89 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6e66e92091ff9..268a37ff5d271 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -246,13 +246,8 @@ class Dataset[T] private[sql]( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) - showString(takeResult, numRows, truncate, vertical) - } - - private def showString( - dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = { - val hasMoreData = dataWithOneMoreRow.length > numRows - val data = dataWithOneMoreRow.take(numRows) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) lazy val timeZone = DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) @@ -688,19 +683,6 @@ class Dataset[T] private[sql]( println(showString(numRows, truncate = 0)) } - // An internal version of `show`, which won't set execution id and trigger listeners. - private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = { - val numRows = _numRows.max(0) - val takeResult = toDF().takeInternal(numRows + 1) - - if (truncate) { - println(showString(takeResult, numRows, truncate = 20, vertical = false)) - } else { - println(showString(takeResult, numRows, truncate = 0, vertical = false)) - } - } - // scalastyle:on println - /** * Displays the Dataset in a tabular form. For example: * {{{ @@ -2467,11 +2449,6 @@ class Dataset[T] private[sql]( */ def take(n: Int): Array[T] = head(n) - // An internal version of `take`, which won't set execution id and trigger listeners. - private[sql] def takeInternal(n: Int): Array[T] = { - collectFromPlan(limit(n).queryExecution.executedPlan) - } - /** * Returns the first `n` rows in the Dataset as a list. * @@ -2496,11 +2473,6 @@ class Dataset[T] private[sql]( */ def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan) - // An internal version of `collect`, which won't set execution id and trigger listeners. - private[sql] def collectInternal(): Array[T] = { - collectFromPlan(queryExecution.executedPlan) - } - /** * Returns a Java list that contains all rows in this Dataset. * @@ -2542,11 +2514,6 @@ class Dataset[T] private[sql]( plan.executeCollect().head.getLong(0) } - // An internal version of `count`, which won't set execution id and trigger listeners. - private[sql] def countInternal(): Long = { - groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0) - } - /** * Returns a new Dataset that has exactly `numPartitions` partitions. * @@ -2792,7 +2759,7 @@ class Dataset[T] private[sql]( createTempViewCommand(viewName, replace = true, global = true) } - private[spark] def createTempViewCommand( + private def createTempViewCommand( viewName: String, replace: Boolean, global: Boolean): CreateViewCommand = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index bb206e84325fd..ca8bed5214f87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -29,6 +29,8 @@ object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" + private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId" + private val _nextExecutionId = new AtomicLong(0) private def nextExecutionId: Long = _nextExecutionId.getAndIncrement @@ -42,8 +44,11 @@ object SQLExecution { private val testing = sys.props.contains("spark.testing") private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { + val sc = sparkSession.sparkContext + val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null + val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null // only throw an exception during tests. a missing execution ID should not fail a job. - if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) { + if (testing && !isNestedExecution && !hasExecutionId) { // Attention testers: when a test fails with this exception, it means that the action that // started execution of a query didn't call withNewExecutionId. The execution ID should be // set by calling withNewExecutionId in the action that begins execution, like @@ -65,7 +70,7 @@ object SQLExecution { val executionId = SQLExecution.nextExecutionId sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) executionIdToQueryExecution.put(executionId, queryExecution) - val r = try { + try { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" @@ -84,7 +89,15 @@ object SQLExecution { executionIdToQueryExecution.remove(executionId) sc.setLocalProperty(EXECUTION_ID_KEY, null) } - r + } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) { + // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the + // `body`, so that Spark jobs issued in the `body` won't be tracked. + try { + sc.setLocalProperty(EXECUTION_ID_KEY, null) + body + } finally { + sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) + } } else { // Don't support nested `withNewExecutionId`. This is an example of the nested // `withNewExecutionId`: @@ -100,7 +113,9 @@ object SQLExecution { // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. // // A real case is the `DataFrame.count` method. - throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set") + throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " + + "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " + + "jobs issued by the nested execution.") } } @@ -118,4 +133,20 @@ object SQLExecution { sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } + + /** + * Wrap an action which may have nested execution id. This method can be used to run an execution + * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that, + * all Spark jobs issued in the body won't be tracked in UI. + */ + def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) + try { + sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true") + body + } finally { + sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 06e588f56f1e9..13b8faff844c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.internal.SessionState @@ -58,7 +59,9 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = sparkSession.table(tableIdentWithDB).countInternal() + val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) { + sparkSession.table(tableIdentWithDB).count() + } if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 184d0387ebfa9..d36eb7587a3ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SQLExecution case class CacheTableCommand( tableIdent: TableIdentifier, @@ -33,16 +34,16 @@ case class CacheTableCommand( override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { - plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan) - .createTempViewCommand(tableIdent.quotedString, replace = false, global = false) - .run(sparkSession) - } - sparkSession.catalog.cacheTable(tableIdent.quotedString) + SQLExecution.ignoreNestedExecutionId(sparkSession) { + plan.foreach { logicalPlan => + Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + } + sparkSession.catalog.cacheTable(tableIdent.quotedString) - if (!isLazy) { - // Performs eager caching - sparkSession.table(tableIdent).countInternal() + if (!isLazy) { + // Performs eager caching + sparkSession.table(tableIdent).count() + } } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index eadc6c94f4b3c..99133bd70989a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption + val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) { + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + } inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index a06f1ce3287e6..b11da7045de22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -129,14 +130,11 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - import scala.collection.JavaConverters._ - - val options = jdbcOptions.asProperties.asScala + - ("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table) - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - - new JdbcRelationProvider().createRelation( - data.sparkSession.sqlContext, mode, options.toMap, data) + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) + } } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9e889ff679450..6fa7c113defaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.types.StructType class ConsoleSink(options: Map[String, String]) extends Sink with Logging { @@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println(batchIdStr) println("-------------------------------------------") // scalastyle:off println - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema) - .showInternal(numRowsToShow, isTruncated) + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + data.sparkSession.createDataFrame( + data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) + .show(numRowsToShow, isTruncated) + } } } @@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider // Truncate the displayed data if it is too long, by default it is true val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) - data.showInternal(numRowsToShow, isTruncated) + SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) { + data.show(numRowsToShow, isTruncated) + } ConsoleRelation(sqlContext, data) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index a5dac469f85b6..198a342582804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -193,21 +194,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collectInternal()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collectInternal()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + outputMode match { + case Append | Update => + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } + + case Complete => + val rows = AddedData(batchId, data.collect()) + synchronized { + batches.clear() + batches += rows + } + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") + } } } else { logDebug(s"Skipping already committed batch: $batchId") From 3cb3ccce120fa9f0273133912624b877b42d95fd Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 27 Jun 2017 17:24:46 +0800 Subject: [PATCH 024/779] [SPARK-21196] Split codegen info of query plan into sequence codegen info of query plan can be very long. In debugging console / web page, it would be more readable if the subtrees and corresponding codegen are split into sequence. Example: ```java codegenStringSeq(sql("select 1").queryExecution.executedPlan) ``` The example will return Seq[(String, String)] of length 1, containing the subtree as string and the corresponding generated code. The subtree as string: > (*Project [1 AS 1#0] > +- Scan OneRowRelation[] The generated code: ```java /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow project_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter; /* 012 */ /* 013 */ public GeneratedIterator(Object[] references) { /* 014 */ this.references = references; /* 015 */ } /* 016 */ /* 017 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 018 */ partitionIndex = index; /* 019 */ this.inputs = inputs; /* 020 */ inputadapter_input = inputs[0]; /* 021 */ project_result = new UnsafeRow(1); /* 022 */ project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 0); /* 023 */ project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1); /* 024 */ /* 025 */ } /* 026 */ /* 027 */ protected void processNext() throws java.io.IOException { /* 028 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 029 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 030 */ project_rowWriter.write(0, 1); /* 031 */ append(project_result); /* 032 */ if (shouldStop()) return; /* 033 */ } /* 034 */ } /* 035 */ /* 036 */ } ``` ## What changes were proposed in this pull request? add method codegenToSeq: split codegen info of query plan into sequence ## How was this patch tested? unit test cloud-fan gatorsmile Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18409 from gengliangwang/codegen. --- .../spark/sql/execution/QueryExecution.scala | 9 +++++ .../spark/sql/execution/debug/package.scala | 35 ++++++++++++++----- .../sql/execution/debug/DebuggingSuite.scala | 7 ++++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c7cac332a0377..9533144214a10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -245,5 +245,14 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { println(org.apache.spark.sql.execution.debug.codegenString(executedPlan)) // scalastyle:on println } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenToSeq(): Seq[(String, String)] = { + org.apache.spark.sql.execution.debug.codegenStringSeq(executedPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 0395c43ba2cbc..a717cbd4a7df9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -50,7 +50,31 @@ package object debug { // scalastyle:on println } + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan into one String + * + * @param plan the query plan for codegen + * @return single String containing all WholeStageCodegen subtrees and corresponding codegen + */ def codegenString(plan: SparkPlan): String = { + val codegenSeq = codegenStringSeq(plan) + var output = s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n" + for (((subtree, code), i) <- codegenSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n" + output += subtree + output += "\nGenerated code:\n" + output += s"${code}\n" + } + output + } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @param plan the query plan for codegen + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenStringSeq(plan: SparkPlan): Seq[(String, String)] = { val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() plan transform { case s: WholeStageCodegenExec => @@ -58,15 +82,10 @@ package object debug { s case s => s } - var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" - for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { - output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" - output += s - output += "\nGenerated code:\n" - val (_, source) = s.doCodeGen() - output += s"${CodeFormatter.format(source)}\n" + codegenSubtrees.toSeq.map { subtree => + val (_, source) = subtree.doCodeGen() + (subtree.toString, CodeFormatter.format(source)) } - output } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 4fc52c99fbeeb..adcaf2d76519f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -38,4 +38,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } + + test("debugCodegenStringSeq") { + val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + assert(res.length == 2) + assert(res.forall{ case (subtree, code) => + subtree.contains("Range") && code.contains("Object[]")}) + } } From b32bd005e46443bbd487b7a1f1078578c8f4c181 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 27 Jun 2017 13:14:12 +0100 Subject: [PATCH 025/779] [INFRA] Close stale PRs ## What changes were proposed in this pull request? This PR proposes to close stale PRs, mostly the same instances with https://github.com/apache/spark/pull/18017 I believe the author in #14807 removed his account. Closes #7075 Closes #8927 Closes #9202 Closes #9366 Closes #10861 Closes #11420 Closes #12356 Closes #13028 Closes #13506 Closes #14191 Closes #14198 Closes #14330 Closes #14807 Closes #15839 Closes #16225 Closes #16685 Closes #16692 Closes #16995 Closes #17181 Closes #17211 Closes #17235 Closes #17237 Closes #17248 Closes #17341 Closes #17708 Closes #17716 Closes #17721 Closes #17937 Added: Closes #14739 Closes #17139 Closes #17445 Closes #18042 Closes #18359 Added: Closes #16450 Closes #16525 Closes #17738 Added: Closes #16458 Closes #16508 Closes #17714 Added: Closes #17830 Closes #14742 ## How was this patch tested? N/A Author: hyukjinkwon Closes #18417 from HyukjinKwon/close-stale-pr. From fd8c931a30a084ee981b75aa469fc97dda6cfaa9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Jun 2017 00:57:05 +0800 Subject: [PATCH 026/779] [SPARK-19104][SQL] Lambda variables in ExternalMapToCatalyst should be global ## What changes were proposed in this pull request? The issue happens in `ExternalMapToCatalyst`. For example, the following codes create `ExternalMapToCatalyst` to convert Scala Map to catalyst map format. val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) val ds = spark.createDataset(data) The `valueConverter` in `ExternalMapToCatalyst` looks like: if (isnull(lambdavariable(ExternalMapToCatalyst_value52, ExternalMapToCatalyst_value_isNull52, ObjectType(class org.apache.spark.sql.InnerData), true))) null else named_struct(name, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(lambdavariable(ExternalMapToCatalyst_value52, ExternalMapToCatalyst_value_isNull52, ObjectType(class org.apache.spark.sql.InnerData), true)).name, true), value, assertnotnull(lambdavariable(ExternalMapToCatalyst_value52, ExternalMapToCatalyst_value_isNull52, ObjectType(class org.apache.spark.sql.InnerData), true)).value) There is a `CreateNamedStruct` expression (`named_struct`) to create a row of `InnerData.name` and `InnerData.value` that are referred by `ExternalMapToCatalyst_value52`. Because `ExternalMapToCatalyst_value52` are local variable, when `CreateNamedStruct` splits expressions to individual functions, the local variable can't be accessed anymore. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #18418 from viirya/SPARK-19104. --- .../catalyst/expressions/objects/objects.scala | 18 ++++++++++++------ .../spark/sql/DatasetPrimitiveSuite.scala | 8 ++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 073993cccdf8a..4b651836ff4d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -911,6 +911,12 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") + val keyElementJavaType = ctx.javaType(keyType) + val valueElementJavaType = ctx.javaType(valueType) + ctx.addMutableState(keyElementJavaType, key, "") + ctx.addMutableState("boolean", valueIsNull, "") + ctx.addMutableState(valueElementJavaType, value, "") + val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => val javaIteratorCls = classOf[java.util.Iterator[_]].getName @@ -922,8 +928,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${ctx.boxedType(keyType)}) $entry.getKey(); + $value = (${ctx.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -937,17 +943,17 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${ctx.boxedType(keyType)}) $entry._1(); + $value = (${ctx.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { - s"boolean $valueIsNull = false;" + s"$valueIsNull = false;" } else { - s"boolean $valueIsNull = $value == null;" + s"$valueIsNull = $value == null;" } val arrayCls = classOf[GenericArrayData].getName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 4126660b5d102..a6847dcfbffc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -39,6 +39,9 @@ case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) +case class InnerData(name: String, value: Int) +case class NestedData(id: Int, param: Map[String, InnerData]) + package object packageobject { case class PackageClass(value: Int) } @@ -354,4 +357,9 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } + test("SPARK-19104: Lambda variables in ExternalMapToCatalyst should be global") { + val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) + val ds = spark.createDataset(data) + checkDataset(ds, data: _*) + } } From 2d686a19e341a31d976aa42228b7589f87dfd6c2 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Wed, 28 Jun 2017 09:26:33 +0800 Subject: [PATCH 027/779] [SPARK-21155][WEBUI] Add (? running tasks) into Spark UI progress ## What changes were proposed in this pull request? Add metric on number of running tasks to status bar on Jobs / Active Jobs. ## How was this patch tested? Run a long running (1 minute) query in spark-shell and use localhost:4040 web UI to observe progress. See jira for screen snapshot. Author: Eric Vandenberg Closes #18369 from ericvandenbergfb/runningTasks. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 2610f673d27f6..ba798df13c95d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -356,6 +356,7 @@ private[spark] object UIUtils extends Logging {
{completed}/{total} + { if (failed == 0 && skipped == 0 && started > 0) s"($started running)" } { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } { reasonToNumKilled.toSeq.sortBy(-_._2).map { From e793bf248bc3c71b9664f26377bce06b0ffa97a7 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 27 Jun 2017 23:15:45 -0700 Subject: [PATCH 028/779] [SPARK-20889][SPARKR] Grouped documentation for MATH column methods ## What changes were proposed in this pull request? Grouped documentation for math column methods. Author: actuaryzhang Author: Wayne Zhang Closes #18371 from actuaryzhang/sparkRDocMath. --- R/pkg/R/functions.R | 619 +++++++++++++++----------------------------- R/pkg/R/generics.R | 48 ++-- 2 files changed, 241 insertions(+), 426 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 31028585aaa13..23ccdf941a8c7 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -86,6 +86,31 @@ NULL #' df <- createDataFrame(data.frame(time = as.POSIXct(dts), y = y))} NULL +#' Math functions for Column operations +#' +#' Math functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and \code{shiftRightUnsigned}, +#' this is the number of bits to shift. +#' @param y Column to compute on. +#' @param ... additional argument(s). +#' @name column_math_functions +#' @rdname column_math_functions +#' @family math functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' tmp <- mutate(df, v1 = log(df$mpg), v2 = cbrt(df$disp), +#' v3 = bround(df$wt, 1), v4 = bin(df$cyl), +#' v5 = hex(df$wt), v6 = toDegrees(df$gear), +#' v7 = atan2(df$cyl, df$am), v8 = hypot(df$cyl, df$am), +#' v9 = pmod(df$hp, df$cyl), v10 = shiftLeft(df$disp, 1), +#' v11 = conv(df$hp, 10, 16), v12 = sign(df$vs - 0.5), +#' v13 = sqrt(df$disp), v14 = ceil(df$wt)) +#' head(tmp)} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -112,18 +137,12 @@ setMethod("lit", signature("ANY"), column(jc) }) -#' abs -#' -#' Computes the absolute value. -#' -#' @param x Column to compute on. +#' @details +#' \code{abs}: Computes the absolute value. #' -#' @rdname abs -#' @name abs -#' @family non-aggregate functions +#' @rdname column_math_functions #' @export -#' @examples \dontrun{abs(df$c)} -#' @aliases abs,Column-method +#' @aliases abs abs,Column-method #' @note abs since 1.5.0 setMethod("abs", signature(x = "Column"), @@ -132,19 +151,13 @@ setMethod("abs", column(jc) }) -#' acos -#' -#' Computes the cosine inverse of the given value; the returned angle is in the range -#' 0.0 through pi. -#' -#' @param x Column to compute on. +#' @details +#' \code{acos}: Computes the cosine inverse of the given value; the returned angle is in +#' the range 0.0 through pi. #' -#' @rdname acos -#' @name acos -#' @family math functions +#' @rdname column_math_functions #' @export -#' @examples \dontrun{acos(df$c)} -#' @aliases acos,Column-method +#' @aliases acos acos,Column-method #' @note acos since 1.5.0 setMethod("acos", signature(x = "Column"), @@ -196,19 +209,13 @@ setMethod("ascii", column(jc) }) -#' asin -#' -#' Computes the sine inverse of the given value; the returned angle is in the range -#' -pi/2 through pi/2. -#' -#' @param x Column to compute on. +#' @details +#' \code{asin}: Computes the sine inverse of the given value; the returned angle is in +#' the range -pi/2 through pi/2. #' -#' @rdname asin -#' @name asin -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases asin,Column-method -#' @examples \dontrun{asin(df$c)} +#' @aliases asin asin,Column-method #' @note asin since 1.5.0 setMethod("asin", signature(x = "Column"), @@ -217,18 +224,12 @@ setMethod("asin", column(jc) }) -#' atan -#' -#' Computes the tangent inverse of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{atan}: Computes the tangent inverse of the given value. #' -#' @rdname atan -#' @name atan -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases atan,Column-method -#' @examples \dontrun{atan(df$c)} +#' @aliases atan atan,Column-method #' @note atan since 1.5.0 setMethod("atan", signature(x = "Column"), @@ -276,19 +277,13 @@ setMethod("base64", column(jc) }) -#' bin -#' -#' An expression that returns the string representation of the binary value of the given long -#' column. For example, bin("12") returns "1100". -#' -#' @param x Column to compute on. +#' @details +#' \code{bin}: An expression that returns the string representation of the binary value +#' of the given long column. For example, bin("12") returns "1100". #' -#' @rdname bin -#' @name bin -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases bin,Column-method -#' @examples \dontrun{bin(df$c)} +#' @aliases bin bin,Column-method #' @note bin since 1.5.0 setMethod("bin", signature(x = "Column"), @@ -317,18 +312,12 @@ setMethod("bitwiseNOT", column(jc) }) -#' cbrt -#' -#' Computes the cube-root of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{cbrt}: Computes the cube-root of the given value. #' -#' @rdname cbrt -#' @name cbrt -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases cbrt,Column-method -#' @examples \dontrun{cbrt(df$c)} +#' @aliases cbrt cbrt,Column-method #' @note cbrt since 1.4.0 setMethod("cbrt", signature(x = "Column"), @@ -337,18 +326,12 @@ setMethod("cbrt", column(jc) }) -#' Computes the ceiling of the given value -#' -#' Computes the ceiling of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{ceil}: Computes the ceiling of the given value. #' -#' @rdname ceil -#' @name ceil -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases ceil,Column-method -#' @examples \dontrun{ceil(df$c)} +#' @aliases ceil ceil,Column-method #' @note ceil since 1.5.0 setMethod("ceil", signature(x = "Column"), @@ -357,6 +340,19 @@ setMethod("ceil", column(jc) }) +#' @details +#' \code{ceiling}: Alias for \code{ceil}. +#' +#' @rdname column_math_functions +#' @aliases ceiling ceiling,Column-method +#' @export +#' @note ceiling since 1.5.0 +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + #' Returns the first column that is not NA #' #' Returns the first column that is not NA, or NA if all inputs are. @@ -405,6 +401,7 @@ setMethod("column", function(x) { col(x) }) + #' corr #' #' Computes the Pearson Correlation Coefficient for two Columns. @@ -493,18 +490,12 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr column(jc) }) -#' cos -#' -#' Computes the cosine of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{cos}: Computes the cosine of the given value. #' -#' @rdname cos -#' @name cos -#' @family math functions -#' @aliases cos,Column-method +#' @rdname column_math_functions +#' @aliases cos cos,Column-method #' @export -#' @examples \dontrun{cos(df$c)} #' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), @@ -513,18 +504,12 @@ setMethod("cos", column(jc) }) -#' cosh -#' -#' Computes the hyperbolic cosine of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{cosh}: Computes the hyperbolic cosine of the given value. #' -#' @rdname cosh -#' @name cosh -#' @family math functions -#' @aliases cosh,Column-method +#' @rdname column_math_functions +#' @aliases cosh cosh,Column-method #' @export -#' @examples \dontrun{cosh(df$c)} #' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), @@ -679,18 +664,12 @@ setMethod("encode", column(jc) }) -#' exp -#' -#' Computes the exponential of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{exp}: Computes the exponential of the given value. #' -#' @rdname exp -#' @name exp -#' @family math functions -#' @aliases exp,Column-method +#' @rdname column_math_functions +#' @aliases exp exp,Column-method #' @export -#' @examples \dontrun{exp(df$c)} #' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), @@ -699,18 +678,12 @@ setMethod("exp", column(jc) }) -#' expm1 -#' -#' Computes the exponential of the given value minus one. -#' -#' @param x Column to compute on. +#' @details +#' \code{expm1}: Computes the exponential of the given value minus one. #' -#' @rdname expm1 -#' @name expm1 -#' @aliases expm1,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases expm1 expm1,Column-method #' @export -#' @examples \dontrun{expm1(df$c)} #' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), @@ -719,18 +692,12 @@ setMethod("expm1", column(jc) }) -#' factorial -#' -#' Computes the factorial of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{factorial}: Computes the factorial of the given value. #' -#' @rdname factorial -#' @name factorial -#' @aliases factorial,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases factorial factorial,Column-method #' @export -#' @examples \dontrun{factorial(df$c)} #' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), @@ -772,18 +739,12 @@ setMethod("first", column(jc) }) -#' floor -#' -#' Computes the floor of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{floor}: Computes the floor of the given value. #' -#' @rdname floor -#' @name floor -#' @aliases floor,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases floor floor,Column-method #' @export -#' @examples \dontrun{floor(df$c)} #' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), @@ -792,18 +753,12 @@ setMethod("floor", column(jc) }) -#' hex -#' -#' Computes hex value of the given column. -#' -#' @param x Column to compute on. +#' @details +#' \code{hex}: Computes hex value of the given column. #' -#' @rdname hex -#' @name hex -#' @family math functions -#' @aliases hex,Column-method +#' @rdname column_math_functions +#' @aliases hex hex,Column-method #' @export -#' @examples \dontrun{hex(df$c)} #' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), @@ -983,18 +938,12 @@ setMethod("length", column(jc) }) -#' log -#' -#' Computes the natural logarithm of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{log}: Computes the natural logarithm of the given value. #' -#' @rdname log -#' @name log -#' @aliases log,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases log log,Column-method #' @export -#' @examples \dontrun{log(df$c)} #' @note log since 1.5.0 setMethod("log", signature(x = "Column"), @@ -1003,18 +952,12 @@ setMethod("log", column(jc) }) -#' log10 -#' -#' Computes the logarithm of the given value in base 10. -#' -#' @param x Column to compute on. +#' @details +#' \code{log10}: Computes the logarithm of the given value in base 10. #' -#' @rdname log10 -#' @name log10 -#' @family math functions -#' @aliases log10,Column-method +#' @rdname column_math_functions +#' @aliases log10 log10,Column-method #' @export -#' @examples \dontrun{log10(df$c)} #' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), @@ -1023,18 +966,12 @@ setMethod("log10", column(jc) }) -#' log1p -#' -#' Computes the natural logarithm of the given value plus one. -#' -#' @param x Column to compute on. +#' @details +#' \code{log1p}: Computes the natural logarithm of the given value plus one. #' -#' @rdname log1p -#' @name log1p -#' @family math functions -#' @aliases log1p,Column-method +#' @rdname column_math_functions +#' @aliases log1p log1p,Column-method #' @export -#' @examples \dontrun{log1p(df$c)} #' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), @@ -1043,18 +980,12 @@ setMethod("log1p", column(jc) }) -#' log2 -#' -#' Computes the logarithm of the given column in base 2. -#' -#' @param x Column to compute on. +#' @details +#' \code{log2}: Computes the logarithm of the given column in base 2. #' -#' @rdname log2 -#' @name log2 -#' @family math functions -#' @aliases log2,Column-method +#' @rdname column_math_functions +#' @aliases log2 log2,Column-method #' @export -#' @examples \dontrun{log2(df$c)} #' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), @@ -1287,19 +1218,13 @@ setMethod("reverse", column(jc) }) -#' rint -#' -#' Returns the double value that is closest in value to the argument and +#' @details +#' \code{rint}: Returns the double value that is closest in value to the argument and #' is equal to a mathematical integer. #' -#' @param x Column to compute on. -#' -#' @rdname rint -#' @name rint -#' @family math functions -#' @aliases rint,Column-method +#' @rdname column_math_functions +#' @aliases rint rint,Column-method #' @export -#' @examples \dontrun{rint(df$c)} #' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), @@ -1308,18 +1233,13 @@ setMethod("rint", column(jc) }) -#' round -#' -#' Returns the value of the column \code{e} rounded to 0 decimal places using HALF_UP rounding mode. -#' -#' @param x Column to compute on. +#' @details +#' \code{round}: Returns the value of the column rounded to 0 decimal places +#' using HALF_UP rounding mode. #' -#' @rdname round -#' @name round -#' @family math functions -#' @aliases round,Column-method +#' @rdname column_math_functions +#' @aliases round round,Column-method #' @export -#' @examples \dontrun{round(df$c)} #' @note round since 1.5.0 setMethod("round", signature(x = "Column"), @@ -1328,24 +1248,18 @@ setMethod("round", column(jc) }) -#' bround -#' -#' Returns the value of the column \code{e} rounded to \code{scale} decimal places using HALF_EVEN rounding -#' mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. +#' @details +#' \code{bround}: Returns the value of the column \code{e} rounded to \code{scale} decimal places +#' using HALF_EVEN rounding mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' -#' @param x Column to compute on. #' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, #' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left #' of the decimal point when \code{scale} < 0. -#' @param ... further arguments to be passed to or from other methods. -#' @rdname bround -#' @name bround -#' @family math functions -#' @aliases bround,Column-method +#' @rdname column_math_functions +#' @aliases bround bround,Column-method #' @export -#' @examples \dontrun{bround(df$c, 0)} #' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), @@ -1354,7 +1268,6 @@ setMethod("bround", column(jc) }) - #' rtrim #' #' Trim the spaces from right end for the specified string value. @@ -1375,7 +1288,6 @@ setMethod("rtrim", column(jc) }) - #' @details #' \code{sd}: Alias for \code{stddev_samp}. #' @@ -1429,18 +1341,12 @@ setMethod("sha1", column(jc) }) -#' signum -#' -#' Computes the signum of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{signum}: Computes the signum of the given value. #' -#' @rdname sign -#' @name signum -#' @aliases signum,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases signum signum,Column-method #' @export -#' @examples \dontrun{signum(df$c)} #' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), @@ -1449,18 +1355,24 @@ setMethod("signum", column(jc) }) -#' sin -#' -#' Computes the sine of the given value. +#' @details +#' \code{sign}: Alias for \code{signum}. #' -#' @param x Column to compute on. +#' @rdname column_math_functions +#' @aliases sign sign,Column-method +#' @export +#' @note sign since 1.5.0 +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' @details +#' \code{sin}: Computes the sine of the given value. #' -#' @rdname sin -#' @name sin -#' @family math functions -#' @aliases sin,Column-method +#' @rdname column_math_functions +#' @aliases sin sin,Column-method #' @export -#' @examples \dontrun{sin(df$c)} #' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), @@ -1469,18 +1381,12 @@ setMethod("sin", column(jc) }) -#' sinh -#' -#' Computes the hyperbolic sine of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{sinh}: Computes the hyperbolic sine of the given value. #' -#' @rdname sinh -#' @name sinh -#' @family math functions -#' @aliases sinh,Column-method +#' @rdname column_math_functions +#' @aliases sinh sinh,Column-method #' @export -#' @examples \dontrun{sinh(df$c)} #' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), @@ -1616,18 +1522,12 @@ setMethod("struct", column(jc) }) -#' sqrt -#' -#' Computes the square root of the specified float value. -#' -#' @param x Column to compute on. +#' @details +#' \code{sqrt}: Computes the square root of the specified float value. #' -#' @rdname sqrt -#' @name sqrt -#' @family math functions -#' @aliases sqrt,Column-method +#' @rdname column_math_functions +#' @aliases sqrt sqrt,Column-method #' @export -#' @examples \dontrun{sqrt(df$c)} #' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), @@ -1669,18 +1569,12 @@ setMethod("sumDistinct", column(jc) }) -#' tan -#' -#' Computes the tangent of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{tan}: Computes the tangent of the given value. #' -#' @rdname tan -#' @name tan -#' @family math functions -#' @aliases tan,Column-method +#' @rdname column_math_functions +#' @aliases tan tan,Column-method #' @export -#' @examples \dontrun{tan(df$c)} #' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), @@ -1689,18 +1583,12 @@ setMethod("tan", column(jc) }) -#' tanh -#' -#' Computes the hyperbolic tangent of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{tanh}: Computes the hyperbolic tangent of the given value. #' -#' @rdname tanh -#' @name tanh -#' @family math functions -#' @aliases tanh,Column-method +#' @rdname column_math_functions +#' @aliases tanh tanh,Column-method #' @export -#' @examples \dontrun{tanh(df$c)} #' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), @@ -1709,18 +1597,13 @@ setMethod("tanh", column(jc) }) -#' toDegrees -#' -#' Converts an angle measured in radians to an approximately equivalent angle measured in degrees. -#' -#' @param x Column to compute on. +#' @details +#' \code{toDegrees}: Converts an angle measured in radians to an approximately equivalent angle +#' measured in degrees. #' -#' @rdname toDegrees -#' @name toDegrees -#' @family math functions -#' @aliases toDegrees,Column-method +#' @rdname column_math_functions +#' @aliases toDegrees toDegrees,Column-method #' @export -#' @examples \dontrun{toDegrees(df$c)} #' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), @@ -1729,18 +1612,13 @@ setMethod("toDegrees", column(jc) }) -#' toRadians -#' -#' Converts an angle measured in degrees to an approximately equivalent angle measured in radians. -#' -#' @param x Column to compute on. +#' @details +#' \code{toRadians}: Converts an angle measured in degrees to an approximately equivalent angle +#' measured in radians. #' -#' @rdname toRadians -#' @name toRadians -#' @family math functions -#' @aliases toRadians,Column-method +#' @rdname column_math_functions +#' @aliases toRadians toRadians,Column-method #' @export -#' @examples \dontrun{toRadians(df$c)} #' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), @@ -1894,19 +1772,13 @@ setMethod("unbase64", column(jc) }) -#' unhex -#' -#' Inverse of hex. Interprets each pair of characters as a hexadecimal number +#' @details +#' \code{unhex}: Inverse of hex. Interprets each pair of characters as a hexadecimal number #' and converts to the byte representation of number. #' -#' @param x Column to compute on. -#' -#' @rdname unhex -#' @name unhex -#' @family math functions -#' @aliases unhex,Column-method +#' @rdname column_math_functions +#' @aliases unhex unhex,Column-method #' @export -#' @examples \dontrun{unhex(df$c)} #' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), @@ -2020,20 +1892,13 @@ setMethod("year", column(jc) }) -#' atan2 -#' -#' Returns the angle theta from the conversion of rectangular coordinates (x, y) to -#' polar coordinates (r, theta). -# -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates +#' (x, y) to polar coordinates (r, theta). #' -#' @rdname atan2 -#' @name atan2 -#' @family math functions -#' @aliases atan2,Column-method +#' @rdname column_math_functions +#' @aliases atan2 atan2,Column-method #' @export -#' @examples \dontrun{atan2(df$c, x)} #' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { @@ -2068,19 +1933,12 @@ setMethod("datediff", signature(y = "Column"), column(jc) }) -#' hypot -#' -#' Computes "sqrt(a^2 + b^2)" without intermediate overflow or underflow. -# -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{hypot}: Computes "sqrt(a^2 + b^2)" without intermediate overflow or underflow. #' -#' @rdname hypot -#' @name hypot -#' @family math functions -#' @aliases hypot,Column-method +#' @rdname column_math_functions +#' @aliases hypot hypot,Column-method #' @export -#' @examples \dontrun{hypot(df$c, x)} #' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { @@ -2154,20 +2012,13 @@ setMethod("nanvl", signature(y = "Column"), column(jc) }) -#' pmod -#' -#' Returns the positive value of dividend mod divisor. -#' -#' @param x divisor Column. -#' @param y dividend Column. +#' @details +#' \code{pmod}: Returns the positive value of dividend mod divisor. +#' Column \code{x} is divisor column, and column \code{y} is the dividend column. #' -#' @rdname pmod -#' @name pmod -#' @docType methods -#' @family math functions -#' @aliases pmod,Column-method +#' @rdname column_math_functions +#' @aliases pmod pmod,Column-method #' @export -#' @examples \dontrun{pmod(df$c, x)} #' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { @@ -2290,31 +2141,6 @@ setMethod("least", column(jc) }) -#' @rdname ceil -#' -#' @name ceiling -#' @aliases ceiling,Column-method -#' @export -#' @examples \dontrun{ceiling(df$c)} -#' @note ceiling since 1.5.0 -setMethod("ceiling", - signature(x = "Column"), - function(x) { - ceil(x) - }) - -#' @rdname sign -#' -#' @name sign -#' @aliases sign,Column-method -#' @export -#' @examples \dontrun{sign(df$c)} -#' @note sign since 1.5.0 -setMethod("sign", signature(x = "Column"), - function(x) { - signum(x) - }) - #' @details #' \code{n_distinct}: Returns the number of distinct items in a group. #' @@ -2564,20 +2390,13 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftLeft -#' -#' Shift the given value numBits left. If the given value is a long value, this function -#' will return a long value else it will return an integer value. -#' -#' @param y column to compute on. -#' @param x number of bits to shift. +#' @details +#' \code{shiftLeft}: Shifts the given value numBits left. If the given value is a long value, +#' this function will return a long value else it will return an integer value. #' -#' @family math functions -#' @rdname shiftLeft -#' @name shiftLeft -#' @aliases shiftLeft,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftLeft shiftLeft,Column,numeric-method #' @export -#' @examples \dontrun{shiftLeft(df$c, 1)} #' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2587,20 +2406,13 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftRight -#' -#' (Signed) shift the given value numBits right. If the given value is a long value, it will return -#' a long value else it will return an integer value. -#' -#' @param y column to compute on. -#' @param x number of bits to shift. +#' @details +#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long value, +#' it will return a long value else it will return an integer value. #' -#' @family math functions -#' @rdname shiftRight -#' @name shiftRight -#' @aliases shiftRight,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftRight shiftRight,Column,numeric-method #' @export -#' @examples \dontrun{shiftRight(df$c, 1)} #' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2610,20 +2422,13 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftRightUnsigned -#' -#' Unsigned shift the given value numBits right. If the given value is a long value, +#' @details +#' \code{shiftRight}: (Unigned) shifts the given value numBits right. If the given value is a long value, #' it will return a long value else it will return an integer value. #' -#' @param y column to compute on. -#' @param x number of bits to shift. -#' -#' @family math functions -#' @rdname shiftRightUnsigned -#' @name shiftRightUnsigned -#' @aliases shiftRightUnsigned,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method #' @export -#' @examples \dontrun{shiftRightUnsigned(df$c, 1)} #' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2656,20 +2461,14 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), column(jc) }) -#' conv -#' -#' Convert a number in a string column from one base to another. +#' @details +#' \code{conv}: Converts a number in a string column from one base to another. #' -#' @param x column to convert. #' @param fromBase base to convert from. #' @param toBase base to convert to. -#' -#' @family math functions -#' @rdname conv -#' @aliases conv,Column,numeric,numeric-method -#' @name conv +#' @rdname column_math_functions +#' @aliases conv conv,Column,numeric,numeric-method #' @export -#' @examples \dontrun{conv(df$n, 2, 16)} #' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f105174cea70d..0248ec585d771 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -931,24 +931,28 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @export setGeneric("base64", function(x) { standardGeneric("base64") }) -#' @rdname bin +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) #' @rdname bitwiseNOT #' @export setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) -#' @rdname bround +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("bround", function(x, ...) { standardGeneric("bround") }) -#' @rdname cbrt +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) -#' @rdname ceil +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("ceil", function(x) { standardGeneric("ceil") }) #' @rdname column_aggregate_functions @@ -973,8 +977,9 @@ setGeneric("concat", function(x, ...) { standardGeneric("concat") }) #' @export setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) -#' @rdname conv +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @rdname column_aggregate_functions @@ -1094,8 +1099,9 @@ setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) #' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) -#' @rdname hex +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("hex", function(x) { standardGeneric("hex") }) #' @rdname column_datetime_functions @@ -1103,8 +1109,9 @@ setGeneric("hex", function(x) { standardGeneric("hex") }) #' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) -#' @rdname hypot +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @rdname initcap @@ -1235,8 +1242,9 @@ setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @export setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) -#' @rdname pmod +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @rdname posexplode @@ -1281,8 +1289,9 @@ setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) #' @export setGeneric("reverse", function(x) { standardGeneric("reverse") }) -#' @rdname rint +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) #' @param x empty. Should be used with no argument. @@ -1316,20 +1325,24 @@ setGeneric("sha1", function(x) { standardGeneric("sha1") }) #' @export setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) -#' @rdname shiftLeft +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) -#' @rdname shiftRight +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) -#' @rdname shiftRightUnsigned +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) -#' @rdname sign +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) #' @rdname size @@ -1386,12 +1399,14 @@ setGeneric("substring_index", function(x, delim, count) { standardGeneric("subst #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname toDegrees +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname toRadians +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column_datetime_functions @@ -1425,8 +1440,9 @@ setGeneric("trim", function(x) { standardGeneric("trim") }) #' @export setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) -#' @rdname unhex +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @rdname column_datetime_functions From 838effb98a0d3410766771533402ce0386133af3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Jun 2017 14:28:40 +0800 Subject: [PATCH 029/779] Revert "[SPARK-13534][PYSPARK] Using Apache Arrow to increase performance of DataFrame.toPandas" This reverts commit e44697606f429b01808c1a22cb44cb5b89585c5c. --- bin/pyspark | 2 +- dev/deps/spark-deps-hadoop-2.6 | 5 - dev/deps/spark-deps-hadoop-2.7 | 5 - dev/run-pip-tests | 6 - pom.xml | 20 - python/pyspark/serializers.py | 17 - python/pyspark/sql/dataframe.py | 48 +- python/pyspark/sql/tests.py | 79 +- .../apache/spark/sql/internal/SQLConf.scala | 22 - sql/core/pom.xml | 4 - .../scala/org/apache/spark/sql/Dataset.scala | 20 - .../sql/execution/arrow/ArrowConverters.scala | 429 ------ .../arrow/ArrowConvertersSuite.scala | 1222 ----------------- 13 files changed, 13 insertions(+), 1866 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala diff --git a/bin/pyspark b/bin/pyspark index 8eeea7716cc98..98387c2ec5b8a 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$@" + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9868c1ab7c2ab..9287bd47cf113 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,9 +13,6 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -58,7 +55,6 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar -flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -81,7 +77,6 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 57c78cfe12087..9127413ab6c23 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,9 +13,6 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -58,7 +55,6 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar -flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -81,7 +77,6 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 225e9209536f0..d51dde12a03c5 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -83,8 +83,6 @@ for python in "${PYTHON_EXECS[@]}"; do if [ -n "$USE_CONDA" ]; then conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools source activate "$VIRTUALENV_PATH" - conda install -y -c conda-forge pyarrow=0.4.0 - TEST_PYARROW=1 else mkdir -p "$VIRTUALENV_PATH" virtualenv --python=$python "$VIRTUALENV_PATH" @@ -122,10 +120,6 @@ for python in "${PYTHON_EXECS[@]}"; do python "$FWDIR"/dev/pip-sanity-check.py echo "Run the tests for context.py" python "$FWDIR"/python/pyspark/context.py - if [ -n "$TEST_PYARROW" ]; then - echo "Run tests for pyarrow" - SPARK_TESTING=1 "$FWDIR"/bin/pyspark pyspark.sql.tests ArrowTests - fi cd "$FWDIR" diff --git a/pom.xml b/pom.xml index f124ba45007b7..5f524079495c0 100644 --- a/pom.xml +++ b/pom.xml @@ -181,7 +181,6 @@ 2.6 1.8 1.0.0 - 0.4.0 ${java.home} @@ -1879,25 +1878,6 @@ paranamer ${paranamer.version} - - org.apache.arrow - arrow-vector - ${arrow.version} - - - com.fasterxml.jackson.core - jackson-annotations - - - com.fasterxml.jackson.core - jackson-databind - - - io.netty - netty-handler - - - diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d5c2a7518b18f..ea5e00e9eeef5 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -182,23 +182,6 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): - """ - Serializes an Arrow stream. - """ - - def dumps(self, obj): - raise NotImplementedError - - def loads(self, obj): - import pyarrow as pa - reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) - return reader.read_all() - - def __repr__(self): - return "ArrowSerializer" - - class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 760f113dfd197..0649271ed2246 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,8 +29,7 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ - UTF8Deserializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -1709,8 +1708,7 @@ def toDF(self, *cols): @since(1.3) def toPandas(self): - """ - Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. @@ -1723,42 +1721,18 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": - try: - import pyarrow - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) - return table.to_pandas() - else: - return pd.DataFrame.from_records([], columns=self.columns) - except ImportError as e: - msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enable=true" - raise ImportError("%s\n%s" % (e.message, msg)) - else: - dtype = {} - for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - if pandas_type is not None: - dtype[field.name] = pandas_type - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - return pdf + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - def _collectAsArrow(self): - """ - Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed - and available. - - .. note:: Experimental. - """ - with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer())) + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 326e8548a617c..0a1cd6856b8e8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -58,21 +58,12 @@ from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException -_have_arrow = False -try: - import pyarrow - _have_arrow = True -except: - # No Arrow, but that's okay, we'll skip those tests - pass - - class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2629,74 +2620,6 @@ def range_frame_match(): importlib.reload(window) - -@unittest.skipIf(not _have_arrow, "Arrow not installed") -class ArrowTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") - cls.schema = StructType([ - StructField("1_str_t", StringType(), True), - StructField("2_int_t", IntegerType(), True), - StructField("3_long_t", LongType(), True), - StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0), - ("b", 2, 20, 0.4, 4.0), - ("c", 3, 30, 0.8, 6.0)] - - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - - def test_unsupported_datatype(self): - schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) - df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: df.toPandas()) - - def test_null_conversion(self): - df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + - self.data) - pdf = df_null.toPandas() - null_counts = pdf.isnull().sum().tolist() - self.assertTrue(all([c == 1 for c in null_counts])) - - def test_toPandas_arrow_toggle(self): - df = self.spark.createDataFrame(self.data, schema=self.schema) - self.spark.conf.set("spark.sql.execution.arrow.enable", "false") - pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enable", "true") - pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) - - def test_pandas_round_trip(self): - import pandas as pd - import numpy as np - data_dict = {} - for j, name in enumerate(self.schema.names): - data_dict[name] = [self.data[i][j] for i in range(len(self.data))] - # need to convert these to numpy types first - data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) - data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) - pdf = pd.DataFrame(data=data_dict) - df = self.spark.createDataFrame(self.data, schema=self.schema) - pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) - - def test_filtered_frame(self): - df = self.spark.range(3).toDF("i") - pdf = df.filter("i < 0").toPandas() - self.assertEqual(len(pdf.columns), 1) - self.assertEqual(pdf.columns[0], "i") - self.assertTrue(pdf.empty) - - if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9c8e26a8eeadf..c641e4d3a23e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -847,24 +847,6 @@ object SQLConf { .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) - val ARROW_EXECUTION_ENABLE = - buildConf("spark.sql.execution.arrow.enable") - .internal() - .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + - "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + - "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + - "LongType, ShortType") - .booleanConf - .createWithDefault(false) - - val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = - buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") - .internal() - .doc("When using Apache Arrow, limit the maximum number of records that can be written " + - "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") - .intConf - .createWithDefault(10000) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1123,10 +1105,6 @@ class SQLConf extends Serializable with Logging { def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) - def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) - - def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 661c31ded7148..1bc34a6b069d9 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,10 +103,6 @@ jackson-databind ${fasterxml.jackson.version} - - org.apache.arrow - arrow-vector - org.apache.xbean xbean-asm5-shaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 268a37ff5d271..7be4aa1ca9562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,7 +47,6 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -2887,16 +2886,6 @@ class Dataset[T] private[sql]( } } - /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. - */ - private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") - } - } - private[sql] def toPythonIterator(): Int = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) @@ -2978,13 +2967,4 @@ class Dataset[T] private[sql]( Dataset(sparkSession, logicalPlan) } } - - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { - val schemaCaptured = this.schema - val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch - queryExecution.toRdd.mapPartitionsInternal { iter => - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala deleted file mode 100644 index 6af5c73422377..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ /dev/null @@ -1,429 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.arrow - -import java.io.ByteArrayOutputStream -import java.nio.channels.Channels - -import scala.collection.JavaConverters._ - -import io.netty.buffer.ArrowBuf -import org.apache.arrow.memory.{BufferAllocator, RootAllocator} -import org.apache.arrow.vector._ -import org.apache.arrow.vector.BaseValueVector.BaseMutator -import org.apache.arrow.vector.file._ -import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} -import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - - -/** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. - */ -private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { - - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { - ArrowConverters.byteArrayToBatch(payload, allocator) - } - - /** - * Get the ArrowPayload as a type that can be served to Python. - */ - def asPythonSerializable: Array[Byte] = payload -} - -private[sql] object ArrowPayload { - - /** - * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. - */ - def apply( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): ArrowPayload = { - new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) - } -} - -private[sql] object ArrowConverters { - - /** - * Map a Spark DataType to ArrowType. - */ - private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { - dataType match { - case BooleanType => ArrowType.Bool.INSTANCE - case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) - case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) - case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case ByteType => new ArrowType.Int(8, true) - case StringType => ArrowType.Utf8.INSTANCE - case BinaryType => ArrowType.Binary.INSTANCE - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } - - /** - * Convert a Spark Dataset schema to Arrow schema. - */ - private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map { f => - new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) - } - new Schema(arrowFields.toList.asJava) - } - - /** - * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload - * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. - */ - private[sql] def toPayloadIterator( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { - new Iterator[ArrowPayload] { - private val _allocator = new RootAllocator(Long.MaxValue) - private var _nextPayload = if (rowIter.nonEmpty) convert() else null - - override def hasNext: Boolean = _nextPayload != null - - override def next(): ArrowPayload = { - val obj = _nextPayload - if (hasNext) { - if (rowIter.hasNext) { - _nextPayload = convert() - } else { - _allocator.close() - _nextPayload = null - } - } - obj - } - - private def convert(): ArrowPayload = { - val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) - ArrowPayload(batch, schema, _allocator) - } - } - } - - /** - * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed - * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, - * then rowIter will be fully consumed. - */ - private def internalRowIterToArrowBatch( - rowIter: Iterator[InternalRow], - schema: StructType, - allocator: BufferAllocator, - maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { - - val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => - ColumnWriter(field.dataType, ordinal, allocator).init() - } - - val writerLength = columnWriters.length - var recordsInBatch = 0 - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { - val row = rowIter.next() - var i = 0 - while (i < writerLength) { - columnWriters(i).write(row) - i += 1 - } - recordsInBatch += 1 - } - - val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip - val buffers = bufferArrays.flatten - - val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 - val recordBatch = new ArrowRecordBatch(rowLength, - fieldNodes.toList.asJava, buffers.toList.asJava) - - buffers.foreach(_.release()) - recordBatch - } - - /** - * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, - * the batch can no longer be used. - */ - private[arrow] def batchToByteArray( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): Array[Byte] = { - val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) - - // Write a batch to byte stream, ensure the batch, allocator and writer are closed - Utils.tryWithSafeFinally { - val loader = new VectorLoader(root) - loader.load(batch) - writer.writeBatch() // writeBatch can throw IOException - } { - batch.close() - root.close() - writer.close() - } - out.toByteArray - } - - /** - * Convert a byte array to an ArrowRecordBatch. - */ - private[arrow] def byteArrayToBatch( - batchBytes: Array[Byte], - allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } - } -} - -/** - * Interface for writing InternalRows to Arrow Buffers. - */ -private[arrow] trait ColumnWriter { - def init(): this.type - def write(row: InternalRow): Unit - - /** - * Clear the column writer and return the ArrowFieldNode and ArrowBuf. - * This should be called only once after all the data is written. - */ - def finish(): (ArrowFieldNode, Array[ArrowBuf]) -} - -/** - * Base class for flat arrow column writer, i.e., column without children. - */ -private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) - extends ColumnWriter { - - def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) - - def valueVector: BaseDataValueVector - def valueMutator: BaseMutator - - def setNull(): Unit - def setValue(row: InternalRow): Unit - - protected var count = 0 - protected var nullCount = 0 - - override def init(): this.type = { - valueVector.allocateNew() - this - } - - override def write(row: InternalRow): Unit = { - if (row.isNullAt(ordinal)) { - setNull() - nullCount += 1 - } else { - setValue(row) - } - count += 1 - } - - override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { - valueMutator.setValueCount(count) - val fieldNode = new ArrowFieldNode(count, nullCount) - val valueBuffers = valueVector.getBuffers(true) - (fieldNode, valueBuffers) - } -} - -private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBitVector - = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) -} - -private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableSmallIntVector - = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) - override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getShort(ordinal)) -} - -private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableIntVector - = new NullableIntVector("IntValue", getFieldType(dtype), allocator) - override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getInt(ordinal)) -} - -private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBigIntVector - = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getLong(ordinal)) -} - -private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat4Vector - = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getFloat(ordinal)) -} - -private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat8Vector - = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getDouble(ordinal)) -} - -private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableUInt1Vector - = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) - override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getByte(ordinal)) -} - -private[arrow] class UTF8StringColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarCharVector - = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val str = row.getUTF8String(ordinal) - valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) - } -} - -private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val bytes = row.getBinary(ordinal) - valueMutator.setSafe(count, bytes, 0, bytes.length) - } -} - -private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableDateDayVector - = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) - override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getInt(ordinal)) - } -} - -private[arrow] class TimeStampColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableTimeStampMicroVector - = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) - override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getLong(ordinal)) - } -} - -private[arrow] object ColumnWriter { - - /** - * Create an Arrow ColumnWriter given the type and ordinal of row. - */ - def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { - val dtype = ArrowConverters.sparkTypeToArrowType(dataType) - dataType match { - case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) - case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) - case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) - case LongType => new LongColumnWriter(dtype, ordinal, allocator) - case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) - case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) - case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) - case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) - case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) - case DateType => new DateColumnWriter(dtype, ordinal, allocator) - case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala deleted file mode 100644 index 159328cc0d958..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ /dev/null @@ -1,1222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.arrow - -import java.io.File -import java.nio.charset.StandardCharsets -import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat -import java.util.Locale - -import com.google.common.io.Files -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} -import org.apache.arrow.vector.file.json.JsonFileReader -import org.apache.arrow.vector.util.Validator -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, StructField, StructType} -import org.apache.spark.util.Utils - - -class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { - import testImplicits._ - - private var tempDataPath: String = _ - - override def beforeAll(): Unit = { - super.beforeAll() - tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath - } - - test("collect to arrow record batch") { - val indexData = (1 to 6).toDF("i") - val arrowPayloads = indexData.toArrowPayload.collect() - assert(arrowPayloads.nonEmpty) - assert(arrowPayloads.length == indexData.rdd.getNumPartitions) - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - val rowCount = arrowRecordBatches.map(_.getLength).sum - assert(rowCount === indexData.count()) - arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) - arrowRecordBatches.foreach(_.close()) - allocator.close() - } - - test("short conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_s", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 16 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } - | }, { - | "name" : "b_s", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 16 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_s", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ] - | }, { - | "name" : "b_s", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -32768 ] - | } ] - | } ] - |} - """.stripMargin - - val a_s = List[Short](1, -1, 2, -2, 32767, -32768) - val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) - val df = a_s.zip(b_s).toDF("a_s", "b_s") - - collectAndValidate(df, json, "integer-16bit.json") - } - - test("int conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] - | }, { - | "name" : "b_i", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] - | } ] - | } ] - |} - """.stripMargin - - val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) - val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) - val df = a_i.zip(b_i).toDF("a_i", "b_i") - - collectAndValidate(df, json, "integer-32bit.json") - } - - test("long conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_l", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 64 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "b_l", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 64 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_l", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ] - | }, { - | "name" : "b_l", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ] - | } ] - | } ] - |} - """.stripMargin - - val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) - val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) - val df = a_l.zip(b_l).toDF("a_l", "b_l") - - collectAndValidate(df, json, "integer-64bit.json") - } - - test("float conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_f", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b_f", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_f", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ] - | }, { - | "name" : "b_f", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] - | } ] - | } ] - |} - """.stripMargin - - val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) - val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) - val df = a_f.zip(b_f).toDF("a_f", "b_f") - - collectAndValidate(df, json, "floating_point-single_precision.json") - } - - test("double conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "b_d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_d", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ] - | }, { - | "name" : "b_d", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] - | } ] - | } ] - |} - """.stripMargin - - val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) - val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) - val df = a_d.zip(b_d).toDF("a_d", "b_d") - - collectAndValidate(df, json, "floating_point-double_precision.json") - } - - test("index conversion") { - val data = List[Int](1, 2, 3, 4, 5, 6) - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | } ] - | } ] - |} - """.stripMargin - val df = data.toDF("i") - - collectAndValidate(df, json, "indexData-ints.json") - } - - test("mixed numeric type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 16 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } - | }, { - | "name" : "b", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "c", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "e", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 64 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | }, { - | "name" : "b", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] - | }, { - | "name" : "c", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | }, { - | "name" : "d", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] - | }, { - | "name" : "e", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | } ] - | } ] - |} - """.stripMargin - - val data = List(1, 2, 3, 4, 5, 6) - val data_tuples = for (d <- data) yield { - (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) - } - val df = data_tuples.toDF("a", "b", "c", "d", "e") - - collectAndValidate(df, json, "mixed_numeric_types.json") - } - - test("string type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "upper_case", - | "type" : { - | "name" : "utf8" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | }, { - | "name" : "lower_case", - | "type" : { - | "name" : "utf8" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | }, { - | "name" : "null_str", - | "type" : { - | "name" : "utf8" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "upper_case", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "OFFSET" : [ 0, 1, 2, 3 ], - | "DATA" : [ "A", "B", "C" ] - | }, { - | "name" : "lower_case", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "OFFSET" : [ 0, 1, 2, 3 ], - | "DATA" : [ "a", "b", "c" ] - | }, { - | "name" : "null_str", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 0 ], - | "OFFSET" : [ 0, 2, 5, 5 ], - | "DATA" : [ "ab", "CDE", "" ] - | } ] - | } ] - |} - """.stripMargin - - val upperCase = Seq("A", "B", "C") - val lowerCase = Seq("a", "b", "c") - val nullStr = Seq("ab", "CDE", null) - val df = (upperCase, lowerCase, nullStr).zipped.toList - .toDF("upper_case", "lower_case", "null_str") - - collectAndValidate(df, json, "stringData.json") - } - - test("boolean type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_bool", - | "type" : { - | "name" : "bool" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 1 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 4, - | "columns" : [ { - | "name" : "a_bool", - | "count" : 4, - | "VALIDITY" : [ 1, 1, 1, 1 ], - | "DATA" : [ true, true, false, true ] - | } ] - | } ] - |} - """.stripMargin - val df = Seq(true, true, false, true).toDF("a_bool") - collectAndValidate(df, json, "boolData.json") - } - - test("byte type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_byte", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 8 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 4, - | "columns" : [ { - | "name" : "a_byte", - | "count" : 4, - | "VALIDITY" : [ 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 64, 127 ] - | } ] - | } ] - |} - | - """.stripMargin - val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") - collectAndValidate(df, json, "byteData.json") - } - - test("binary type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_binary", - | "type" : { - | "name" : "binary" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "a_binary", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "OFFSET" : [ 0, 3, 4, 6 ], - | "DATA" : [ "616263", "64", "6566" ] - | } ] - | } ] - |} - """.stripMargin - - val data = Seq("abc", "d", "ef") - val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) - val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) - - collectAndValidate(df, json, "binaryData.json") - } - - test("floating-point NaN") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "NaN_f", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "NaN_d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 2, - | "columns" : [ { - | "name" : "NaN_f", - | "count" : 2, - | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ 1.2000000476837158, "NaN" ] - | }, { - | "name" : "NaN_d", - | "count" : 2, - | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ "NaN", 1.2 ] - | } ] - | } ] - |} - """.stripMargin - - val fnan = Seq(1.2F, Float.NaN) - val dnan = Seq(Double.NaN, 1.2) - val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") - - collectAndValidate(df, json, "nanData-floating_point.json") - } - - test("partitioned DataFrame") { - val json1 = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "a", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 1, 1, 2 ] - | }, { - | "name" : "b", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 1, 2, 1 ] - | } ] - | } ] - |} - """.stripMargin - val json2 = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "a", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 2, 3, 3 ] - | }, { - | "name" : "b", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 2, 1, 2 ] - | } ] - | } ] - |} - """.stripMargin - - val arrowPayloads = testData2.toArrowPayload.collect() - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowPayloads.length === 2) - val schema = testData2.schema - - val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") - val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") - Files.write(json1, tempFile1, StandardCharsets.UTF_8) - Files.write(json2, tempFile2, StandardCharsets.UTF_8) - - validateConversion(schema, arrowPayloads(0), tempFile1) - validateConversion(schema, arrowPayloads(1), tempFile2) - } - - test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() - assert(arrowPayload.isEmpty) - - val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() - assert(filteredArrowPayload.isEmpty) - } - - test("empty partition collect") { - val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowPayload.collect() - assert(arrowPayloads.length === 1) - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - assert(arrowRecordBatches.head.getLength == 1) - arrowRecordBatches.foreach(_.close()) - allocator.close() - } - - test("max records in batch conf") { - val totalRecords = 10 - val maxRecordsPerBatch = 3 - spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) - val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowPayloads = df.toArrowPayload.collect() - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - var recordCount = 0 - arrowRecordBatches.foreach { batch => - assert(batch.getLength > 0) - assert(batch.getLength <= maxRecordsPerBatch) - recordCount += batch.getLength - batch.close() - } - assert(recordCount == totalRecords) - allocator.close() - spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") - } - - testQuietly("unsupported types") { - def runUnsupported(block: => Unit): Unit = { - val msg = intercept[SparkException] { - block - } - assert(msg.getMessage.contains("Unsupported data type")) - assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) - } - - runUnsupported { decimalData.toArrowPayload.collect() } - runUnsupported { arrayData.toDF().toArrowPayload.collect() } - runUnsupported { mapData.toDF().toArrowPayload.collect() } - runUnsupported { complexData.toArrowPayload.collect() } - - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) - runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } - - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } - } - - test("test Arrow Validator") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] - | }, { - | "name" : "b_i", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] - | } ] - | } ] - |} - """.stripMargin - val json_diff_col_order = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "b_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "a_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] - | }, { - | "name" : "b_i", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] - | } ] - | } ] - |} - """.stripMargin - - val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) - val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) - val df = a_i.zip(b_i).toDF("a_i", "b_i") - - // Different schema - intercept[IllegalArgumentException] { - collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json") - } - - // Different values - intercept[IllegalArgumentException] { - collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json") - } - } - - /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { - // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = df.coalesce(1).toArrowPayload.collect().head - val tempFile = new File(tempDataPath, file) - Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile) - } - - private def validateConversion( - sparkSchema: StructType, - arrowPayload: ArrowPayload, - jsonFile: File): Unit = { - val allocator = new RootAllocator(Long.MaxValue) - val jsonReader = new JsonFileReader(jsonFile, allocator) - - val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) - val jsonSchema = jsonReader.start() - Validator.compareSchemas(arrowSchema, jsonSchema) - - val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) - val vectorLoader = new VectorLoader(arrowRoot) - val arrowRecordBatch = arrowPayload.loadBatch(allocator) - vectorLoader.load(arrowRecordBatch) - val jsonRoot = jsonReader.read() - Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) - - jsonRoot.close() - jsonReader.close() - arrowRecordBatch.close() - arrowRoot.close() - allocator.close() - } -} From e68aed70fbf1cfa59ba51df70287d718d737a193 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 28 Jun 2017 10:45:45 -0700 Subject: [PATCH 030/779] [SPARK-21216][SS] Hive strategies missed in Structured Streaming IncrementalExecution ## What changes were proposed in this pull request? If someone creates a HiveSession, the planner in `IncrementalExecution` doesn't take into account the Hive scan strategies. This causes joins of Streaming DataFrame's with Hive tables to fail. ## How was this patch tested? Regression test Author: Burak Yavuz Closes #18426 from brkyvz/hive-join. --- .../streaming/IncrementalExecution.scala | 4 ++ .../sql/hive/execution/HiveDDLSuite.scala | 41 ++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ab89dc6b705d5..dbe652b3b1ed2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -47,6 +47,10 @@ class IncrementalExecution( sparkSession.sparkContext, sparkSession.sessionState.conf, sparkSession.sessionState.experimentalMethods) { + override def strategies: Seq[Strategy] = + extraPlanningStrategies ++ + sparkSession.sessionState.planner.strategies + override def extraPlanningStrategies: Seq[Strategy] = StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index aca964907d4cd..31fa3d2447467 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -160,7 +160,6 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA test("drop table") { testDropTable(isDatasourceTable = false) } - } class HiveDDLSuite @@ -1956,4 +1955,44 @@ class HiveDDLSuite } } } + + test("SPARK-21216: join with a streaming DataFrame") { + import org.apache.spark.sql.execution.streaming.MemoryStream + import testImplicits._ + + implicit val _sqlContext = spark.sqlContext + + Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word").createOrReplaceTempView("t1") + // Make a table and ensure it will be broadcast. + sql("""CREATE TABLE smallTable(word string, number int) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |STORED AS TEXTFILE + """.stripMargin) + + sql( + """INSERT INTO smallTable + |SELECT word, number from t1 + """.stripMargin) + + val inputData = MemoryStream[Int] + val joined = inputData.toDS().toDF() + .join(spark.table("smallTable"), $"value" === $"number") + + val sq = joined.writeStream + .format("memory") + .queryName("t2") + .start() + try { + inputData.addData(1, 2) + + sq.processAllAvailable() + + checkAnswer( + spark.table("t2"), + Seq(Row(1, "one", 1), Row(2, "two", 2)) + ) + } finally { + sq.stop() + } + } } From b72b8521d9cad878a1a4e4dbb19cf980169dcbc7 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 29 Jun 2017 08:47:31 +0800 Subject: [PATCH 031/779] [SPARK-21222] Move elimination of Distinct clause from analyzer to optimizer ## What changes were proposed in this pull request? Move elimination of Distinct clause from analyzer to optimizer Distinct clause is useless after MAX/MIN clause. For example, "Select MAX(distinct a) FROM src from" is equivalent of "Select MAX(a) FROM src from" However, this optimization is implemented in analyzer. It should be in optimizer. ## How was this patch tested? Unit test gatorsmile cloud-fan Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18429 from gengliangwang/distinct_opt. --- .../sql/catalyst/analysis/Analyzer.scala | 5 -- .../spark/sql/catalyst/dsl/package.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 15 +++++ .../optimizer/EliminateDistinctSuite.scala | 56 +++++++++++++++++++ 4 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 434b6ffee37fa..53536496d0457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1197,11 +1197,6 @@ class Analyzer( case u @ UnresolvedFunction(funcId, children, isDistinct) => withPosition(u) { catalog.lookupFunction(funcId, children) match { - // DISTINCT is not meaningful for a Max or a Min. - case max: Max if isDistinct => - AggregateExpression(max, Complete, isDistinct = false) - case min: Min if isDistinct => - AggregateExpression(min, Complete, isDistinct = false) // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index beee93d906f0f..f6792569b704e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,7 +159,9 @@ package object dsl { def first(e: Expression): Expression = new First(e).toAggregateExpression() def last(e: Expression): Expression = new Last(e).toAggregateExpression() def min(e: Expression): Expression = Min(e).toAggregateExpression() + def minDistinct(e: Expression): Expression = Min(e).toAggregateExpression(isDistinct = true) def max(e: Expression): Expression = Max(e).toAggregateExpression() + def maxDistinct(e: Expression): Expression = Max(e).toAggregateExpression(isDistinct = true) def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b410312030c5d..946fa7bae0199 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,6 +40,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) def batches: Seq[Batch] = { + Batch("Eliminate Distinct", Once, EliminateDistinct) :: // Technically some of the rules in Finish Analysis are not optimizer rules and belong more // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). // However, because we also use the analyzer to canonicalized queries (for view definition), @@ -151,6 +152,20 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil } +/** + * Remove useless DISTINCT for MAX and MIN. + * This rule should be applied before RewriteDistinctAggregates. + */ +object EliminateDistinct extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressions { + case ae: AggregateExpression if ae.isDistinct => + ae.aggregateFunction match { + case _: Max | _: Min => ae.copy(isDistinct = false) + case _ => ae + } + } +} + /** * An optimizer used in test code. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala new file mode 100644 index 0000000000000..f40691bd1a038 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class EliminateDistinctSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", Once, + EliminateDistinct) :: Nil + } + + val testRelation = LocalRelation('a.int) + + test("Eliminate Distinct in Max") { + val query = testRelation + .select(maxDistinct('a).as('result)) + .analyze + val answer = testRelation + .select(max('a).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } + + test("Eliminate Distinct in Min") { + val query = testRelation + .select(minDistinct('a).as('result)) + .analyze + val answer = testRelation + .select(min('a).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } +} From 376d90d556fcd4fd84f70ee42a1323e1f48f829d Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 28 Jun 2017 19:31:54 -0700 Subject: [PATCH 032/779] [SPARK-20889][SPARKR] Grouped documentation for STRING column methods ## What changes were proposed in this pull request? Grouped documentation for string column methods. Author: actuaryzhang Author: Wayne Zhang Closes #18366 from actuaryzhang/sparkRDocString. --- R/pkg/R/functions.R | 573 +++++++++++++++++++------------------------- R/pkg/R/generics.R | 84 ++++--- 2 files changed, 300 insertions(+), 357 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 23ccdf941a8c7..70ea620b471fe 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -111,6 +111,27 @@ NULL #' head(tmp)} NULL +#' String functions for Column operations +#' +#' String functions defined for \code{Column}. +#' +#' @param x Column to compute on except in the following methods: +#' \itemize{ +#' \item \code{instr}: \code{character}, the substring to check. See 'Details'. +#' \item \code{format_number}: \code{numeric}, the number of decimal place to +#' format to. See 'Details'. +#' } +#' @param y Column to compute on. +#' @param ... additional columns. +#' @name column_string_functions +#' @rdname column_string_functions +#' @family string functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(as.data.frame(Titanic, stringsAsFactors = FALSE))} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -188,19 +209,17 @@ setMethod("approxCountDistinct", column(jc) }) -#' ascii -#' -#' Computes the numeric value of the first character of the string column, and returns the -#' result as a int column. -#' -#' @param x Column to compute on. +#' @details +#' \code{ascii}: Computes the numeric value of the first character of the string column, +#' and returns the result as an int column. #' -#' @rdname ascii -#' @name ascii -#' @family string functions +#' @rdname column_string_functions #' @export -#' @aliases ascii,Column-method -#' @examples \dontrun{\dontrun{ascii(df$c)}} +#' @aliases ascii ascii,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, ascii(df$Class), ascii(df$Sex)))} #' @note ascii since 1.5.0 setMethod("ascii", signature(x = "Column"), @@ -256,19 +275,22 @@ setMethod("avg", column(jc) }) -#' base64 -#' -#' Computes the BASE64 encoding of a binary column and returns it as a string column. -#' This is the reverse of unbase64. -#' -#' @param x Column to compute on. +#' @details +#' \code{base64}: Computes the BASE64 encoding of a binary column and returns it as +#' a string column. This is the reverse of unbase64. #' -#' @rdname base64 -#' @name base64 -#' @family string functions +#' @rdname column_string_functions #' @export -#' @aliases base64,Column-method -#' @examples \dontrun{base64(df$c)} +#' @aliases base64 base64,Column-method +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = encode(df$Class, "UTF-8")) +#' str(tmp) +#' tmp2 <- mutate(tmp, s2 = base64(tmp$s1), s3 = decode(tmp$s1, "UTF-8"), +#' s4 = soundex(tmp$Sex)) +#' head(tmp2) +#' head(select(tmp2, unbase64(tmp2$s2)))} #' @note base64 since 1.5.0 setMethod("base64", signature(x = "Column"), @@ -620,20 +642,16 @@ setMethod("dayofyear", column(jc) }) -#' decode -#' -#' Computes the first argument into a string from a binary using the provided character set -#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' @details +#' \code{decode}: Computes the first argument into a string from a binary using the provided +#' character set. #' -#' @param x Column to compute on. -#' @param charset Character set to use +#' @param charset Character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", +#' "UTF-16LE", "UTF-16"). #' -#' @rdname decode -#' @name decode -#' @family string functions -#' @aliases decode,Column,character-method +#' @rdname column_string_functions +#' @aliases decode decode,Column,character-method #' @export -#' @examples \dontrun{decode(df$c, "UTF-8")} #' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), @@ -642,20 +660,13 @@ setMethod("decode", column(jc) }) -#' encode -#' -#' Computes the first argument into a binary from a string using the provided character set -#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). -#' -#' @param x Column to compute on. -#' @param charset Character set to use +#' @details +#' \code{encode}: Computes the first argument into a binary from a string using the provided +#' character set. #' -#' @rdname encode -#' @name encode -#' @family string functions -#' @aliases encode,Column,character-method +#' @rdname column_string_functions +#' @aliases encode encode,Column,character-method #' @export -#' @examples \dontrun{encode(df$c, "UTF-8")} #' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), @@ -788,21 +799,23 @@ setMethod("hour", column(jc) }) -#' initcap -#' -#' Returns a new string column by converting the first letter of each word to uppercase. -#' Words are delimited by whitespace. -#' -#' For example, "hello world" will become "Hello World". -#' -#' @param x Column to compute on. +#' @details +#' \code{initcap}: Returns a new string column by converting the first letter of +#' each word to uppercase. Words are delimited by whitespace. For example, "hello world" +#' will become "Hello World". #' -#' @rdname initcap -#' @name initcap -#' @family string functions -#' @aliases initcap,Column-method +#' @rdname column_string_functions +#' @aliases initcap initcap,Column-method #' @export -#' @examples \dontrun{initcap(df$c)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, sex_lower = lower(df$Sex), age_upper = upper(df$age), +#' sex_age = concat_ws(" ", lower(df$sex), lower(df$age))) +#' head(tmp) +#' tmp2 <- mutate(tmp, s1 = initcap(tmp$sex_lower), s2 = initcap(tmp$sex_age), +#' s3 = reverse(df$Sex)) +#' head(tmp2)} #' @note initcap since 1.5.0 setMethod("initcap", signature(x = "Column"), @@ -918,18 +931,12 @@ setMethod("last_day", column(jc) }) -#' length -#' -#' Computes the length of a given string or binary column. -#' -#' @param x Column to compute on. +#' @details +#' \code{length}: Computes the length of a given string or binary column. #' -#' @rdname length -#' @name length -#' @aliases length,Column-method -#' @family string functions +#' @rdname column_string_functions +#' @aliases length length,Column-method #' @export -#' @examples \dontrun{length(df$c)} #' @note length since 1.5.0 setMethod("length", signature(x = "Column"), @@ -994,18 +1001,12 @@ setMethod("log2", column(jc) }) -#' lower -#' -#' Converts a string column to lower case. -#' -#' @param x Column to compute on. +#' @details +#' \code{lower}: Converts a string column to lower case. #' -#' @rdname lower -#' @name lower -#' @family string functions -#' @aliases lower,Column-method +#' @rdname column_string_functions +#' @aliases lower lower,Column-method #' @export -#' @examples \dontrun{lower(df$c)} #' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), @@ -1014,18 +1015,24 @@ setMethod("lower", column(jc) }) -#' ltrim -#' -#' Trim the spaces from left end for the specified string value. -#' -#' @param x Column to compute on. +#' @details +#' \code{ltrim}: Trims the spaces from left end for the specified string value. #' -#' @rdname ltrim -#' @name ltrim -#' @family string functions -#' @aliases ltrim,Column-method +#' @rdname column_string_functions +#' @aliases ltrim ltrim,Column-method #' @export -#' @examples \dontrun{ltrim(df$c)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, SexLpad = lpad(df$Sex, 6, " "), SexRpad = rpad(df$Sex, 7, " ")) +#' head(select(tmp, length(tmp$Sex), length(tmp$SexLpad), length(tmp$SexRpad))) +#' tmp2 <- mutate(tmp, SexLtrim = ltrim(tmp$SexLpad), SexRtrim = rtrim(tmp$SexRpad), +#' SexTrim = trim(tmp$SexLpad)) +#' head(select(tmp2, length(tmp2$Sex), length(tmp2$SexLtrim), +#' length(tmp2$SexRtrim), length(tmp2$SexTrim))) +#' +#' tmp <- mutate(df, SexLpad = lpad(df$Sex, 6, "xx"), SexRpad = rpad(df$Sex, 7, "xx")) +#' head(tmp)} #' @note ltrim since 1.5.0 setMethod("ltrim", signature(x = "Column"), @@ -1198,18 +1205,12 @@ setMethod("quarter", column(jc) }) -#' reverse -#' -#' Reverses the string column and returns it as a new string column. -#' -#' @param x Column to compute on. +#' @details +#' \code{reverse}: Reverses the string column and returns it as a new string column. #' -#' @rdname reverse -#' @name reverse -#' @family string functions -#' @aliases reverse,Column-method +#' @rdname column_string_functions +#' @aliases reverse reverse,Column-method #' @export -#' @examples \dontrun{reverse(df$c)} #' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), @@ -1268,18 +1269,12 @@ setMethod("bround", column(jc) }) -#' rtrim -#' -#' Trim the spaces from right end for the specified string value. -#' -#' @param x Column to compute on. +#' @details +#' \code{rtrim}: Trims the spaces from right end for the specified string value. #' -#' @rdname rtrim -#' @name rtrim -#' @family string functions -#' @aliases rtrim,Column-method +#' @rdname column_string_functions +#' @aliases rtrim rtrim,Column-method #' @export -#' @examples \dontrun{rtrim(df$c)} #' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column"), @@ -1409,18 +1404,12 @@ setMethod("skewness", column(jc) }) -#' soundex -#' -#' Return the soundex code for the specified expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{soundex}: Returns the soundex code for the specified expression. #' -#' @rdname soundex -#' @name soundex -#' @family string functions -#' @aliases soundex,Column-method +#' @rdname column_string_functions +#' @aliases soundex soundex,Column-method #' @export -#' @examples \dontrun{soundex(df$c)} #' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), @@ -1731,18 +1720,12 @@ setMethod("to_timestamp", column(jc) }) -#' trim -#' -#' Trim the spaces from both ends for the specified string column. -#' -#' @param x Column to compute on. +#' @details +#' \code{trim}: Trims the spaces from both ends for the specified string column. #' -#' @rdname trim -#' @name trim -#' @family string functions -#' @aliases trim,Column-method +#' @rdname column_string_functions +#' @aliases trim trim,Column-method #' @export -#' @examples \dontrun{trim(df$c)} #' @note trim since 1.5.0 setMethod("trim", signature(x = "Column"), @@ -1751,19 +1734,13 @@ setMethod("trim", column(jc) }) -#' unbase64 -#' -#' Decodes a BASE64 encoded string column and returns it as a binary column. +#' @details +#' \code{unbase64}: Decodes a BASE64 encoded string column and returns it as a binary column. #' This is the reverse of base64. #' -#' @param x Column to compute on. -#' -#' @rdname unbase64 -#' @name unbase64 -#' @family string functions -#' @aliases unbase64,Column-method +#' @rdname column_string_functions +#' @aliases unbase64 unbase64,Column-method #' @export -#' @examples \dontrun{unbase64(df$c)} #' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), @@ -1787,18 +1764,12 @@ setMethod("unhex", column(jc) }) -#' upper -#' -#' Converts a string column to upper case. -#' -#' @param x Column to compute on. +#' @details +#' \code{upper}: Converts a string column to upper case. #' -#' @rdname upper -#' @name upper -#' @family string functions -#' @aliases upper,Column-method +#' @rdname column_string_functions +#' @aliases upper upper,Column-method #' @export -#' @examples \dontrun{upper(df$c)} #' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), @@ -1949,19 +1920,19 @@ setMethod("hypot", signature(y = "Column"), column(jc) }) -#' levenshtein -#' -#' Computes the Levenshtein distance of the two given string columns. -#' -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{levenshtein}: Computes the Levenshtein distance of the two given string columns. #' -#' @rdname levenshtein -#' @name levenshtein -#' @family string functions -#' @aliases levenshtein,Column-method +#' @rdname column_string_functions +#' @aliases levenshtein levenshtein,Column-method #' @export -#' @examples \dontrun{levenshtein(df$c, x)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, d1 = levenshtein(df$Class, df$Sex), +#' d2 = levenshtein(df$Age, df$Sex), +#' d3 = levenshtein(df$Age, df$Age)) +#' head(tmp)} #' @note levenshtein since 1.5.0 setMethod("levenshtein", signature(y = "Column"), function(y, x) { @@ -2061,20 +2032,22 @@ setMethod("countDistinct", column(jc) }) - -#' concat -#' -#' Concatenates multiple input string columns together into a single string column. -#' -#' @param x Column to compute on -#' @param ... other columns +#' @details +#' \code{concat}: Concatenates multiple input string columns together into a single string column. #' -#' @family string functions -#' @rdname concat -#' @name concat -#' @aliases concat,Column-method +#' @rdname column_string_functions +#' @aliases concat concat,Column-method #' @export -#' @examples \dontrun{concat(df$strings, df$strings2)} +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), +#' s2 = concat(df$Class, df$Sex, df$Age), +#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), +#' s4 = concat_ws("_", df$Class, df$Sex), +#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2243,22 +2216,21 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), column(jc) }) -#' instr -#' -#' Locate the position of the first occurrence of substr column in the given string. -#' Returns null if either of the arguments are null. -#' -#' Note: The position is not zero based, but 1 based index. Returns 0 if substr -#' could not be found in str. +#' @details +#' \code{instr}: Locates the position of the first occurrence of a substring (\code{x}) +#' in the given string column (\code{y}). Returns null if either of the arguments are null. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the substring +#' could not be found in the string column. #' -#' @param y column to check -#' @param x substring to check -#' @family string functions -#' @aliases instr,Column,character-method -#' @rdname instr -#' @name instr +#' @rdname column_string_functions +#' @aliases instr instr,Column,character-method #' @export -#' @examples \dontrun{instr(df$c, 'b')} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = instr(df$Sex, "m"), s2 = instr(df$Sex, "M"), +#' s3 = locate("m", df$Sex), s4 = locate("m", df$Sex, pos = 4)) +#' head(tmp)} #' @note instr since 1.5.0 setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { @@ -2345,22 +2317,22 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), column(jc) }) -#' format_number -#' -#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places -#' with HALF_EVEN round mode, and returns the result as a string column. -#' -#' If x is 0, the result has no decimal point or fractional part. -#' If x < 0, the result will be null. +#' @details +#' \code{format_number}: Formats numeric column \code{y} to a format like '#,###,###.##', +#' rounded to \code{x} decimal places with HALF_EVEN round mode, and returns the result +#' as a string column. +#' If \code{x} is 0, the result has no decimal point or fractional part. +#' If \code{x} < 0, the result will be null. #' -#' @param y column to format -#' @param x number of decimal place to format to -#' @family string functions -#' @rdname format_number -#' @name format_number -#' @aliases format_number,Column,numeric-method +#' @rdname column_string_functions +#' @aliases format_number format_number,Column,numeric-method #' @export -#' @examples \dontrun{format_number(df$n, 4)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, v1 = df$Freq/3) +#' head(select(tmp, format_number(tmp$v1, 0), format_number(tmp$v1, 2), +#' format_string("%4.2f %s", tmp$v1, tmp$Sex)), 10)} #' @note format_number since 1.5.0 setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2438,21 +2410,14 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), column(jc) }) -#' concat_ws -#' -#' Concatenates multiple input string columns together into a single string column, -#' using the given separator. +#' @details +#' \code{concat_ws}: Concatenates multiple input string columns together into a single +#' string column, using the given separator. #' -#' @param x column to concatenate. #' @param sep separator to use. -#' @param ... other columns to concatenate. -#' -#' @family string functions -#' @rdname concat_ws -#' @name concat_ws -#' @aliases concat_ws,character,Column-method +#' @rdname column_string_functions +#' @aliases concat_ws concat_ws,character,Column-method #' @export -#' @examples \dontrun{concat_ws('-', df$s, df$d)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2499,19 +2464,14 @@ setMethod("expr", signature(x = "character"), column(jc) }) -#' format_string -#' -#' Formats the arguments in printf-style and returns the result as a string column. +#' @details +#' \code{format_string}: Formats the arguments in printf-style and returns the result +#' as a string column. #' #' @param format a character object of format strings. -#' @param x a Column. -#' @param ... additional Column(s). -#' @family string functions -#' @rdname format_string -#' @name format_string -#' @aliases format_string,character,Column-method +#' @rdname column_string_functions +#' @aliases format_string format_string,character,Column-method #' @export -#' @examples \dontrun{format_string('%d %s', df$a, df$b)} #' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { @@ -2620,23 +2580,17 @@ setMethod("window", signature(x = "Column"), column(jc) }) -#' locate -#' -#' Locate the position of the first occurrence of substr. -#' +#' @details +#' \code{locate}: Locates the position of the first occurrence of substr. #' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. #' @param str a Column where matches are sought for each entry. #' @param pos start position of search. -#' @param ... further arguments to be passed to or from other methods. -#' @family string functions -#' @rdname locate -#' @aliases locate,character,Column-method -#' @name locate +#' @rdname column_string_functions +#' @aliases locate locate,character,Column-method #' @export -#' @examples \dontrun{locate('b', df$c, 1)} #' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 1) { @@ -2646,19 +2600,14 @@ setMethod("locate", signature(substr = "character", str = "Column"), column(jc) }) -#' lpad -#' -#' Left-pad the string column with +#' @details +#' \code{lpad}: Left-padded with pad to a length of len. #' -#' @param x the string Column to be left-padded. #' @param len maximum length of each output result. #' @param pad a character string to be padded with. -#' @family string functions -#' @rdname lpad -#' @aliases lpad,Column,numeric,character-method -#' @name lpad +#' @rdname column_string_functions +#' @aliases lpad lpad,Column,numeric,character-method #' @export -#' @examples \dontrun{lpad(df$c, 6, '#')} #' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2728,20 +2677,27 @@ setMethod("randn", signature(seed = "numeric"), column(jc) }) -#' regexp_extract -#' -#' Extract a specific \code{idx} group identified by a Java regex, from the specified string column. -#' If the regex did not match, or the specified group did not match, an empty string is returned. +#' @details +#' \code{regexp_extract}: Extracts a specific \code{idx} group identified by a Java regex, +#' from the specified string column. If the regex did not match, or the specified group did +#' not match, an empty string is returned. #' -#' @param x a string Column. #' @param pattern a regular expression. #' @param idx a group index. -#' @family string functions -#' @rdname regexp_extract -#' @name regexp_extract -#' @aliases regexp_extract,Column,character,numeric-method +#' @rdname column_string_functions +#' @aliases regexp_extract regexp_extract,Column,character,numeric-method #' @export -#' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = regexp_extract(df$Class, "(\\d+)\\w+", 1), +#' s2 = regexp_extract(df$Sex, "^(\\w)\\w+", 1), +#' s3 = regexp_replace(df$Class, "\\D+", ""), +#' s4 = substring_index(df$Sex, "a", 1), +#' s5 = substring_index(df$Sex, "a", -1), +#' s6 = translate(df$Sex, "ale", ""), +#' s7 = translate(df$Sex, "a", "-")) +#' head(tmp)} #' @note regexp_extract since 1.5.0 setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), @@ -2752,19 +2708,14 @@ setMethod("regexp_extract", column(jc) }) -#' regexp_replace -#' -#' Replace all substrings of the specified string value that match regexp with rep. +#' @details +#' \code{regexp_replace}: Replaces all substrings of the specified string value that +#' match regexp with rep. #' -#' @param x a string Column. -#' @param pattern a regular expression. #' @param replacement a character string that a matched \code{pattern} is replaced with. -#' @family string functions -#' @rdname regexp_replace -#' @name regexp_replace -#' @aliases regexp_replace,Column,character,character-method +#' @rdname column_string_functions +#' @aliases regexp_replace regexp_replace,Column,character,character-method #' @export -#' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} #' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), @@ -2775,19 +2726,12 @@ setMethod("regexp_replace", column(jc) }) -#' rpad -#' -#' Right-padded with pad to a length of len. +#' @details +#' \code{rpad}: Right-padded with pad to a length of len. #' -#' @param x the string Column to be right-padded. -#' @param len maximum length of each output result. -#' @param pad a character string to be padded with. -#' @family string functions -#' @rdname rpad -#' @name rpad -#' @aliases rpad,Column,numeric,character-method +#' @rdname column_string_functions +#' @aliases rpad rpad,Column,numeric,character-method #' @export -#' @examples \dontrun{rpad(df$c, 6, '#')} #' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2797,28 +2741,20 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), column(jc) }) -#' substring_index -#' -#' Returns the substring from string str before count occurrences of the delimiter delim. -#' If count is positive, everything the left of the final delimiter (counting from left) is -#' returned. If count is negative, every to the right of the final delimiter (counting from the -#' right) is returned. substring_index performs a case-sensitive match when searching for delim. +#' @details +#' \code{substring_index}: Returns the substring from string str before count occurrences of +#' the delimiter delim. If count is positive, everything the left of the final delimiter +#' (counting from left) is returned. If count is negative, every to the right of the final +#' delimiter (counting from the right) is returned. substring_index performs a case-sensitive +#' match when searching for delim. #' -#' @param x a Column. #' @param delim a delimiter string. #' @param count number of occurrences of \code{delim} before the substring is returned. #' A positive number means counting from the left, while negative means #' counting from the right. -#' @family string functions -#' @rdname substring_index -#' @aliases substring_index,Column,character,numeric-method -#' @name substring_index +#' @rdname column_string_functions +#' @aliases substring_index substring_index,Column,character,numeric-method #' @export -#' @examples -#'\dontrun{ -#'substring_index(df$c, '.', 2) -#'substring_index(df$c, '.', -1) -#'} #' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), @@ -2829,24 +2765,19 @@ setMethod("substring_index", column(jc) }) -#' translate -#' -#' Translate any character in the src by a character in replaceString. +#' @details +#' \code{translate}: Translates any character in the src by a character in replaceString. #' The characters in replaceString is corresponding to the characters in matchingString. #' The translate will happen when any character in the string matching with the character #' in the matchingString. #' -#' @param x a string Column. #' @param matchingString a source string where each character will be translated. #' @param replaceString a target string where each \code{matchingString} character will #' be replaced by the character in \code{replaceString} #' at the same location, if any. -#' @family string functions -#' @rdname translate -#' @name translate -#' @aliases translate,Column,character,character-method +#' @rdname column_string_functions +#' @aliases translate translate,Column,character,character-method #' @export -#' @examples \dontrun{translate(df$c, 'rnlt', '123')} #' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), @@ -3419,28 +3350,20 @@ setMethod("collect_set", column(jc) }) -#' split_string -#' -#' Splits string on regular expression. -#' -#' Equivalent to \code{split} SQL function -#' -#' @param x Column to compute on -#' @param pattern Java regular expression +#' @details +#' \code{split_string}: Splits string on regular expression. +#' Equivalent to \code{split} SQL function. #' -#' @rdname split_string -#' @family string functions -#' @aliases split_string,Column-method +#' @rdname column_string_functions +#' @aliases split_string split_string,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- read.text("README.md") -#' -#' head(select(df, split_string(df$value, "\\s+"))) #' +#' \dontrun{ +#' head(select(df, split_string(df$Sex, "a"))) +#' head(select(df, split_string(df$Class, "\\d"))) #' # This is equivalent to the following SQL expression -#' head(selectExpr(df, "split(value, '\\\\s+')")) -#' } +#' head(selectExpr(df, "split(Class, '\\\\d')"))} #' @note split_string 2.3.0 setMethod("split_string", signature(x = "Column", pattern = "character"), @@ -3449,28 +3372,20 @@ setMethod("split_string", column(jc) }) -#' repeat_string -#' -#' Repeats string n times. -#' -#' Equivalent to \code{repeat} SQL function +#' @details +#' \code{repeat_string}: Repeats string n times. +#' Equivalent to \code{repeat} SQL function. #' -#' @param x Column to compute on #' @param n Number of repetitions -#' -#' @rdname repeat_string -#' @family string functions -#' @aliases repeat_string,Column-method +#' @rdname column_string_functions +#' @aliases repeat_string repeat_string,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- read.text("README.md") -#' -#' first(select(df, repeat_string(df$value, 3))) #' +#' \dontrun{ +#' head(select(df, repeat_string(df$Class, 3))) #' # This is equivalent to the following SQL expression -#' first(selectExpr(df, "repeat(value, 3)")) -#' } +#' head(selectExpr(df, "repeat(Class, 3)"))} #' @note repeat_string since 2.3.0 setMethod("repeat_string", signature(x = "Column", n = "numeric"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0248ec585d771..dc99e3d94b269 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -917,8 +917,9 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @export setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) -#' @rdname ascii +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @param x Column to compute on or a GroupedData object. @@ -927,8 +928,9 @@ setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) -#' @rdname base64 +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("base64", function(x) { standardGeneric("base64") }) #' @rdname column_math_functions @@ -969,12 +971,14 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @export setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname concat +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) -#' @rdname concat_ws +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @rdname column_math_functions @@ -1038,8 +1042,9 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) -#' @rdname decode +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @param x empty. Should be used with no argument. @@ -1047,8 +1052,9 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @export setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) -#' @rdname encode +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @rdname explode @@ -1068,12 +1074,14 @@ setGeneric("expr", function(x) { standardGeneric("expr") }) #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) -#' @rdname format_number +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) -#' @rdname format_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) #' @rdname from_json @@ -1114,8 +1122,9 @@ setGeneric("hour", function(x) { standardGeneric("hour") }) #' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) -#' @rdname initcap +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @param x empty. Should be used with no argument. @@ -1124,8 +1133,9 @@ setGeneric("initcap", function(x) { standardGeneric("initcap") }) setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) -#' @rdname instr +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @rdname is.nan @@ -1158,28 +1168,33 @@ setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("l #' @export setGeneric("least", function(x, ...) { standardGeneric("least") }) -#' @rdname levenshtein +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) #' @rdname lit #' @export setGeneric("lit", function(x) { standardGeneric("lit") }) -#' @rdname locate +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) -#' @rdname lower +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("lower", function(x) { standardGeneric("lower") }) -#' @rdname lpad +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) -#' @rdname ltrim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) #' @rdname md5 @@ -1272,21 +1287,25 @@ setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @export setGeneric("rank", function(x, ...) { standardGeneric("rank") }) -#' @rdname regexp_extract +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) -#' @rdname regexp_replace +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) -#' @rdname repeat_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname reverse +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname column_math_functions @@ -1299,12 +1318,14 @@ setGeneric("rint", function(x) { standardGeneric("rint") }) #' @export setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) -#' @rdname rpad +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) -#' @rdname rtrim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) #' @rdname column_aggregate_functions @@ -1358,12 +1379,14 @@ setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @export setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) -#' @rdname split_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) -#' @rdname soundex +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @param x empty. Should be used with no argument. @@ -1390,8 +1413,9 @@ setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @export setGeneric("struct", function(x, ...) { standardGeneric("struct") }) -#' @rdname substring_index +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) #' @rdname column_aggregate_functions @@ -1428,16 +1452,19 @@ setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") #' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) -#' @rdname translate +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) -#' @rdname trim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("trim", function(x) { standardGeneric("trim") }) -#' @rdname unbase64 +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @rdname column_math_functions @@ -1450,8 +1477,9 @@ setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) -#' @rdname upper +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname column_aggregate_functions From 0c8444cf6d0620cd219ddcf5f50b12ff648639e9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 29 Jun 2017 10:32:32 +0800 Subject: [PATCH 033/779] [SPARK-14657][SPARKR][ML] RFormula w/o intercept should output reference category when encoding string terms ## What changes were proposed in this pull request? Please see [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657) for detail of this bug. I searched online and test some other cases, found when we fit R glm model(or other models powered by R formula) w/o intercept on a dataset including string/category features, one of the categories in the first category feature is being used as reference category, we will not drop any category for that feature. I think we should keep consistent semantics between Spark RFormula and R formula. ## How was this patch tested? Add standard unit tests. cc mengxr Author: Yanbo Liang Closes #12414 from yanboliang/spark-14657. --- .../apache/spark/ml/feature/RFormula.scala | 10 ++- .../spark/ml/feature/RFormulaSuite.scala | 83 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1fad0a6fc9443..4b44878784c90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -205,12 +205,20 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) }.toMap // Then we handle one-hot encoding and interactions between terms. + var keepReferenceCategory = false val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - encoderStages += new OneHotEncoder() + var encoder = new OneHotEncoder() .setInputCol(indexed(term)) .setOutputCol(encodedCol) + // Formula w/o intercept, one of the categories in the first category feature is + // being used as reference category, we will not drop any category for that feature. + if (!hasIntercept && !keepReferenceCategory) { + encoder = encoder.setDropLast(false) + keepReferenceCategory = true + } + encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 41d0062c2cabd..23570d6e0b4cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -213,6 +213,89 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result.collect() === expected.collect()) } + test("formula w/o intercept, we should output reference category when encoding string terms") { + /* + R code: + + df <- data.frame(id = c(1, 2, 3, 4), + a = c("foo", "bar", "bar", "baz"), + b = c("zq", "zz", "zz", "zz"), + c = c(4, 4, 5, 5)) + model.matrix(id ~ a + b + c - 1, df) + + abar abaz afoo bzz c + 1 0 0 1 0 4 + 2 1 0 0 1 4 + 3 1 0 0 1 5 + 4 0 1 0 1 5 + + model.matrix(id ~ a:b + c - 1, df) + + c abar:bzq abaz:bzq afoo:bzq abar:bzz abaz:bzz afoo:bzz + 1 4 0 0 1 0 0 0 + 2 4 0 0 0 1 0 0 + 3 5 0 0 0 1 0 0 + 4 5 0 0 0 0 1 0 + */ + val original = Seq((1, "foo", "zq", 4), (2, "bar", "zz", 4), (3, "bar", "zz", 5), + (4, "baz", "zz", 5)).toDF("id", "a", "b", "c") + + val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model1 = formula1.fit(original) + val result1 = model1.transform(original) + val resultSchema1 = model1.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result1.schema.toString == resultSchema1.toString) + assert(result1.collect() === expected1.collect()) + + val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) + val expectedAttrs1 = new AttributeGroup( + "features", + Array[Attribute]( + new BinaryAttribute(Some("a_foo"), Some(1)), + new BinaryAttribute(Some("a_baz"), Some(2)), + new BinaryAttribute(Some("a_bar"), Some(3)), + new BinaryAttribute(Some("b_zz"), Some(4)), + new NumericAttribute(Some("c"), Some(5)))) + assert(attrs1 === expectedAttrs1) + + // There is no impact for string terms interaction. + val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model2 = formula2.fit(original) + val result2 = model2.transform(original) + val resultSchema2 = model2.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected2 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.sparse(7, Array(4, 6), Array(1.0, 4.0)), 2.0), + (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), + (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result2.schema.toString == resultSchema2.toString) + assert(result2.collect() === expected2.collect()) + + val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) + val expectedAttrs2 = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_foo:b_zz"), Some(1)), + new NumericAttribute(Some("a_foo:b_zq"), Some(2)), + new NumericAttribute(Some("a_baz:b_zz"), Some(3)), + new NumericAttribute(Some("a_baz:b_zq"), Some(4)), + new NumericAttribute(Some("a_bar:b_zz"), Some(5)), + new NumericAttribute(Some("a_bar:b_zq"), Some(6)), + new NumericAttribute(Some("c"), Some(7)))) + assert(attrs2 === expectedAttrs2) + } + test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") val original = From db44f5f3e8b5bc28c33b154319539d51c05a089c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Jun 2017 19:36:00 -0700 Subject: [PATCH 034/779] [SPARK-21224][R] Specify a schema by using a DDL-formatted string when reading in R ## What changes were proposed in this pull request? This PR proposes to support a DDL-formetted string as schema as below: ```r mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}") jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines, jsonPath) df <- read.df(jsonPath, "json", "name STRING, age DOUBLE") collect(df) ``` ## How was this patch tested? Tests added in `test_streaming.R` and `test_sparkSQL.R` and manual tests. Author: hyukjinkwon Closes #18431 from HyukjinKwon/r-ddl-schema. --- R/pkg/R/SQLContext.R | 38 +++++++++++++------ R/pkg/tests/fulltests/test_sparkSQL.R | 20 +++++++++- R/pkg/tests/fulltests/test_streaming.R | 23 +++++++++++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 15 -------- 4 files changed, 67 insertions(+), 29 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index e3528bc7c3135..3b7f71bbbffb8 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -584,7 +584,7 @@ tableToDF <- function(tableName) { #' #' @param path The path of files to load #' @param source The name of external data source -#' @param schema The data schema defined in structType +#' @param schema The data schema defined in structType or a DDL-formatted string. #' @param na.strings Default string value for NA when source is "csv" #' @param ... additional external data source specific named properties. #' @return SparkDataFrame @@ -600,6 +600,8 @@ tableToDF <- function(tableName) { #' structField("info", "map")) #' df2 <- read.df(mapTypeJsonPath, "json", schema, multiLine = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") +#' stringSchema <- "name STRING, info MAP" +#' df4 <- read.df(mapTypeJsonPath, "json", stringSchema, multiLine = TRUE) #' } #' @name read.df #' @method read.df default @@ -623,14 +625,19 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string if (source == "csv" && is.null(options[["nullValue"]])) { options[["nullValue"]] <- na.strings } + read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, schema$jobj, options) - } else { - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, options) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } + read <- callJMethod(read, "options", options) + sdf <- handledCallJMethod(read, "load") dataFrame(sdf) } @@ -717,8 +724,8 @@ read.jdbc <- function(url, tableName, #' "spark.sql.sources.default" will be used. #' #' @param source The name of external data source -#' @param schema The data schema defined in structType, this is required for file-based streaming -#' data source +#' @param schema The data schema defined in structType or a DDL-formatted string, this is +#' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for #' file-based streaming data source #' @return SparkDataFrame @@ -733,6 +740,8 @@ read.jdbc <- function(url, tableName, #' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") #' #' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' stringSchema <- "name STRING, info MAP" +#' df1 <- read.stream("json", path = jsonDir, schema = stringSchema, maxFilesPerTrigger = 1) #' } #' @name read.stream #' @note read.stream since 2.2.0 @@ -750,10 +759,15 @@ read.stream <- function(source = NULL, schema = NULL, ...) { read <- callJMethod(sparkSession, "readStream") read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - read <- callJMethod(read, "schema", schema$jobj) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } read <- callJMethod(read, "options", options) sdf <- handledCallJMethod(read, "load") - dataFrame(callJMethod(sdf, "toDF")) + dataFrame(sdf) } diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 911b73b9ee551..a2bcb5aefe16d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3248,9 +3248,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", + paste("Error in load : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) - expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.df("arbitrary_path"), "Error in load : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") @@ -3268,6 +3268,22 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume "Unnamed arguments ignored: 2, 3, a.") }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.df with a user defined schema in a DDL-formatted string. + df1 <- read.df(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + expect_error(read.df(jsonPath, "json", "name stri"), "DataType stri is not supported.") + + # Test loadDF with a user defined schema in a DDL-formatted string. + df2 <- loadDF(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + expect_error(loadDF(jsonPath, "json", "name stri"), "DataType stri is not supported.") +}) + test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", { ldf <- data.frame(col1 = c(0, 1, 2), col2 = c(as.POSIXct("2017-01-01 00:00:01"), diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index d691de7cd725d..54f40bbd5f517 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -46,6 +46,8 @@ schema <- structType(structField("name", "string"), structField("age", "integer"), structField("count", "double")) +stringSchema <- "name STRING, age INTEGER, count DOUBLE" + test_that("read.stream, write.stream, awaitTermination, stopQuery", { df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) @@ -111,6 +113,27 @@ test_that("Stream other format", { unlink(parquetPath) }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.stream with a user defined schema in a DDL-formatted string. + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + df <- read.df(jsonPath, "json", schema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + + expect_error(read.stream(path = parquetPath, schema = "name stri"), + "DataType stri is not supported.") + + unlink(parquetPath) +}) + test_that("Non-streaming DataFrame", { c <- as.DataFrame(cars) expect_false(isStreaming(c)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d94e528a3ad47..9bd2987057dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -193,21 +193,6 @@ private[sql] object SQLUtils extends Logging { } } - def loadDF( - sparkSession: SparkSession, - source: String, - options: java.util.Map[String, String]): DataFrame = { - sparkSession.read.format(source).options(options).load() - } - - def loadDF( - sparkSession: SparkSession, - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - sparkSession.read.format(source).schema(schema).options(options).load() - } - def readSqlObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 's' => From fc92d25f2a27e81ef2d5031dcf856af1cc1d8c31 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 28 Jun 2017 20:06:29 -0700 Subject: [PATCH 035/779] Revert "[SPARK-21094][R] Terminate R's worker processes in the parent of R's daemon to prevent a leak" This reverts commit 6b3d02285ee0debc73cbcab01b10398a498fbeb8. --- R/pkg/inst/worker/daemon.R | 59 +++----------------------------------- 1 file changed, 4 insertions(+), 55 deletions(-) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 6e385b2a27622..3a318b71ea06d 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,55 +30,8 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) -# Waits indefinitely for a socket connecion by default. -selectTimeout <- NULL - -# Exit code that children send to the parent to indicate they exited. -exitCode <- 1 - while (TRUE) { - ready <- socketSelect(list(inputCon), timeout = selectTimeout) - - # Note that the children should be terminated in the parent. If each child terminates - # itself, it appears that the resource is not released properly, that causes an unexpected - # termination of this daemon due to, for example, running out of file descriptors - # (see SPARK-21093). Therefore, the current implementation tries to retrieve children - # that are exited (but not terminated) and then sends a kill signal to terminate them properly - # in the parent. - # - # There are two paths that it attempts to send a signal to terminate the children in the parent. - # - # 1. Every second if any socket connection is not available and if there are child workers - # running. - # 2. Right after a socket connection is available. - # - # In other words, the parent attempts to send the signal to the children every second if - # any worker is running or right before launching other worker children from the following - # new socket connection. - - # Only the process IDs of children sent data to the parent are returned below. The children - # send a custom exit code to the parent after being exited and the parent tries - # to terminate them only if they sent the exit code. - children <- parallel:::selectChildren(timeout = 0) - - if (is.integer(children)) { - lapply(children, function(child) { - # This data should be raw bytes if any data was sent from this child. - # Otherwise, this returns the PID. - data <- parallel:::readChild(child) - if (is.raw(data)) { - # This checks if the data from this child is the exit code that indicates an exited child. - if (unserialize(data) == exitCode) { - # If so, we terminate this child. - tools::pskill(child, tools::SIGUSR1) - } - } - }) - } else if (is.null(children)) { - # If it is NULL, there are no children. Waits indefinitely for a socket connecion. - selectTimeout <- NULL - } - + ready <- socketSelect(list(inputCon)) if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -91,16 +44,12 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { - # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Note that this mcexit does not fully terminate this child. So, this writes back - # a custom exit code so that the parent can read and terminate this child. - parallel:::mcexit(0L, send = exitCode) - } else { - # Forking succeeded and we need to check if they finished their jobs every second. - selectTimeout <- 1 + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) } } } From 25c2edf6f9da9d4d45fc628cf97de657f2a2cc7e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Jun 2017 11:21:50 +0800 Subject: [PATCH 036/779] [SPARK-21229][SQL] remove QueryPlan.preCanonicalized ## What changes were proposed in this pull request? `QueryPlan.preCanonicalized` is only overridden in a few places, and it does introduce an extra concept to `QueryPlan` which may confuse people. This PR removes it and override `canonicalized` in these places ## How was this patch tested? existing tests Author: Wenchen Fan Closes #18440 from cloud-fan/minor. --- .../sql/catalyst/catalog/interface.scala | 23 +++++++++++-------- .../spark/sql/catalyst/plans/QueryPlan.scala | 13 ++++------- .../sql/execution/DataSourceScanExec.scala | 8 +++++-- .../datasources/LogicalRelation.scala | 5 +++- 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index b63bef9193332..da50b0e7e8e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -27,7 +27,8 @@ import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -425,15 +426,17 @@ case class CatalogRelation( Objects.hashCode(tableMeta.identifier, output) } - override def preCanonicalized: LogicalPlan = copy(tableMeta = CatalogTable( - identifier = tableMeta.identifier, - tableType = tableMeta.tableType, - storage = CatalogStorageFormat.empty, - schema = tableMeta.schema, - partitionColumnNames = tableMeta.partitionColumnNames, - bucketSpec = tableMeta.bucketSpec, - createTime = -1 - )) + override lazy val canonicalized: LogicalPlan = copy( + tableMeta = tableMeta.copy( + storage = CatalogStorageFormat.empty, + createTime = -1 + ), + dataCols = dataCols.zipWithIndex.map { + case (attr, index) => attr.withExprId(ExprId(index)) + }, + partitionCols = partitionCols.zipWithIndex.map { + case (attr, index) => attr.withExprId(ExprId(index + dataCols.length)) + }) override def computeStats: Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 01b3da3f7c482..7addbaaa9afa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -188,12 +188,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same * result. * - * Some nodes should overwrite this to provide proper canonicalize logic. + * Some nodes should overwrite this to provide proper canonicalize logic, but they should remove + * expressions cosmetic variations themselves. */ lazy val canonicalized: PlanType = { val canonicalizedChildren = children.map(_.canonicalized) var id = -1 - preCanonicalized.mapExpressions { + mapExpressions { case a: Alias => id += 1 // As the root of the expression, Alias will always take an arbitrary exprId, we need to @@ -206,18 +207,12 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // Top level `AttributeReference` may also be used for output like `Alias`, we should // normalize the epxrId too. id += 1 - ar.withExprId(ExprId(id)) + ar.withExprId(ExprId(id)).canonicalized case other => QueryPlan.normalizeExprId(other, allAttributes) }.withNewChildren(canonicalizedChildren) } - /** - * Do some simple transformation on this plan before canonicalizing. Implementations can override - * this method to provide customized canonicalize logic without rewriting the whole logic. - */ - protected def preCanonicalized: PlanType = this - /** * Returns true when the given query plan will return the same results as this query plan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 74fc23a52a141..a0def68d88e0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -138,8 +138,12 @@ case class RowDataSourceScanExec( } // Only care about `relation` and `metadata` when canonicalizing. - override def preCanonicalized: SparkPlan = - copy(rdd = null, outputPartitioning = null, metastoreTableIdentifier = None) + override lazy val canonicalized: SparkPlan = + copy( + output.map(QueryPlan.normalizeExprId(_, output)), + rdd = null, + outputPartitioning = null, + metastoreTableIdentifier = None) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index c1b2895f1747e..6ba190b9e5dcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils @@ -43,7 +44,9 @@ case class LogicalRelation( } // Only care about relation when canonicalizing. - override def preCanonicalized: LogicalPlan = copy(catalogTable = None) + override lazy val canonicalized: LogicalPlan = copy( + output = output.map(QueryPlan.normalizeExprId(_, output)), + catalogTable = None) @transient override def computeStats: Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( From 82e24912d6e15a9e4fbadd83da9a08d4f80a592b Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Thu, 29 Jun 2017 11:32:29 +0800 Subject: [PATCH 037/779] [SPARK-21237][SQL] Invalidate stats once table data is changed ## What changes were proposed in this pull request? Invalidate spark's stats after data changing commands: - InsertIntoHadoopFsRelationCommand - InsertIntoHiveTable - LoadDataCommand - TruncateTableCommand - AlterTableSetLocationCommand - AlterTableDropPartitionCommand ## How was this patch tested? Added test cases. Author: wangzhenhua Closes #18449 from wzhfy/removeStats. --- .../catalyst/catalog/ExternalCatalog.scala | 3 +- .../catalyst/catalog/InMemoryCatalog.scala | 4 +- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../catalog/ExternalCatalogSuite.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 2 +- .../command/AnalyzeColumnCommand.scala | 4 +- .../command/AnalyzeTableCommand.scala | 76 +--------- .../sql/execution/command/CommandUtils.scala | 102 ++++++++++++++ .../spark/sql/execution/command/ddl.scala | 9 +- .../spark/sql/execution/command/tables.scala | 7 + .../InsertIntoHadoopFsRelationCommand.scala | 5 + .../spark/sql/StatisticsCollectionSuite.scala | 85 ++++++++++-- .../apache/spark/sql/test/SQLTestUtils.scala | 14 ++ .../spark/sql/hive/HiveExternalCatalog.scala | 24 ++-- .../hive/execution/InsertIntoHiveTable.scala | 4 +- .../spark/sql/hive/StatisticsSuite.scala | 130 ++++++++++++++---- 16 files changed, 340 insertions(+), 133 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 12ba5aedde026..0254b6bb6d136 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -160,7 +160,8 @@ abstract class ExternalCatalog */ def alterTableSchema(db: String, table: String, schema: StructType): Unit - def alterTableStats(db: String, table: String, stats: CatalogStatistics): Unit + /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9820522a230e3..747190faa3c8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -315,10 +315,10 @@ class InMemoryCatalog( override def alterTableStats( db: String, table: String, - stats: CatalogStatistics): Unit = synchronized { + stats: Option[CatalogStatistics]): Unit = synchronized { requireTableExists(db, table) val origTable = catalog(db).tables(table).table - catalog(db).tables(table).table = origTable.copy(stats = Some(stats)) + catalog(db).tables(table).table = origTable.copy(stats = stats) } override def getTable(db: String, table: String): CatalogTable = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index cf02da8993658..7ece77df7fc14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -380,7 +380,7 @@ class SessionCatalog( * Alter Spark's statistics of an existing metastore table identified by the provided table * identifier. */ - def alterTableStats(identifier: TableIdentifier, newStats: CatalogStatistics): Unit = { + def alterTableStats(identifier: TableIdentifier, newStats: Option[CatalogStatistics]): Unit = { val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 557b0970b54e5..c22d55fc96a65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -260,7 +260,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val oldTableStats = catalog.getTable("db2", "tbl1").stats assert(oldTableStats.isEmpty) val newStats = CatalogStatistics(sizeInBytes = 1) - catalog.alterTableStats("db2", "tbl1", newStats) + catalog.alterTableStats("db2", "tbl1", Some(newStats)) val newTableStats = catalog.getTable("db2", "tbl1").stats assert(newTableStats.get == newStats) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index a6dc21b03d446..fc3893e197792 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -454,7 +454,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { val oldTableStats = catalog.getTableMetadata(tableId).stats assert(oldTableStats.isEmpty) val newStats = CatalogStatistics(sizeInBytes = 1) - catalog.alterTableStats(tableId, newStats) + catalog.alterTableStats(tableId, Some(newStats)) val newTableStats = catalog.getTableMetadata(tableId).stats assert(newTableStats.get == newStats) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 2f273b63e8348..6588993ef9ad9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -42,7 +42,7 @@ case class AnalyzeColumnCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } - val sizeInBytes = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) + val sizeInBytes = CommandUtils.calculateTotalSize(sessionState, tableMeta) // Compute stats for each column val (rowCount, newColStats) = computeColumnStats(sparkSession, tableIdentWithDB, columnNames) @@ -54,7 +54,7 @@ case class AnalyzeColumnCommand( // Newly computed column stats should override the existing ones. colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) - sessionState.catalog.alterTableStats(tableIdentWithDB, statistics) + sessionState.catalog.alterTableStats(tableIdentWithDB, Some(statistics)) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 13b8faff844c7..d780ef42f3fae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -17,18 +17,10 @@ package org.apache.spark.sql.execution.command -import java.net.URI - -import scala.util.control.NonFatal - -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.internal.SessionState /** @@ -46,7 +38,7 @@ case class AnalyzeTableCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } - val newTotalSize = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) + val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta) val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(0L) val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) @@ -74,7 +66,7 @@ case class AnalyzeTableCommand( // Update the metastore if the above statistics of the table are different from those // recorded in the metastore. if (newStats.isDefined) { - sessionState.catalog.alterTableStats(tableIdentWithDB, newStats.get) + sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } @@ -82,65 +74,3 @@ case class AnalyzeTableCommand( Seq.empty[Row] } } - -object AnalyzeTableCommand extends Logging { - - def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { - if (catalogTable.partitionColumnNames.isEmpty) { - calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) - } else { - // Calculate table size as a sum of the visible partitions. See SPARK-21079 - val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) - partitions.map(p => - calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) - ).sum - } - } - - private def calculateLocationSize( - sessionState: SessionState, - tableId: TableIdentifier, - locationUri: Option[URI]): Long = { - // This method is mainly based on - // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) - // in Hive 0.13 (except that we do not use fs.getContentSummary). - // TODO: Generalize statistics collection. - // TODO: Why fs.getContentSummary returns wrong size on Jenkins? - // Can we use fs.getContentSummary in future? - // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use - // countFileSize to count the table size. - val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - - def calculateLocationSize(fs: FileSystem, path: Path): Long = { - val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDirectory) { - fs.listStatus(path) - .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { - calculateLocationSize(fs, status.getPath) - } else { - 0L - } - }.sum - } else { - fileStatus.getLen - } - - size - } - - locationUri.map { p => - val path = new Path(p) - try { - val fs = path.getFileSystem(sessionState.newHadoopConf()) - calculateLocationSize(fs, path) - } catch { - case NonFatal(e) => - logWarning( - s"Failed to get the size of table ${tableId.table} in the " + - s"database ${tableId.database} because of ${e.toString}", e) - 0L - } - }.getOrElse(0L) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala new file mode 100644 index 0000000000000..92397607f38fd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -0,0 +1,102 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.command + +import java.net.URI + +import scala.util.control.NonFatal + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable} +import org.apache.spark.sql.internal.SessionState + + +object CommandUtils extends Logging { + + /** Change statistics after changing data by commands. */ + def updateTableStats(sparkSession: SparkSession, table: CatalogTable): Unit = { + if (table.stats.nonEmpty) { + val catalog = sparkSession.sessionState.catalog + catalog.alterTableStats(table.identifier, None) + } + } + + def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): BigInt = { + if (catalogTable.partitionColumnNames.isEmpty) { + calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) + } else { + // Calculate table size as a sum of the visible partitions. See SPARK-21079 + val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) + partitions.map { p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + }.sum + } + } + + def calculateLocationSize( + sessionState: SessionState, + identifier: TableIdentifier, + locationUri: Option[URI]): Long = { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + + def getPathSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDirectory) { + fs.listStatus(path) + .map { status => + if (!status.getPath.getName.startsWith(stagingDir)) { + getPathSize(fs, status.getPath) + } else { + 0L + } + }.sum + } else { + fileStatus.getLen + } + + size + } + + locationUri.map { p => + val path = new Path(p) + try { + val fs = path.getFileSystem(sessionState.newHadoopConf()) + getPathSize(fs, path) + } catch { + case NonFatal(e) => + logWarning( + s"Failed to get the size of table ${identifier.table} in the " + + s"database ${identifier.database} because of ${e.toString}", e) + 0L + } + }.getOrElse(0L) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 413f5f3ba539c..ac897c1b22d77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -433,9 +433,11 @@ case class AlterTableAddPartitionCommand( sparkSession.sessionState.conf.resolver) // inherit table storage format (possibly except for location) CatalogTablePartition(normalizedSpec, table.storage.copy( - locationUri = location.map(CatalogUtils.stringToURI(_)))) + locationUri = location.map(CatalogUtils.stringToURI))) } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) + + CommandUtils.updateTableStats(sparkSession, table) Seq.empty[Row] } @@ -519,6 +521,9 @@ case class AlterTableDropPartitionCommand( catalog.dropPartitions( table.identifier, normalizedSpecs, ignoreIfNotExists = ifExists, purge = purge, retainData = retainData) + + CommandUtils.updateTableStats(sparkSession, table) + Seq.empty[Row] } @@ -768,6 +773,8 @@ case class AlterTableSetLocationCommand( // No partition spec is specified, so we set the location for the table itself catalog.alterTable(table.withNewStorage(locationUri = Some(locUri))) } + + CommandUtils.updateTableStats(sparkSession, table) Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index b937a8a9f375b..8ded1060f7bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -400,6 +400,7 @@ case class LoadDataCommand( // Refresh the metadata cache to ensure the data visible to the users catalog.refreshTable(targetTable.identifier) + CommandUtils.updateTableStats(sparkSession, targetTable) Seq.empty[Row] } } @@ -487,6 +488,12 @@ case class TruncateTableCommand( case NonFatal(e) => log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e) } + + if (table.stats.nonEmpty) { + // empty table after truncation + val newStats = CatalogStatistics(sizeInBytes = 0, rowCount = Some(0)) + catalog.alterTableStats(tableName, Some(newStats)) + } Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 00aa1240886e4..ab26f2affbce5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -161,6 +161,11 @@ case class InsertIntoHadoopFsRelationCommand( fileIndex.foreach(_.refresh()) // refresh data cache if table is cached sparkSession.catalog.refreshByPath(outputPath.toString) + + if (catalogTable.nonEmpty) { + CommandUtils.updateTableStats(sparkSession, catalogTable.get) + } + } else { logInfo("Skipping insertion into a relation that already exists.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 9824062f969b3..b031c52dad8b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -40,17 +40,6 @@ import org.apache.spark.sql.types._ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext { import testImplicits._ - private def checkTableStats(tableName: String, expectedRowCount: Option[Int]) - : Option[CatalogStatistics] = { - val df = spark.table(tableName) - val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } - test("estimates the size of a limit 0 on outer join") { withTempView("test") { Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") @@ -96,11 +85,11 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared // noscan won't count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - checkTableStats(tableName, expectedRowCount = None) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") - checkTableStats(tableName, expectedRowCount = Some(2)) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = Some(2)) } } @@ -168,6 +157,60 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(stats.simpleString == expectedString) } } + + test("change stats after truncate command") { + val table = "change_stats_truncate_table" + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // truncate table command + sql(s"TRUNCATE TABLE $table") + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched2.get.sizeInBytes == 0) + assert(fetched2.get.colStats.isEmpty) + } + } + + test("change stats after set location command") { + val table = "change_stats_set_location_table" + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats( + table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // set location command + withTempDir { newLocation => + sql(s"ALTER TABLE $table SET LOCATION '${newLocation.toURI.toString}'") + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } + + test("change stats after insert command for datasource table") { + val table = "change_stats_insert_datasource_table" + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } @@ -219,6 +262,22 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils private val randomName = new Random(31) + def checkTableStats( + tableName: String, + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { + val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes >= 0) + assert(stats.get.rowCount === expectedRowCounts) + } else { + assert(stats.isEmpty) + } + + stats + } + /** * Compute column stats for the given DataFrame and compare it with colStats. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index f6d47734d7e83..d74a7cce25ed6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -149,6 +149,7 @@ private[sql] trait SQLTestUtils .getExecutorInfos.map(_.numRunningTasks()).sum == 0) } } + /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. @@ -164,6 +165,19 @@ private[sql] trait SQLTestUtils } } + /** + * Creates the specified number of temporary directories, which is then passed to `f` and will be + * deleted after `f` returns. + */ + protected def withTempPaths(numPaths: Int)(f: Seq[File] => Unit): Unit = { + val files = Array.fill[File](numPaths)(Utils.createTempDir().getCanonicalFile) + try f(files) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + files.foreach(Utils.deleteRecursively) + } + } + /** * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 6e7c475fa34c9..2a17849fa8a34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -631,21 +631,23 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def alterTableStats( db: String, table: String, - stats: CatalogStatistics): Unit = withClient { + stats: Option[CatalogStatistics]): Unit = withClient { requireTableExists(db, table) val rawTable = getRawTable(db, table) // convert table statistics to properties so that we can persist them through hive client - var statsProperties: Map[String, String] = - Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) - if (stats.rowCount.isDefined) { - statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() - } - val colNameTypeMap: Map[String, DataType] = - rawTable.schema.fields.map(f => (f.name, f.dataType)).toMap - stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) + val statsProperties = new mutable.HashMap[String, String]() + if (stats.isDefined) { + statsProperties += STATISTICS_TOTAL_SIZE -> stats.get.sizeInBytes.toString() + if (stats.get.rowCount.isDefined) { + statsProperties += STATISTICS_NUM_ROWS -> stats.get.rowCount.get.toString() + } + val colNameTypeMap: Map[String, DataType] = + rawTable.schema.fields.map(f => (f.name, f.dataType)).toMap + stats.get.colStats.foreach { case (colName, colStat) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 392b7cfaa8eff..223d375232393 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command.{CommandUtils, RunnableCommand} import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} @@ -434,6 +434,8 @@ case class InsertIntoHiveTable( sparkSession.catalog.uncacheTable(table.identifier.quotedString) sparkSession.sessionState.catalog.refreshTable(table.identifier) + CommandUtils.updateTableStats(sparkSession, table) + // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 64deb3818d5d1..5fd266c2d033c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -30,10 +30,12 @@ import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ + class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { test("Hive serde tables should fallback to HDFS for size estimation") { @@ -219,23 +221,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - private def checkTableStats( - tableName: String, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { - val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats - - if (hasSizeInBytes || expectedRowCounts.nonEmpty) { - assert(stats.isDefined) - assert(stats.get.sizeInBytes > 0) - assert(stats.get.rowCount === expectedRowCounts) - } else { - assert(stats.isEmpty) - } - - stats - } - test("test table-level statistics for hive tables created in HiveExternalCatalog") { val textTable = "textTable" withTable(textTable) { @@ -326,7 +311,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto descOutput: Seq[String], propKey: String): Option[BigInt] = { val str = descOutput - .filterNot(_.contains(HiveExternalCatalog.STATISTICS_PREFIX)) + .filterNot(_.contains(STATISTICS_PREFIX)) .filter(_.contains(propKey)) if (str.isEmpty) { None @@ -448,6 +433,103 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto "ALTER TABLE unset_prop_table UNSET TBLPROPERTIES ('prop1')") } + /** + * To see if stats exist, we need to check spark's stats properties instead of catalog + * statistics, because hive would change stats in metastore and thus change catalog statistics. + */ + private def getStatsProperties(tableName: String): Map[String, String] = { + val hTable = hiveClient.getTable(spark.sessionState.catalog.getCurrentDatabase, tableName) + hTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + } + + test("change stats after insert command for hive table") { + val table = s"change_stats_insert_hive_table" + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string)") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + assert(getStatsProperties(table).isEmpty) + } + } + + test("change stats after load data command") { + val table = "change_stats_load_table" + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + withTempDir { loadPath => + // load data command + val file = new File(loadPath + "/data") + val writer = new PrintWriter(file) + writer.write("2,xyz") + writer.close() + sql(s"LOAD DATA INPATH '${loadPath.toURI.toString}' INTO TABLE $table") + assert(getStatsProperties(table).isEmpty) + } + } + } + + test("change stats after add/drop partition command") { + val table = "change_stats_part_table" + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") + // table has two partitions initially + for (ds <- Seq("2008-04-08"); hr <- Seq("11", "12")) { + sql(s"INSERT OVERWRITE TABLE $table PARTITION (ds='$ds',hr='$hr') SELECT 1, 'a'") + } + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + withTempPaths(numPaths = 2) { case Seq(dir1, dir2) => + val file1 = new File(dir1 + "/data") + val writer1 = new PrintWriter(file1) + writer1.write("1,a") + writer1.close() + + val file2 = new File(dir2 + "/data") + val writer2 = new PrintWriter(file2) + writer2.write("1,a") + writer2.close() + + // add partition command + sql( + s""" + |ALTER TABLE $table ADD + |PARTITION (ds='2008-04-09', hr='11') LOCATION '${dir1.toURI.toString}' + |PARTITION (ds='2008-04-09', hr='12') LOCATION '${dir2.toURI.toString}' + """.stripMargin) + assert(getStatsProperties(table).isEmpty) + + // generate stats again + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(4)) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.size == 2) + + // drop partition command + sql(s"ALTER TABLE $table DROP PARTITION (ds='2008-04-08'), PARTITION (hr='12')") + // only one partition left + assert(spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + .map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + assert(getStatsProperties(table).isEmpty) + } + } + } + test("add/drop partitions - managed table") { val catalog = spark.sessionState.catalog val managedTable = "partitionedTable" @@ -483,23 +565,19 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(catalog.listPartitions(TableIdentifier(managedTable)).map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) - val stats2 = checkTableStats( - managedTable, hasSizeInBytes = true, expectedRowCounts = Some(4)) - assert(stats1 == stats2) - sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") - val stats3 = checkTableStats( + val stats2 = checkTableStats( managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) - assert(stats2.get.sizeInBytes > stats3.get.sizeInBytes) + assert(stats1.get.sizeInBytes > stats2.get.sizeInBytes) sql(s"ALTER TABLE $managedTable ADD PARTITION (ds='2008-04-08', hr='12')") sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") val stats4 = checkTableStats( managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) - assert(stats2.get.sizeInBytes > stats4.get.sizeInBytes) - assert(stats4.get.sizeInBytes == stats3.get.sizeInBytes) + assert(stats1.get.sizeInBytes > stats4.get.sizeInBytes) + assert(stats4.get.sizeInBytes == stats2.get.sizeInBytes) } } From a946be35ac177737e99942ad42de6f319f186138 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Thu, 29 Jun 2017 14:25:51 +0800 Subject: [PATCH 038/779] [SPARK-3577] Report Spill size on disk for UnsafeExternalSorter ## What changes were proposed in this pull request? Report Spill size on disk for UnsafeExternalSorter ## How was this patch tested? Tested by running a job on cluster and verify the spill size on disk. Author: Sital Kedia Closes #17471 from sitalkedia/fix_disk_spill_size. --- .../unsafe/sort/UnsafeExternalSorter.java | 9 +++---- .../sort/UnsafeExternalSorterSuite.java | 25 +++++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index f312fa2b2ddd7..82d03e3e9190c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -54,7 +54,6 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final BlockManager blockManager; private final SerializerManager serializerManager; private final TaskContext taskContext; - private ShuffleWriteMetrics writeMetrics; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -144,10 +143,6 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; - // The spill metrics are stored in a new ShuffleWriteMetrics, - // and then discarded (this fixes SPARK-16827). - // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577). - this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( @@ -199,6 +194,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { spillWriters.size(), spillWriters.size() > 1 ? " times" : " time"); + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); // We only write out contents of the inMemSorter if it is not empty. if (inMemSorter.numRecords() > 0) { final UnsafeSorterSpillWriter spillWriter = @@ -226,6 +222,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { // pages, we might not be able to get memory for the pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += spillSize; return spillSize; } @@ -502,6 +499,7 @@ public long spill() throws IOException { UnsafeInMemorySorter.SortedIterator inMemIterator = ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); @@ -540,6 +538,7 @@ public long spill() throws IOException { inMemSorter.free(); inMemSorter = null; taskContext.taskMetrics().incMemoryBytesSpilled(released); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += released; return released; } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 771d39016c188..d31d7c1c0900c 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -405,6 +405,31 @@ public void forcedSpillingWithoutComparator() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void testDiskSpilledBytes() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + long[] record = new long[100]; + int recordSize = record.length * 8; + int n = (int) pageSizeBytes / recordSize * 3; + for (int i = 0; i < n; i++) { + record[0] = (long) i; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); + } + // We will have at-least 2 memory pages allocated because of rounding happening due to + // integer division of pageSizeBytes and recordSize. + assertTrue(sorter.getNumberOfAllocatedPages() >= 2); + assertTrue(taskContext.taskMetrics().diskBytesSpilled() == 0); + UnsafeExternalSorter.SpillableIterator iter = + (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator(); + assertTrue(iter.spill() > 0); + assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0); + assertEquals(0, iter.spill()); + // Even if we did not spill second time, the disk spilled bytes should still be non-zero + assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0); + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + @Test public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; From 9f6b3e65ccfa0daec31b58c5a6386b3a890c2149 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Jun 2017 14:37:42 +0800 Subject: [PATCH 039/779] [SPARK-21238][SQL] allow nested SQL execution ## What changes were proposed in this pull request? This is kind of another follow-up for https://github.com/apache/spark/pull/18064 . In #18064 , we wrap every SQL command with SQL execution, which makes nested SQL execution very likely to happen. #18419 trid to improve it a little bit, by introduing `SQLExecition.ignoreNestedExecutionId`. However, this is not friendly to data source developers, they may need to update their code to use this `ignoreNestedExecutionId` API. This PR proposes a new solution, to just allow nested execution. The downside is that, we may have multiple executions for one query. We can improve this by updating the data organization in SQLListener, to have 1-n mapping from query to execution, instead of 1-1 mapping. This can be done in a follow-up. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18450 from cloud-fan/execution-id. --- .../spark/sql/execution/SQLExecution.scala | 88 ++++--------------- .../command/AnalyzeTableCommand.scala | 4 +- .../spark/sql/execution/command/cache.scala | 16 ++-- .../datasources/csv/CSVDataSource.scala | 4 +- .../datasources/jdbc/JDBCRelation.scala | 8 +- .../sql/execution/streaming/console.scala | 12 +-- .../sql/execution/streaming/memory.scala | 32 ++++--- .../sql/execution/SQLExecutionSuite.scala | 24 ----- 8 files changed, 50 insertions(+), 138 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index ca8bed5214f87..e991da7df0bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -22,15 +22,12 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, - SparkListenerSQLExecutionStart} +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" - private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId" - private val _nextExecutionId = new AtomicLong(0) private def nextExecutionId: Long = _nextExecutionId.getAndIncrement @@ -45,10 +42,8 @@ object SQLExecution { private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { val sc = sparkSession.sparkContext - val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null - val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null // only throw an exception during tests. a missing execution ID should not fail a job. - if (testing && !isNestedExecution && !hasExecutionId) { + if (testing && sc.getLocalProperty(EXECUTION_ID_KEY) == null) { // Attention testers: when a test fails with this exception, it means that the action that // started execution of a query didn't call withNewExecutionId. The execution ID should be // set by calling withNewExecutionId in the action that begins execution, like @@ -66,56 +61,27 @@ object SQLExecution { queryExecution: QueryExecution)(body: => T): T = { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) - if (oldExecutionId == null) { - val executionId = SQLExecution.nextExecutionId - sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) - executionIdToQueryExecution.put(executionId, queryExecution) - try { - // sparkContext.getCallSite() would first try to pick up any call site that was previously - // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on - // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() - - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) - } - } finally { - executionIdToQueryExecution.remove(executionId) - sc.setLocalProperty(EXECUTION_ID_KEY, null) - } - } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) { - // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the - // `body`, so that Spark jobs issued in the `body` won't be tracked. + val executionId = SQLExecution.nextExecutionId + sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) + executionIdToQueryExecution.put(executionId, queryExecution) + try { + // sparkContext.getCallSite() would first try to pick up any call site that was previously + // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on + // streaming queries would give us call site like "run at :0" + val callSite = sparkSession.sparkContext.getCallSite() + + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { - sc.setLocalProperty(EXECUTION_ID_KEY, null) body } finally { - sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } - } else { - // Don't support nested `withNewExecutionId`. This is an example of the nested - // `withNewExecutionId`: - // - // class DataFrame { - // def foo: T = withNewExecutionId { something.createNewDataFrame().collect() } - // } - // - // Note: `collect` will call withNewExecutionId - // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan" - // for the outer DataFrame won't be executed. So it's meaningless to create a new Execution - // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run, - // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. - // - // A real case is the `DataFrame.count` method. - throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " + - "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " + - "jobs issued by the nested execution.") + } finally { + executionIdToQueryExecution.remove(executionId) + sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) } } @@ -133,20 +99,4 @@ object SQLExecution { sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } - - /** - * Wrap an action which may have nested execution id. This method can be used to run an execution - * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that, - * all Spark jobs issued in the body won't be tracked in UI. - */ - def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = { - val sc = sparkSession.sparkContext - val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) - try { - sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true") - body - } finally { - sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d780ef42f3fae..42e2a9ca5c4e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -51,9 +51,7 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) { - sparkSession.table(tableIdentWithDB).count() - } + val newRowCount = sparkSession.table(tableIdentWithDB).count() if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index d36eb7587a3ef..47952f2f227a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -34,16 +34,14 @@ case class CacheTableCommand( override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { - SQLExecution.ignoreNestedExecutionId(sparkSession) { - plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) - } - sparkSession.catalog.cacheTable(tableIdent.quotedString) + plan.foreach { logicalPlan => + Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + } + sparkSession.catalog.cacheTable(tableIdent.quotedString) - if (!isLazy) { - // Performs eager caching - sparkSession.table(tableIdent).count() - } + if (!isLazy) { + // Performs eager caching + sparkSession.table(tableIdent).count() } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 99133bd70989a..2031381dd2e10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -145,9 +145,7 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) { - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption - } + val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b11da7045de22..a521fd1323852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -130,11 +130,9 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - data.write - .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) - } + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 6fa7c113defaa..3baea6376069f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -48,11 +48,9 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println(batchIdStr) println("-------------------------------------------") // scalastyle:off println - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) - } + data.sparkSession.createDataFrame( + data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) + .show(numRowsToShow, isTruncated) } } @@ -82,9 +80,7 @@ class ConsoleSinkProvider extends StreamSinkProvider // Truncate the displayed data if it is too long, by default it is true val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) - SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) { - data.show(numRowsToShow, isTruncated) - } + data.show(numRowsToShow, isTruncated) ConsoleRelation(sqlContext, data) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 198a342582804..4979873ee3c7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -194,23 +194,21 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collect()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") - } + outputMode match { + case Append | Update => + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } + + case Complete => + val rows = AddedData(batchId, data.collect()) + synchronized { + batches.clear() + batches += rows + } + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") } } else { logDebug(s"Skipping already committed batch: $batchId") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index fe78a76568837..f6b006b98edd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -26,22 +26,9 @@ import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { test("concurrent query execution (SPARK-10548)") { - // Try to reproduce the issue with the old SparkContext val conf = new SparkConf() .setMaster("local[*]") .setAppName("test") - val badSparkContext = new BadSparkContext(conf) - try { - testConcurrentQueryExecution(badSparkContext) - fail("unable to reproduce SPARK-10548") - } catch { - case e: IllegalArgumentException => - assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) - } finally { - badSparkContext.stop() - } - - // Verify that the issue is fixed with the latest SparkContext val goodSparkContext = new SparkContext(conf) try { testConcurrentQueryExecution(goodSparkContext) @@ -134,17 +121,6 @@ class SQLExecutionSuite extends SparkFunSuite { } } -/** - * A bad [[SparkContext]] that does not clone the inheritable thread local properties - * when passing them to children threads. - */ -private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { - protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) - override protected def initialValue(): Properties = new Properties() - } -} - object SQLExecutionSuite { @volatile var canProgress = false } From a2d5623548194f15989e7b68118d744673e33819 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 29 Jun 2017 01:23:13 -0700 Subject: [PATCH 040/779] [SPARK-20889][SPARKR] Grouped documentation for NONAGGREGATE column methods ## What changes were proposed in this pull request? Grouped documentation for nonaggregate column methods. Author: actuaryzhang Author: Wayne Zhang Closes #18422 from actuaryzhang/sparkRDocNonAgg. --- R/pkg/R/functions.R | 360 ++++++++++++++++++-------------------------- R/pkg/R/generics.R | 55 ++++--- 2 files changed, 182 insertions(+), 233 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 70ea620b471fe..cb09e847d739a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -132,23 +132,39 @@ NULL #' df <- createDataFrame(as.data.frame(Titanic, stringsAsFactors = FALSE))} NULL -#' lit +#' Non-aggregate functions for Column operations #' -#' A new \linkS4class{Column} is created to represent the literal value. -#' If the parameter is a \linkS4class{Column}, it is returned unchanged. +#' Non-aggregate functions defined for \code{Column}. #' -#' @param x a literal value or a Column. +#' @param x Column to compute on. In \code{lit}, it is a literal value or a Column. +#' In \code{expr}, it contains an expression character object to be parsed. +#' @param y Column to compute on. +#' @param ... additional Columns. +#' @name column_nonaggregate_functions +#' @rdname column_nonaggregate_functions +#' @seealso coalesce,SparkDataFrame-method #' @family non-aggregate functions -#' @rdname lit -#' @name lit +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} +NULL + +#' @details +#' \code{lit}: A new Column is created to represent the literal value. +#' If the parameter is a Column, it is returned unchanged. +#' +#' @rdname column_nonaggregate_functions #' @export -#' @aliases lit,ANY-method +#' @aliases lit lit,ANY-method #' @examples +#' #' \dontrun{ -#' lit(df$name) -#' select(df, lit("x")) -#' select(df, lit("2015-01-01")) -#'} +#' tmp <- mutate(df, v1 = lit(df$mpg), v2 = lit("x"), v3 = lit("2015-01-01"), +#' v4 = negate(df$mpg), v5 = expr('length(model)'), +#' v6 = greatest(df$vs, df$am), v7 = least(df$vs, df$am), +#' v8 = column("mpg")) +#' head(tmp)} #' @note lit since 1.5.0 setMethod("lit", signature("ANY"), function(x) { @@ -314,18 +330,16 @@ setMethod("bin", column(jc) }) -#' bitwiseNOT -#' -#' Computes bitwise NOT. -#' -#' @param x Column to compute on. +#' @details +#' \code{bitwiseNOT}: Computes bitwise NOT. #' -#' @rdname bitwiseNOT -#' @name bitwiseNOT -#' @family non-aggregate functions +#' @rdname column_nonaggregate_functions #' @export -#' @aliases bitwiseNOT,Column-method -#' @examples \dontrun{bitwiseNOT(df$c)} +#' @aliases bitwiseNOT bitwiseNOT,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, bitwiseNOT(cast(df$vs, "int"))))} #' @note bitwiseNOT since 1.5.0 setMethod("bitwiseNOT", signature(x = "Column"), @@ -375,16 +389,12 @@ setMethod("ceiling", ceil(x) }) -#' Returns the first column that is not NA -#' -#' Returns the first column that is not NA, or NA if all inputs are. +#' @details +#' \code{coalesce}: Returns the first column that is not NA, or NA if all inputs are. #' -#' @rdname coalesce -#' @name coalesce -#' @family non-aggregate functions +#' @rdname column_nonaggregate_functions #' @export #' @aliases coalesce,Column-method -#' @examples \dontrun{coalesce(df$c, df$d, df$e)} #' @note coalesce(Column) since 2.1.1 setMethod("coalesce", signature(x = "Column"), @@ -824,22 +834,24 @@ setMethod("initcap", column(jc) }) -#' is.nan -#' -#' Return true if the column is NaN, alias for \link{isnan} -#' -#' @param x Column to compute on. +#' @details +#' \code{isnan}: Returns true if the column is NaN. +#' @rdname column_nonaggregate_functions +#' @aliases isnan isnan,Column-method +#' @note isnan since 2.0.0 +setMethod("isnan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) + column(jc) + }) + +#' @details +#' \code{is.nan}: Alias for \link{isnan}. #' -#' @rdname is.nan -#' @name is.nan -#' @family non-aggregate functions -#' @aliases is.nan,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases is.nan is.nan,Column-method #' @export -#' @examples -#' \dontrun{ -#' is.nan(df$c) -#' isnan(df$c) -#' } #' @note is.nan since 2.0.0 setMethod("is.nan", signature(x = "Column"), @@ -847,17 +859,6 @@ setMethod("is.nan", isnan(x) }) -#' @rdname is.nan -#' @name isnan -#' @aliases isnan,Column-method -#' @note isnan since 2.0.0 -setMethod("isnan", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) - column(jc) - }) - #' @details #' \code{kurtosis}: Returns the kurtosis of the values in a group. #' @@ -1129,27 +1130,24 @@ setMethod("minute", column(jc) }) -#' monotonically_increasing_id -#' -#' Return a column that generates monotonically increasing 64-bit integers. -#' -#' The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. -#' The current implementation puts the partition ID in the upper 31 bits, and the record number -#' within each partition in the lower 33 bits. The assumption is that the SparkDataFrame has -#' less than 1 billion partitions, and each partition has less than 8 billion records. -#' -#' As an example, consider a SparkDataFrame with two partitions, each with 3 records. +#' @details +#' \code{monotonically_increasing_id}: Returns a column that generates monotonically increasing +#' 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, +#' but not consecutive. The current implementation puts the partition ID in the upper 31 bits, +#' and the record number within each partition in the lower 33 bits. The assumption is that the +#' SparkDataFrame has less than 1 billion partitions, and each partition has less than 8 billion +#' records. As an example, consider a SparkDataFrame with two partitions, each with 3 records. #' This expression would return the following IDs: #' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. -#' #' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. +#' The method should be used with no argument. #' -#' @rdname monotonically_increasing_id -#' @aliases monotonically_increasing_id,missing-method -#' @name monotonically_increasing_id -#' @family misc functions +#' @rdname column_nonaggregate_functions +#' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method #' @export -#' @examples \dontrun{select(df, monotonically_increasing_id())} +#' @examples +#' +#' \dontrun{head(select(df, monotonically_increasing_id()))} setMethod("monotonically_increasing_id", signature("missing"), function() { @@ -1171,18 +1169,12 @@ setMethod("month", column(jc) }) -#' negate -#' -#' Unary minus, i.e. negate the expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{negate}: Unary minus, i.e. negate the expression. #' -#' @rdname negate -#' @name negate -#' @family non-aggregate functions -#' @aliases negate,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases negate negate,Column-method #' @export -#' @examples \dontrun{negate(df$c)} #' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), @@ -1481,23 +1473,19 @@ setMethod("stddev_samp", column(jc) }) -#' struct -#' -#' Creates a new struct column that composes multiple input columns. -#' -#' @param x a column to compute on. -#' @param ... optional column(s) to be included. +#' @details +#' \code{struct}: Creates a new struct column that composes multiple input columns. #' -#' @rdname struct -#' @name struct -#' @family non-aggregate functions -#' @aliases struct,characterOrColumn-method +#' @rdname column_nonaggregate_functions +#' @aliases struct struct,characterOrColumn-method #' @export #' @examples +#' #' \dontrun{ -#' struct(df$c, df$d) -#' struct("col1", "col2") -#' } +#' tmp <- mutate(df, v1 = struct(df$mpg, df$cyl), v2 = struct("hp", "wt", "vs"), +#' v3 = create_array(df$mpg, df$cyl, df$hp), +#' v4 = create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))) +#' head(tmp)} #' @note struct since 1.6.0 setMethod("struct", signature(x = "characterOrColumn"), @@ -1959,20 +1947,13 @@ setMethod("months_between", signature(y = "Column"), column(jc) }) -#' nanvl -#' -#' Returns col1 if it is not NaN, or col2 if col1 is NaN. -#' Both inputs should be floating point columns (DoubleType or FloatType). -#' -#' @param x first Column. -#' @param y second Column. +#' @details +#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column (\code{x}) if +#' the first column is NaN. Both inputs should be floating point columns (DoubleType or FloatType). #' -#' @rdname nanvl -#' @name nanvl -#' @family non-aggregate functions -#' @aliases nanvl,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases nanvl nanvl,Column-method #' @export -#' @examples \dontrun{nanvl(df$c, x)} #' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { @@ -2060,20 +2041,13 @@ setMethod("concat", column(jc) }) -#' greatest -#' -#' Returns the greatest value of the list of column names, skipping null values. +#' @details +#' \code{greatest}: Returns the greatest value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family non-aggregate functions -#' @rdname greatest -#' @name greatest -#' @aliases greatest,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases greatest greatest,Column-method #' @export -#' @examples \dontrun{greatest(df$c, df$d)} #' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), @@ -2087,20 +2061,13 @@ setMethod("greatest", column(jc) }) -#' least -#' -#' Returns the least value of the list of column names, skipping null values. +#' @details +#' \code{least}: Returns the least value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family non-aggregate functions -#' @rdname least -#' @aliases least,Column-method -#' @name least +#' @rdname column_nonaggregate_functions +#' @aliases least least,Column-method #' @export -#' @examples \dontrun{least(df$c, df$d)} #' @note least since 1.5.0 setMethod("least", signature(x = "Column"), @@ -2445,18 +2412,13 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri column(jc) }) -#' expr -#' -#' Parses the expression string into the column that it represents, similar to -#' SparkDataFrame.selectExpr +#' @details +#' \code{expr}: Parses the expression string into the column that it represents, similar to +#' \code{SparkDataFrame.selectExpr} #' -#' @param x an expression character object to be parsed. -#' @family non-aggregate functions -#' @rdname expr -#' @aliases expr,character-method -#' @name expr +#' @rdname column_nonaggregate_functions +#' @aliases expr expr,character-method #' @export -#' @examples \dontrun{expr('length(name)')} #' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { @@ -2617,18 +2579,19 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), column(jc) }) -#' rand -#' -#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' @details +#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) samples #' from U[0.0, 1.0]. #' +#' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. -#' @family non-aggregate functions -#' @rdname rand -#' @name rand -#' @aliases rand,missing-method +#' @aliases rand rand,missing-method #' @export -#' @examples \dontrun{rand()} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, r1 = rand(), r2 = rand(10), r3 = randn(), r4 = randn(10)) +#' head(tmp)} #' @note rand since 1.5.0 setMethod("rand", signature(seed = "missing"), function(seed) { @@ -2636,8 +2599,7 @@ setMethod("rand", signature(seed = "missing"), column(jc) }) -#' @rdname rand -#' @name rand +#' @rdname column_nonaggregate_functions #' @aliases rand,numeric-method #' @export #' @note rand(numeric) since 1.5.0 @@ -2647,18 +2609,13 @@ setMethod("rand", signature(seed = "numeric"), column(jc) }) -#' randn -#' -#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' @details +#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples from #' the standard normal distribution. #' -#' @param seed a random seed. Can be missing. -#' @family non-aggregate functions -#' @rdname randn -#' @name randn -#' @aliases randn,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases randn randn,missing-method #' @export -#' @examples \dontrun{randn()} #' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { @@ -2666,8 +2623,7 @@ setMethod("randn", signature(seed = "missing"), column(jc) }) -#' @rdname randn -#' @name randn +#' @rdname column_nonaggregate_functions #' @aliases randn,numeric-method #' @export #' @note randn(numeric) since 1.5.0 @@ -2819,20 +2775,26 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) column(jc) }) -#' when -#' -#' Evaluates a list of conditions and returns one of multiple possible result expressions. + +#' @details +#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result expressions. #' For unmatched expressions null is returned. #' +#' @rdname column_nonaggregate_functions #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. -#' @family non-aggregate functions -#' @rdname when -#' @name when -#' @aliases when,Column-method -#' @seealso \link{ifelse} +#' @aliases when when,Column-method #' @export -#' @examples \dontrun{when(df$age == 2, df$age + 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, mpg_na = otherwise(when(df$mpg > 20, df$mpg), lit(NaN)), +#' mpg2 = ifelse(df$mpg > 20 & df$am > 0, 0, 1), +#' mpg3 = ifelse(df$mpg > 20, df$mpg, 20.0)) +#' head(tmp) +#' tmp <- mutate(tmp, ind_na1 = is.nan(tmp$mpg_na), ind_na2 = isnan(tmp$mpg_na)) +#' head(select(tmp, coalesce(tmp$mpg_na, tmp$mpg))) +#' head(select(tmp, nanvl(tmp$mpg_na, tmp$hp)))} #' @note when since 1.5.0 setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { @@ -2842,25 +2804,16 @@ setMethod("when", signature(condition = "Column", value = "ANY"), column(jc) }) -#' ifelse -#' -#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. +#' @details +#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. #' Otherwise \code{no} is returned for unmatched conditions. #' +#' @rdname column_nonaggregate_functions #' @param test a Column expression that describes the condition. #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. -#' @family non-aggregate functions -#' @rdname ifelse -#' @name ifelse -#' @aliases ifelse,Column-method -#' @seealso \link{when} +#' @aliases ifelse ifelse,Column-method #' @export -#' @examples -#' \dontrun{ -#' ifelse(df$a > 1 & df$b > 2, 0, 1) -#' ifelse(df$a > 1, df$a, 1) -#' } #' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), @@ -3263,19 +3216,12 @@ setMethod("posexplode", column(jc) }) -#' create_array -#' -#' Creates a new array column. The input columns must all have the same data type. -#' -#' @param x Column to compute on -#' @param ... additional Column(s). +#' @details +#' \code{create_array}: Creates a new array column. The input columns must all have the same data type. #' -#' @family non-aggregate functions -#' @rdname create_array -#' @name create_array -#' @aliases create_array,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases create_array create_array,Column-method #' @export -#' @examples \dontrun{create_array(df$x, df$y, df$z)} #' @note create_array since 2.3.0 setMethod("create_array", signature(x = "Column"), @@ -3288,22 +3234,15 @@ setMethod("create_array", column(jc) }) -#' create_map -#' -#' Creates a new map column. The input columns must be grouped as key-value pairs, +#' @details +#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value pairs, #' e.g. (key1, value1, key2, value2, ...). #' The key columns must all have the same data type, and can't be null. #' The value columns must all have the same data type. #' -#' @param x Column to compute on -#' @param ... additional Column(s). -#' -#' @family non-aggregate functions -#' @rdname create_map -#' @name create_map -#' @aliases create_map,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases create_map create_map,Column-method #' @export -#' @examples \dontrun{create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))} #' @note create_map since 2.3.0 setMethod("create_map", signature(x = "Column"), @@ -3554,21 +3493,18 @@ setMethod("grouping_id", column(jc) }) -#' input_file_name -#' -#' Creates a string column with the input file name for a given row +#' @details +#' \code{input_file_name}: Creates a string column with the input file name for a given row. +#' The method should be used with no argument. #' -#' @rdname input_file_name -#' @name input_file_name -#' @family non-aggregate functions -#' @aliases input_file_name,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases input_file_name input_file_name,missing-method #' @export #' @examples -#' \dontrun{ -#' df <- read.text("README.md") #' -#' head(select(df, input_file_name())) -#' } +#' \dontrun{ +#' tmp <- read.text("README.md") +#' head(select(tmp, input_file_name()))} #' @note input_file_name since 2.3.0 setMethod("input_file_name", signature("missing"), function() { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index dc99e3d94b269..1deb057bb1b82 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -422,9 +422,8 @@ setGeneric("cache", function(x) { standardGeneric("cache") }) setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) #' @rdname coalesce -#' @param x a Column or a SparkDataFrame. -#' @param ... additional argument(s). If \code{x} is a Column, additional Columns can be optionally -#' provided. +#' @param x a SparkDataFrame. +#' @param ... additional argument(s). #' @export setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) @@ -863,8 +862,9 @@ setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @export setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) -#' @rdname when +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname otherwise @@ -938,8 +938,9 @@ setGeneric("base64", function(x) { standardGeneric("base64") }) #' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) -#' @rdname bitwiseNOT +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) #' @rdname column_math_functions @@ -995,12 +996,14 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) -#' @rdname create_array +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) -#' @rdname create_map +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @rdname hash @@ -1065,8 +1068,9 @@ setGeneric("explode", function(x) { standardGeneric("explode") }) #' @export setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) -#' @rdname expr +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) #' @rdname column_datetime_diff_functions @@ -1093,8 +1097,9 @@ setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) -#' @rdname greatest +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) #' @rdname column_aggregate_functions @@ -1127,9 +1132,9 @@ setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) -#' @param x empty. Should be used with no argument. -#' @rdname input_file_name +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) @@ -1138,8 +1143,9 @@ setGeneric("input_file_name", #' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) -#' @rdname is.nan +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname column_aggregate_functions @@ -1164,8 +1170,9 @@ setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @export setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) -#' @rdname least +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @rdname column_string_functions @@ -1173,8 +1180,9 @@ setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) -#' @rdname lit +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("lit", function(x) { standardGeneric("lit") }) #' @rdname column_string_functions @@ -1206,9 +1214,9 @@ setGeneric("md5", function(x) { standardGeneric("md5") }) #' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) -#' @param x empty. Should be used with no argument. -#' @rdname monotonically_increasing_id +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) @@ -1226,12 +1234,14 @@ setGeneric("months_between", function(y, x) { standardGeneric("months_between") #' @export setGeneric("n", function(x) { standardGeneric("n") }) -#' @rdname nanvl +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) -#' @rdname negate +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname not @@ -1275,12 +1285,14 @@ setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") #' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) -#' @rdname rand +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) -#' @rdname randn +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @rdname rank @@ -1409,8 +1421,9 @@ setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) -#' @rdname struct +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @rdname column_string_functions From 70085e83d1ee728b23f7df15f570eb8d77f67a7a Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 29 Jun 2017 09:51:12 +0100 Subject: [PATCH 041/779] [SPARK-21210][DOC][ML] Javadoc 8 fixes for ML shared param traits PR #15999 included fixes for doc strings in the ML shared param traits (occurrences of `>` and `>=`). This PR simply uses the HTML-escaped version of the param doc to embed into the Scaladoc, to ensure that when `SharedParamsCodeGen` is run, the generated javadoc will be compliant for Java 8. ## How was this patch tested? Existing tests Author: Nick Pentreath Closes #18420 from MLnick/shared-params-javadoc8. --- .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 5 ++++- .../org/apache/spark/ml/param/shared/sharedParams.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index c94b8b4e9dfda..013817a41baf5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.param.shared import java.io.PrintWriter import scala.reflect.ClassTag +import scala.xml.Utility /** * Code generator for shared params (sharedParams.scala). Run under the Spark folder with @@ -167,6 +168,8 @@ private[shared] object SharedParamsCodeGen { "def" } + val htmlCompliantDoc = Utility.escape(doc) + s""" |/** | * Trait for shared param $name$defaultValueDoc. @@ -174,7 +177,7 @@ private[shared] object SharedParamsCodeGen { |private[ml] trait Has$Name extends Params { | | /** - | * Param for $doc. + | * Param for $htmlCompliantDoc. | * @group ${groupStr(0)} | */ | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index e3e03dfd43dd6..50619607a5054 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -176,7 +176,7 @@ private[ml] trait HasThreshold extends Params { private[ml] trait HasThresholds extends Params { /** - * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. * @group param */ final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1) From d106a74c53f493c3c18741a9b19cb821dace4ba2 Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 29 Jun 2017 09:59:36 +0100 Subject: [PATCH 042/779] [SPARK-21240] Fix code style for constructing and stopping a SparkContext in UT. ## What changes were proposed in this pull request? Same with SPARK-20985. Fix code style for constructing and stopping a `SparkContext`. Assure the context is stopped to avoid other tests complain that there's only one `SparkContext` can exist. Author: jinxing Closes #18454 from jinxing64/SPARK-21240. --- .../scala/org/apache/spark/scheduler/MapStatusSuite.scala | 6 ++---- .../apache/spark/sql/execution/ui/SQLListenerSuite.scala | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index e6120139f4958..276169e02f01d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -26,6 +26,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.internal.config +import org.apache.spark.LocalSparkContext._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerId @@ -160,12 +161,9 @@ class MapStatusSuite extends SparkFunSuite { .set("spark.serializer", classOf[KryoSerializer].getName) .setMaster("local") .setAppName("SPARK-21133") - val sc = new SparkContext(conf) - try { + withSpark(new SparkContext(conf)) { sc => val count = sc.parallelize(0 until 3000, 10).repartition(2001).collect().length assert(count === 3000) - } finally { - sc.stop() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index e6cd41e4facf1..82eff5e6491ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -25,6 +25,7 @@ import org.mockito.Mockito.mock import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config +import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SparkSession} @@ -496,8 +497,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .setAppName("test") .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly - val sc = new SparkContext(conf) - try { + withSpark(new SparkContext(conf)) { sc => SparkSession.sqlListener.set(null) val spark = new SparkSession(sc) import spark.implicits._ @@ -522,8 +522,6 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { assert(spark.sharedState.listener.executionIdToData.size <= 100) assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100) assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100) - } finally { - sc.stop() } } } From d7da2b94d6107341b33ca9224e9bfa4c9a92ed88 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Thu, 29 Jun 2017 10:01:12 +0100 Subject: [PATCH 043/779] =?UTF-8?q?[SPARK-21135][WEB=20UI]=20On=20history?= =?UTF-8?q?=20server=20page=EF=BC=8Cduration=20of=20incompleted=20applicat?= =?UTF-8?q?ions=20should=20be=20hidden=20instead=20of=20showing=20up=20as?= =?UTF-8?q?=200?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Hide duration of incompleted applications. ## How was this patch tested? manual tests Author: fjh100456 Closes #18351 from fjh100456/master. --- .../spark/ui/static/historypage-template.html | 4 ++-- .../org/apache/spark/ui/static/historypage.js | 15 ++++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index bfe31aae555ba..6cff0068d8bcb 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -44,7 +44,7 @@ Completed - + Duration @@ -74,7 +74,7 @@ {{attemptId}} {{startTime}} {{endTime}} - {{duration}} + {{duration}} {{sparkUser}} {{lastUpdated}} Download diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 5ec1ce15a2127..9edd3ba0e0ba6 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -182,12 +182,17 @@ $(document).ready(function() { for (i = 0; i < completedCells.length; i++) { completedCells[i].style.display='none'; } - } - var durationCells = document.getElementsByClassName("durationClass"); - for (i = 0; i < durationCells.length; i++) { - var timeInMilliseconds = parseInt(durationCells[i].title); - durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + var durationCells = document.getElementsByClassName("durationColumn"); + for (i = 0; i < durationCells.length; i++) { + durationCells[i].style.display='none'; + } + } else { + var durationCells = document.getElementsByClassName("durationClass"); + for (i = 0; i < durationCells.length; i++) { + var timeInMilliseconds = parseInt(durationCells[i].title); + durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + } } if ($(selector.concat(" tr")).length < 20) { From 29bd251dd5914fc3b6146eb4fe0b45f1c84dba62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E6=B2=BB=E5=9B=BD10192065?= Date: Thu, 29 Jun 2017 20:53:48 +0800 Subject: [PATCH 044/779] [SPARK-21225][CORE] Considering CPUS_PER_TASK when allocating task slots for each WorkerOffer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA Issue:https://issues.apache.org/jira/browse/SPARK-21225 In the function "resourceOffers", It declare a variable "tasks" for storage the tasks which have allocated a executor. It declared like this: `val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))` But, I think this code only conside a situation for that one task per core. If the user set "spark.task.cpus" as 2 or 3, It really don't need so much Mem. I think It can motify as follow: val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) to instead. Motify like this the other earning is that it's more easy to understand the way how the tasks allocate offers. Author: 杨治国10192065 Closes #18435 from JackYangzg/motifyTaskCoreDisp. --- .../scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 91ec172ffeda1..737b383631148 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -345,7 +345,7 @@ private[spark] class TaskSchedulerImpl( val shuffledOffers = shuffleOffers(filteredOffers) // Build a list of tasks to assign to each worker. - val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { From 18066f2e61f430b691ed8a777c9b4e5786bf9dbc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Jun 2017 21:28:48 +0800 Subject: [PATCH 045/779] [SPARK-21052][SQL] Add hash map metrics to join ## What changes were proposed in this pull request? This adds the average hash map probe metrics to join operator such as `BroadcastHashJoin` and `ShuffledHashJoin`. This PR adds the API to `HashedRelation` to get average hash map probe. ## How was this patch tested? Related test cases are added. Author: Liang-Chi Hsieh Closes #18301 from viirya/SPARK-21052. --- .../aggregate/HashAggregateExec.scala | 15 +- .../TungstenAggregationIterator.scala | 34 ++-- .../joins/BroadcastHashJoinExec.scala | 30 ++- .../spark/sql/execution/joins/HashJoin.scala | 8 +- .../sql/execution/joins/HashedRelation.scala | 43 +++- .../joins/ShuffledHashJoinExec.scala | 6 +- .../sql/execution/metric/SQLMetrics.scala | 32 ++- .../execution/metric/SQLMetricsSuite.scala | 188 ++++++++++++++++-- 8 files changed, 296 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 5027a615ced7a..56f61c30c4a38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -60,7 +60,7 @@ case class HashAggregateExec( "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"), - "avgHashmapProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hashmap probe")) + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -94,7 +94,7 @@ case class HashAggregateExec( val numOutputRows = longMetric("numOutputRows") val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") - val avgHashmapProbe = longMetric("avgHashmapProbe") + val avgHashProbe = longMetric("avgHashProbe") child.execute().mapPartitions { iter => @@ -119,7 +119,7 @@ case class HashAggregateExec( numOutputRows, peakMemory, spillSize, - avgHashmapProbe) + avgHashProbe) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) @@ -344,7 +344,7 @@ case class HashAggregateExec( sorter: UnsafeKVExternalSorter, peakMemory: SQLMetric, spillSize: SQLMetric, - avgHashmapProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { + avgHashProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { // update peak execution memory val mapMemory = hashMap.getPeakMemoryUsedBytes @@ -355,8 +355,7 @@ case class HashAggregateExec( metrics.incPeakExecutionMemory(maxMemory) // Update average hashmap probe - val avgProbes = hashMap.getAverageProbesPerLookup() - avgHashmapProbe.add(avgProbes.ceil.toLong) + avgHashProbe.set(hashMap.getAverageProbesPerLookup()) if (sorter == null) { // not spilled @@ -584,7 +583,7 @@ case class HashAggregateExec( val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") - val avgHashmapProbe = metricTerm(ctx, "avgHashmapProbe") + val avgHashProbe = metricTerm(ctx, "avgHashProbe") def generateGenerateCode(): String = { if (isFastHashMapEnabled) { @@ -611,7 +610,7 @@ case class HashAggregateExec( s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize, - $avgHashmapProbe); + $avgHashProbe); } """) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 8efa95d48aea0..cfa930607360c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -89,7 +89,7 @@ class TungstenAggregationIterator( numOutputRows: SQLMetric, peakMemory: SQLMetric, spillSize: SQLMetric, - avgHashmapProbe: SQLMetric) + avgHashProbe: SQLMetric) extends AggregationIterator( groupingExpressions, originalInputAttributes, @@ -367,6 +367,22 @@ class TungstenAggregationIterator( } } + TaskContext.get().addTaskCompletionListener(_ => { + // At the end of the task, update the task's peak memory usage. Since we destroy + // the map to create the sorter, their memory usages should not overlap, so it is safe + // to just use the max of the two. + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val maxMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() + peakMemory.set(maxMemory) + spillSize.set(metrics.memoryBytesSpilled - spillSizeBefore) + metrics.incPeakExecutionMemory(maxMemory) + + // Updating average hashmap probe + avgHashProbe.set(hashMap.getAverageProbesPerLookup()) + }) + /////////////////////////////////////////////////////////////////////////// // Part 7: Iterator's public methods. /////////////////////////////////////////////////////////////////////////// @@ -409,22 +425,6 @@ class TungstenAggregationIterator( } } - // If this is the last record, update the task's peak memory usage. Since we destroy - // the map to create the sorter, their memory usages should not overlap, so it is safe - // to just use the max of the two. - if (!hasNext) { - val mapMemory = hashMap.getPeakMemoryUsedBytes - val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) - val maxMemory = Math.max(mapMemory, sorterMemory) - val metrics = TaskContext.get().taskMetrics() - peakMemory += maxMemory - spillSize += metrics.memoryBytesSpilled - spillSizeBefore - metrics.incPeakExecutionMemory(maxMemory) - - // Update average hashmap probe if this is the last record. - val averageProbes = hashMap.getAverageProbesPerLookup() - avgHashmapProbe.add(averageProbes.ceil.toLong) - } numOutputRows += 1 res } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 0bc261d593df4..bfa1e9d49a545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Dist import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType +import org.apache.spark.util.TaskCompletionListener /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -46,7 +47,8 @@ case class BroadcastHashJoinExec( extends BinaryExecNode with HashJoin with CodegenSupport { override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) @@ -60,12 +62,13 @@ case class BroadcastHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") + val avgHashProbe = longMetric("avgHashProbe") val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) - join(streamedIter, hashed, numOutputRows) + join(streamedIter, hashed, numOutputRows, avgHashProbe) } } @@ -90,6 +93,23 @@ case class BroadcastHashJoinExec( } } + /** + * Returns the codes used to add a task completion listener to update avg hash probe + * at the end of the task. + */ + private def genTaskListener(avgHashProbe: String, relationTerm: String): String = { + val listenerClass = classOf[TaskCompletionListener].getName + val taskContextClass = classOf[TaskContext].getName + s""" + | $taskContextClass$$.MODULE$$.get().addTaskCompletionListener(new $listenerClass() { + | @Override + | public void onTaskCompletion($taskContextClass context) { + | $avgHashProbe.set($relationTerm.getAverageProbesPerLookup()); + | } + | }); + """.stripMargin + } + /** * Returns a tuple of Broadcast of HashedRelation and the variable name for it. */ @@ -99,10 +119,16 @@ case class BroadcastHashJoinExec( val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName + + // At the end of the task, we update the avg hash probe. + val avgHashProbe = metricTerm(ctx, "avgHashProbe") + val addTaskListener = genTaskListener(avgHashProbe, relationTerm) + ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($relationTerm.estimatedSize()); + | $addTaskListener """.stripMargin) (broadcastRelation, relationTerm) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 1aef5f6864263..b09edf380c2d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -193,7 +194,8 @@ trait HashJoin { protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, - numOutputRows: SQLMetric): Iterator[InternalRow] = { + numOutputRows: SQLMetric, + avgHashProbe: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike => @@ -211,6 +213,10 @@ trait HashJoin { s"BroadcastHashJoin should not take $x as the JoinType") } + // At the end of the task, we update the avg hash probe. + TaskContext.get().addTaskCompletionListener(_ => + avgHashProbe.set(hashed.getAverageProbesPerLookup())) + val resultProj = createResultProjection joinedIter.map { r => numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2dd1dc3da96c9..3c702856114f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -79,6 +79,11 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { * Release any used resources. */ def close(): Unit + + /** + * Returns the average number of probes per key lookup. + */ + def getAverageProbesPerLookup(): Double } private[execution] object HashedRelation { @@ -242,7 +247,8 @@ private[joins] class UnsafeHashedRelation( binaryMap = new BytesToBytesMap( taskMemoryManager, (nKeys * 1.5 + 1).toInt, // reduce hash collision - pageSizeBytes) + pageSizeBytes, + true) var i = 0 var keyBuffer = new Array[Byte](1024) @@ -273,6 +279,8 @@ private[joins] class UnsafeHashedRelation( override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { read(in.readInt, in.readLong, in.readBytes) } + + override def getAverageProbesPerLookup(): Double = binaryMap.getAverageProbesPerLookup() } private[joins] object UnsafeHashedRelation { @@ -290,7 +298,8 @@ private[joins] object UnsafeHashedRelation { taskMemoryManager, // Only 70% of the slots can be used before growing, more capacity help to reduce collision (sizeEstimate * 1.5 + 1).toInt, - pageSizeBytes) + pageSizeBytes, + true) // Create a mapping of buildKeys -> rows val keyGenerator = UnsafeProjection.create(key) @@ -344,7 +353,7 @@ private[joins] object UnsafeHashedRelation { * determined by `key1 - minKey`. * * The map is created as sparse mode, then key-value could be appended into it. Once finish - * appending, caller could all optimize() to try to turn the map into dense mode, which is faster + * appending, caller could call optimize() to try to turn the map into dense mode, which is faster * to probe. * * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/ @@ -385,6 +394,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap // The number of unique keys. private var numKeys = 0L + // Tracking average number of probes per key lookup. + private var numKeyLookups = 0L + private var numProbes = 0L + // needed by serializer def this() = { this( @@ -469,6 +482,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { + numKeyLookups += 1 + numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -477,11 +492,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) + numKeyLookups += 1 + numProbes += 1 while (array(pos + 1) != 0) { if (array(pos) == key) { return getRow(array(pos + 1), resultRow) } pos = nextSlot(pos) + numProbes += 1 } } null @@ -509,6 +527,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { + numKeyLookups += 1 + numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -517,11 +537,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) + numKeyLookups += 1 + numProbes += 1 while (array(pos + 1) != 0) { if (array(pos) == key) { return valueIter(array(pos + 1), resultRow) } pos = nextSlot(pos) + numProbes += 1 } } null @@ -573,8 +596,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap private def updateIndex(key: Long, address: Long): Unit = { var pos = firstSlot(key) assert(numKeys < array.length / 2) + numKeyLookups += 1 + numProbes += 1 while (array(pos) != key && array(pos + 1) != 0) { pos = nextSlot(pos) + numProbes += 1 } if (array(pos + 1) == 0) { // this is the first value for this key, put the address in array. @@ -686,6 +712,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap writeLong(maxKey) writeLong(numKeys) writeLong(numValues) + writeLong(numKeyLookups) + writeLong(numProbes) writeLong(array.length) writeLongArray(writeBuffer, array, array.length) @@ -727,6 +755,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = readLong() numKeys = readLong() numValues = readLong() + numKeyLookups = readLong() + numProbes = readLong() val length = readLong().toInt mask = length - 2 @@ -742,6 +772,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap override def read(kryo: Kryo, in: Input): Unit = { read(in.readBoolean, in.readLong, in.readBytes) } + + /** + * Returns the average number of probes per key lookup. + */ + def getAverageProbesPerLookup(): Double = numProbes.toDouble / numKeyLookups } private[joins] class LongHashedRelation( @@ -793,6 +828,8 @@ private[joins] class LongHashedRelation( resultRow = new UnsafeRow(nFields) map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } + + override def getAverageProbesPerLookup(): Double = map.getAverageProbesPerLookup() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index afb6e5e3dd235..f1df41ca49c27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -42,7 +42,8 @@ case class ShuffledHashJoinExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), - "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) + "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"), + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -62,9 +63,10 @@ case class ShuffledHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") + val avgHashProbe = longMetric("avgHashProbe") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) - join(streamIter, hashed, numOutputRows) + join(streamIter, hashed, numOutputRows, avgHashProbe) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 49cab04de2bf0..b4653c1b564f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -57,6 +57,12 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato override def add(v: Long): Unit = _value += v + // We can set a double value to `SQLMetric` which stores only long value, if it is + // average metrics. + def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) + + def set(v: Long): Unit = _value = v + def +=(v: Long): Unit = _value += v override def value: Long = _value @@ -74,6 +80,19 @@ object SQLMetrics { private val TIMING_METRIC = "timing" private val AVERAGE_METRIC = "average" + private val baseForAvgMetric: Int = 10 + + /** + * Converts a double value to long value by multiplying a base integer, so we can store it in + * `SQLMetrics`. It only works for average metrics. When showing the metrics on UI, we restore + * it back to a double value up to the decimal places bound by the base integer. + */ + private[sql] def setDoubleForAverageMetrics(metric: SQLMetric, v: Double): Unit = { + assert(metric.metricType == AVERAGE_METRIC, + s"Can't set a double to a metric of metrics type: ${metric.metricType}") + metric.set((v * baseForAvgMetric).toLong) + } + def createMetric(sc: SparkContext, name: String): SQLMetric = { val acc = new SQLMetric(SUM_METRIC) acc.register(sc, name = Some(name), countFailedValues = false) @@ -104,15 +123,14 @@ object SQLMetrics { /** * Create a metric to report the average information (including min, med, max) like - * avg hashmap probe. Because `SQLMetric` stores long values, we take the ceil of the average - * values before storing them. This metric is used to record an average value computed in the - * end of a task. It should be set once. The initial values (zeros) of this metrics will be - * excluded after. + * avg hash probe. As average metrics are double values, this kind of metrics should be + * only set with `SQLMetric.set` method instead of other methods like `SQLMetric.add`. + * The initial values (zeros) of this metrics will be excluded after. */ def createAverageMetric(sc: SparkContext, name: String): SQLMetric = { // The final result of this metric in physical operator UI may looks like: // probe avg (min, med, max): - // (1, 2, 6) + // (1.2, 2.2, 6.3) val acc = new SQLMetric(AVERAGE_METRIC) acc.register(sc, name = Some(s"$name (min, med, max)"), countFailedValues = false) acc @@ -127,7 +145,7 @@ object SQLMetrics { val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) } else if (metricsType == AVERAGE_METRIC) { - val numberFormat = NumberFormat.getIntegerInstance(Locale.US) + val numberFormat = NumberFormat.getNumberInstance(Locale.US) val validValues = values.filter(_ > 0) val Seq(min, med, max) = { @@ -137,7 +155,7 @@ object SQLMetrics { val sorted = validValues.sorted Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) } - metric.map(numberFormat.format) + metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } s"\n($min, $med, $max)" } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a12ce2b9eba34..cb3405b2fe19b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -47,9 +47,10 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { private def getSparkPlanMetrics( df: DataFrame, expectedNumOfJobs: Int, - expectedNodeIds: Set[Long]): Option[Map[Long, (String, Map[String, Any])]] = { + expectedNodeIds: Set[Long], + enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = { val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) @@ -110,6 +111,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + /** + * Generates a `DataFrame` by filling randomly generated bytes for hash collision. + */ + private def generateRandomBytesDF(numRows: Int = 65535): DataFrame = { + val random = new Random() + val manyBytes = (0 until numRows).map { _ => + val byteArrSize = random.nextInt(100) + val bytes = new Array[Byte](byteArrSize) + random.nextBytes(bytes) + (bytes, random.nextInt(100)) + } + manyBytes.toSeq.toDF("a", "b") + } + test("LocalTableScanExec computes metrics in collect and take") { val df1 = spark.createDataset(Seq(1, 2, 3)) val logical = df1.queryExecution.logical @@ -151,9 +166,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = testData2.groupBy().count() // 2 partitions val expected1 = Seq( Map("number of output rows" -> 2L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"), + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 1L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)")) + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) testSparkPlanMetrics(df, 1, Map( 2L -> ("HashAggregate", expected1(0)), 0L -> ("HashAggregate", expected1(1))) @@ -163,9 +178,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df2 = testData2.groupBy('a).count() val expected2 = Seq( Map("number of output rows" -> 4L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"), + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 3L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)")) + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) testSparkPlanMetrics(df2, 1, Map( 2L -> ("HashAggregate", expected2(0)), 0L -> ("HashAggregate", expected2(1))) @@ -173,19 +188,42 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } test("Aggregate metrics: track avg probe") { - val random = new Random() - val manyBytes = (0 until 65535).map { _ => - val byteArrSize = random.nextInt(100) - val bytes = new Array[Byte](byteArrSize) - random.nextBytes(bytes) - (bytes, random.nextInt(100)) - } - val df = manyBytes.toSeq.toDF("a", "b").repartition(1).groupBy('a).count() - val metrics = getSparkPlanMetrics(df, 1, Set(2L, 0L)).get - Seq(metrics(2L)._2("avg hashmap probe (min, med, max)"), - metrics(0L)._2("avg hashmap probe (min, med, max)")).foreach { probes => - probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => - assert(probe.toInt > 1) + // The executed plan looks like: + // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) + // +- Exchange hashpartitioning(a#61, 5) + // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) + // +- Exchange RoundRobinPartitioning(1) + // +- LocalTableScan [a#61] + // + // Assume the execution plan with node id is: + // Wholestage disabled: + // HashAggregate(nodeId = 0) + // Exchange(nodeId = 1) + // HashAggregate(nodeId = 2) + // Exchange (nodeId = 3) + // LocalTableScan(nodeId = 4) + // + // Wholestage enabled: + // WholeStageCodegen(nodeId = 0) + // HashAggregate(nodeId = 1) + // Exchange(nodeId = 2) + // WholeStageCodegen(nodeId = 3) + // HashAggregate(nodeId = 4) + // Exchange(nodeId = 5) + // LocalTableScan(nodeId = 6) + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) + } else { + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } } } } @@ -267,10 +305,120 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = df1.join(broadcast(df2), "key") testSparkPlanMetrics(df, 2, Map( 1L -> ("BroadcastHashJoin", Map( - "number of output rows" -> 2L))) + "number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))) ) } + test("BroadcastHashJoin metrics: track avg probe") { + // The executed plan looks like: + // Project [a#210, b#211, b#221] + // +- BroadcastHashJoin [a#210], [a#220], Inner, BuildRight + // :- Project [_1#207 AS a#210, _2#208 AS b#211] + // : +- Filter isnotnull(_1#207) + // : +- LocalTableScan [_1#207, _2#208] + // +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, binary, true])) + // +- Project [_1#217 AS a#220, _2#218 AS b#221] + // +- Filter isnotnull(_1#217) + // +- LocalTableScan [_1#217, _2#218] + // + // Assume the execution plan with node id is + // WholeStageCodegen disabled: + // Project(nodeId = 0) + // BroadcastHashJoin(nodeId = 1) + // ...(ignored) + // + // WholeStageCodegen enabled: + // WholeStageCodegen(nodeId = 0) + // Project(nodeId = 1) + // BroadcastHashJoin(nodeId = 2) + // Project(nodeId = 3) + // Filter(nodeId = 4) + // ...(ignored) + Seq(true, false).foreach { enableWholeStage => + val df1 = generateRandomBytesDF() + val df2 = generateRandomBytesDF() + val df = df1.join(broadcast(df2), "a") + val nodeIds = if (enableWholeStage) { + Set(2L) + } else { + Set(1L) + } + val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } + } + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40", + "spark.sql.shuffle.partitions" -> "2", + "spark.sql.join.preferSortMergeJoin" -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key", "value") + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0) + val df = df1.join(df2, "key") + val metrics = getSparkPlanMetrics(df, 1, Set(1L)) + testSparkPlanMetrics(df, 1, Map( + 1L -> ("ShuffledHashJoin", Map( + "number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))) + ) + } + } + + test("ShuffledHashJoin metrics: track avg probe") { + // The executed plan looks like: + // Project [a#308, b#309, b#319] + // +- ShuffledHashJoin [a#308], [a#318], Inner, BuildRight + // :- Exchange hashpartitioning(a#308, 2) + // : +- Project [_1#305 AS a#308, _2#306 AS b#309] + // : +- Filter isnotnull(_1#305) + // : +- LocalTableScan [_1#305, _2#306] + // +- Exchange hashpartitioning(a#318, 2) + // +- Project [_1#315 AS a#318, _2#316 AS b#319] + // +- Filter isnotnull(_1#315) + // +- LocalTableScan [_1#315, _2#316] + // + // Assume the execution plan with node id is + // WholeStageCodegen disabled: + // Project(nodeId = 0) + // ShuffledHashJoin(nodeId = 1) + // ...(ignored) + // + // WholeStageCodegen enabled: + // WholeStageCodegen(nodeId = 0) + // Project(nodeId = 1) + // ShuffledHashJoin(nodeId = 2) + // ...(ignored) + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "5000000", + "spark.sql.shuffle.partitions" -> "2", + "spark.sql.join.preferSortMergeJoin" -> "false") { + Seq(true, false).foreach { enableWholeStage => + val df1 = generateRandomBytesDF(65535 * 5) + val df2 = generateRandomBytesDF(65535) + val df = df1.join(df2, "a") + val nodeIds = if (enableWholeStage) { + Set(2L) + } else { + Set(1L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } + } + } + } + } + test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") From f9151bebca986d44cdab7699959fec2bc050773a Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Thu, 29 Jun 2017 16:03:15 -0700 Subject: [PATCH 046/779] [SPARK-21188][CORE] releaseAllLocksForTask should synchronize the whole method ## What changes were proposed in this pull request? Since the objects `readLocksByTask`, `writeLocksByTask` and `info`s are coupled and supposed to be modified by other threads concurrently, all the read and writes of them in the method `releaseAllLocksForTask` should be protected by a single synchronized block like other similar methods. ## How was this patch tested? existing tests Author: Feng Liu Closes #18400 from liufengdb/synchronize. --- .../spark/storage/BlockInfoManager.scala | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 7064872ec1c77..219a0e799cc73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -341,15 +341,11 @@ private[storage] class BlockInfoManager extends Logging { * * @return the ids of blocks whose pins were released */ - def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = { + def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = synchronized { val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() - val readLocks = synchronized { - readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) - } - val writeLocks = synchronized { - writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) - } + val readLocks = readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) + val writeLocks = writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) for (blockId <- writeLocks) { infos.get(blockId).foreach { info => @@ -358,21 +354,19 @@ private[storage] class BlockInfoManager extends Logging { } blocksWithReleasedLocks += blockId } + readLocks.entrySet().iterator().asScala.foreach { entry => val blockId = entry.getElement val lockCount = entry.getCount blocksWithReleasedLocks += blockId - synchronized { - get(blockId).foreach { info => - info.readerCount -= lockCount - assert(info.readerCount >= 0) - } + get(blockId).foreach { info => + info.readerCount -= lockCount + assert(info.readerCount >= 0) } } - synchronized { - notifyAll() - } + notifyAll() + blocksWithReleasedLocks } From 4996c53949376153f9ebdc74524fed7226968808 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 30 Jun 2017 10:56:48 +0800 Subject: [PATCH 047/779] [SPARK-21253][CORE] Fix a bug that StreamCallback may not be notified if network errors happen ## What changes were proposed in this pull request? If a network error happens before processing StreamResponse/StreamFailure events, StreamCallback.onFailure won't be called. This PR fixes `failOutstandingRequests` to also notify outstanding StreamCallbacks. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #18472 from zsxwing/fix-stream-2. --- .../spark/network/client/TransportClient.java | 2 +- .../client/TransportResponseHandler.java | 38 ++++++++++++++----- .../TransportResponseHandlerSuite.java | 31 ++++++++++++++- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index a6f527c118218..8f354ad78bbaa 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -179,7 +179,7 @@ public void stream(String streamId, StreamCallback callback) { // written to the socket atomically, so that callbacks are called in the right order // when responses arrive. synchronized (this) { - handler.addStreamCallback(callback); + handler.addStreamCallback(streamId, callback); channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 41bead546cad6..be9f18203c8e4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,8 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import scala.Tuple2; + import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -56,7 +58,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; - private final Queue streamCallbacks; + private final Queue> streamCallbacks; private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ @@ -88,9 +90,9 @@ public void removeRpcRequest(long requestId) { outstandingRpcs.remove(requestId); } - public void addStreamCallback(StreamCallback callback) { + public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); - streamCallbacks.offer(callback); + streamCallbacks.offer(Tuple2.apply(streamId, callback)); } @VisibleForTesting @@ -104,15 +106,31 @@ public void deactivateStream() { */ private void failOutstandingRequests(Throwable cause) { for (Map.Entry entry : outstandingFetches.entrySet()) { - entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + try { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } catch (Exception e) { + logger.warn("ChunkReceivedCallback.onFailure throws exception", e); + } } for (Map.Entry entry : outstandingRpcs.entrySet()) { - entry.getValue().onFailure(cause); + try { + entry.getValue().onFailure(cause); + } catch (Exception e) { + logger.warn("RpcResponseCallback.onFailure throws exception", e); + } + } + for (Tuple2 entry : streamCallbacks) { + try { + entry._2().onFailure(entry._1(), cause); + } catch (Exception e) { + logger.warn("StreamCallback.onFailure throws exception", e); + } } // It's OK if new fetches appear, as they will fail immediately. outstandingFetches.clear(); outstandingRpcs.clear(); + streamCallbacks.clear(); } @Override @@ -190,8 +208,9 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { + Tuple2 entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry._2(); if (resp.byteCount > 0) { StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, callback); @@ -216,8 +235,9 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamFailure) { StreamFailure resp = (StreamFailure) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { + Tuple2 entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry._2(); try { callback.onFailure(resp.streamId, new RuntimeException(resp.error)); } catch (IOException ioe) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 09fc80d12d510..b4032c4c3f031 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.IOException; import java.nio.ByteBuffer; import io.netty.channel.Channel; @@ -127,7 +128,7 @@ public void testActiveStreams() throws Exception { StreamResponse response = new StreamResponse("stream", 1234L, null); StreamCallback cb = mock(StreamCallback.class); - handler.addStreamCallback(cb); + handler.addStreamCallback("stream", cb); assertEquals(1, handler.numOutstandingRequests()); handler.handle(response); assertEquals(1, handler.numOutstandingRequests()); @@ -135,9 +136,35 @@ public void testActiveStreams() throws Exception { assertEquals(0, handler.numOutstandingRequests()); StreamFailure failure = new StreamFailure("stream", "uh-oh"); - handler.addStreamCallback(cb); + handler.addStreamCallback("stream", cb); assertEquals(1, handler.numOutstandingRequests()); handler.handle(failure); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void failOutstandingStreamCallbackOnClose() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback("stream-1", cb); + handler.channelInactive(); + + verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); + } + + @Test + public void failOutstandingStreamCallbackOnException() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback("stream-1", cb); + handler.exceptionCaught(new IOException("Oops!")); + + verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); + } } From 80f7ac3a601709dd9471092244612023363f54cd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 30 Jun 2017 11:02:22 +0800 Subject: [PATCH 048/779] [SPARK-21253][CORE] Disable spark.reducer.maxReqSizeShuffleToMem ## What changes were proposed in this pull request? Disable spark.reducer.maxReqSizeShuffleToMem because it breaks the old shuffle service. Credits to wangyum Closes #18466 ## How was this patch tested? Jenkins Author: Shixiong Zhu Author: Yuming Wang Closes #18467 from zsxwing/SPARK-21253. --- .../scala/org/apache/spark/internal/config/package.scala | 3 ++- docs/configuration.md | 8 -------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index be63c637a3a13..8dee0d970c4c6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -323,10 +323,11 @@ package object config { private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") + .internal() .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + "above this threshold. This is to avoid a giant request takes too much memory.") .bytesConf(ByteUnit.BYTE) - .createWithDefaultString("200m") + .createWithDefault(Long.MaxValue) private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") diff --git a/docs/configuration.md b/docs/configuration.md index c8e61537a457c..bd6a1f9e240e2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -528,14 +528,6 @@ Apart from these, the following properties are also available, and may be useful By allowing it to limit the number of fetch requests, this scenario can be mitigated. - - spark.reducer.maxReqSizeShuffleToMem - 200m - - The blocks of a shuffle request will be fetched to disk when size of the request is above - this threshold. This is to avoid a giant request takes too much memory. - - spark.shuffle.compress true From 88a536babf119b7e331d02aac5d52b57658803bf Mon Sep 17 00:00:00 2001 From: IngoSchuster Date: Fri, 30 Jun 2017 11:16:09 +0800 Subject: [PATCH 049/779] [SPARK-21176][WEB UI] Limit number of selector threads for admin ui proxy servlets to 8 ## What changes were proposed in this pull request? Please see also https://issues.apache.org/jira/browse/SPARK-21176 This change limits the number of selector threads that jetty creates to maximum 8 per proxy servlet (Jetty default is number of processors / 2). The newHttpClient for Jettys ProxyServlet class is overwritten to avoid the Jetty defaults (which are designed for high-performance http servers). Once https://github.com/eclipse/jetty.project/issues/1643 is available, the code could be cleaned up to avoid the method override. I really need this on v2.1.1 - what is the best way for a backport automatic merge works fine)? Shall I create another PR? ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) The patch was tested manually on a Spark cluster with a head node that has 88 processors using JMX to verify that the number of selector threads is now limited to 8 per proxy. gurvindersingh zsxwing can you please review the change? Author: IngoSchuster Author: Ingo Schuster Closes #18437 from IngoSchuster/master. --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index edf328b5ae538..b9371c7ad7b45 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -26,6 +26,8 @@ import scala.language.implicitConversions import scala.xml.Node import org.eclipse.jetty.client.api.Response +import org.eclipse.jetty.client.HttpClient +import org.eclipse.jetty.client.http.HttpClientTransportOverHTTP import org.eclipse.jetty.proxy.ProxyServlet import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ @@ -208,6 +210,16 @@ private[spark] object JettyUtils extends Logging { rewrittenURI.toString() } + override def newHttpClient(): HttpClient = { + // SPARK-21176: Use the Jetty logic to calculate the number of selector threads (#CPUs/2), + // but limit it to 8 max. + // Otherwise, it might happen that we exhaust the threadpool since in reverse proxy mode + // a proxy is instantiated for each executor. If the head node has many processors, this + // can quickly add up to an unreasonably high number of threads. + val numSelectors = math.max(1, math.min(8, Runtime.getRuntime().availableProcessors() / 2)) + new HttpClient(new HttpClientTransportOverHTTP(numSelectors), null) + } + override def filterServerResponseHeader( clientRequest: HttpServletRequest, serverResponse: Response, From cfc696f4a4289acf132cb26baf7c02c5b6305277 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 29 Jun 2017 20:56:37 -0700 Subject: [PATCH 050/779] [SPARK-21253][CORE][HOTFIX] Fix Scala 2.10 build ## What changes were proposed in this pull request? A follow up PR to fix Scala 2.10 build for #18472 ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #18478 from zsxwing/SPARK-21253-2. --- .../apache/spark/network/client/TransportResponseHandler.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index be9f18203c8e4..340b8b96aabc6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -92,7 +92,7 @@ public void removeRpcRequest(long requestId) { public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); - streamCallbacks.offer(Tuple2.apply(streamId, callback)); + streamCallbacks.offer(new Tuple2<>(streamId, callback)); } @VisibleForTesting From e2f32ee45ac907f1f53fde7e412676a849a94872 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 30 Jun 2017 12:34:09 +0800 Subject: [PATCH 051/779] [SPARK-21258][SQL] Fix WindowExec complex object aggregation with spilling ## What changes were proposed in this pull request? `WindowExec` currently improperly stores complex objects (UnsafeRow, UnsafeArrayData, UnsafeMapData, UTF8String) during aggregation by keeping a reference in the buffer used by `GeneratedMutableProjections` to the actual input data. Things go wrong when the input object (or the backing bytes) are reused for other things. This could happen in window functions when it starts spilling to disk. When reading the back the spill files the `UnsafeSorterSpillReader` reuses the buffer to which the `UnsafeRow` points, leading to weird corruption scenario's. Note that this only happens for aggregate functions that preserve (parts of) their input, for example `FIRST`, `LAST`, `MIN` & `MAX`. This was not seen before, because the spilling logic was not doing actual spills as much and actually used an in-memory page. This page was not cleaned up during window processing and made sure unsafe objects point to their own dedicated memory location. This was changed by https://github.com/apache/spark/pull/16909, after this PR Spark spills more eagerly. This PR provides a surgical fix because we are close to releasing Spark 2.2. This change just makes sure that there cannot be any object reuse at the expensive of a little bit of performance. We will follow-up with a more subtle solution at a later point. ## How was this patch tested? Added a regression test to `DataFrameWindowFunctionsSuite`. Author: Herman van Hovell Closes #18470 from hvanhovell/SPARK-21258. --- .../execution/window/AggregateProcessor.scala | 7 ++- .../sql/DataFrameWindowFunctionsSuite.scala | 47 ++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index bc141b36e63b4..2195c6ea95948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -145,10 +145,13 @@ private[window] final class AggregateProcessor( /** Update the buffer. */ def update(input: InternalRow): Unit = { - updateProjection(join(buffer, input)) + // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that + // MutableProjection makes copies of the complex input objects it buffer. + val copy = input.copy() + updateProjection(join(buffer, copy)) var i = 0 while (i < numImperatives) { - imperatives(i).update(buffer, input) + imperatives(i).update(buffer, copy) i += 1 } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1255c49104718..204858fa29787 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types._ /** * Window function testing for DataFrame API. @@ -423,4 +424,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { df.select(selectList: _*).where($"value" < 2), Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + + test("SPARK-21258: complex object in combination with spilling") { + // Make sure we trigger the spilling path. + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { + val sampleSchema = new StructType(). + add("f0", StringType). + add("f1", LongType). + add("f2", ArrayType(new StructType(). + add("f20", StringType))). + add("f3", ArrayType(new StructType(). + add("f30", StringType))) + + val w0 = Window.partitionBy("f0").orderBy("f1") + val w1 = w0.rowsBetween(Long.MinValue, Long.MaxValue) + + val c0 = first(struct($"f2", $"f3")).over(w0) as "c0" + val c1 = last(struct($"f2", $"f3")).over(w1) as "c1" + + val input = + """{"f1":1497820153720,"f2":[{"f20":"x","f21":0}],"f3":[{"f30":"x","f31":0}]} + |{"f1":1497802179638} + |{"f1":1497802189347} + |{"f1":1497802189593} + |{"f1":1497802189597} + |{"f1":1497802189599} + |{"f1":1497802192103} + |{"f1":1497802193414} + |{"f1":1497802193577} + |{"f1":1497802193709} + |{"f1":1497802202883} + |{"f1":1497802203006} + |{"f1":1497802203743} + |{"f1":1497802203834} + |{"f1":1497802203887} + |{"f1":1497802203893} + |{"f1":1497802203976} + |{"f1":1497820168098} + |""".stripMargin.split("\n").toSeq + + import testImplicits._ + + spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + } + } } From fddb63f46345be36c40d9a7f3660920af6502bbd Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 29 Jun 2017 21:35:01 -0700 Subject: [PATCH 052/779] [SPARK-20889][SPARKR] Grouped documentation for MISC column methods ## What changes were proposed in this pull request? Grouped documentation for column misc methods. Author: actuaryzhang Author: Wayne Zhang Closes #18448 from actuaryzhang/sparkRDocMisc. --- R/pkg/R/functions.R | 98 +++++++++++++++++++++------------------------ R/pkg/R/generics.R | 15 ++++--- 2 files changed, 55 insertions(+), 58 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index cb09e847d739a..67cb7a7f6db08 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -150,6 +150,27 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} NULL +#' Miscellaneous functions for Column operations +#' +#' Miscellaneous functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{sha2}, it is one of 224, 256, 384, or 512. +#' @param y Column to compute on. +#' @param ... additional Columns. +#' @name column_misc_functions +#' @rdname column_misc_functions +#' @family misc functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)[, 1:2]) +#' tmp <- mutate(df, v1 = crc32(df$model), v2 = hash(df$model), +#' v3 = hash(df$model, df$mpg), v4 = md5(df$model), +#' v5 = sha1(df$model), v6 = sha2(df$model, 256)) +#' head(tmp) +#' } +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -569,19 +590,13 @@ setMethod("count", column(jc) }) -#' crc32 -#' -#' Calculates the cyclic redundancy check value (CRC32) of a binary column and -#' returns the value as a bigint. -#' -#' @param x Column to compute on. +#' @details +#' \code{crc32}: Calculates the cyclic redundancy check value (CRC32) of a binary column +#' and returns the value as a bigint. #' -#' @rdname crc32 -#' @name crc32 -#' @family misc functions -#' @aliases crc32,Column-method +#' @rdname column_misc_functions +#' @aliases crc32 crc32,Column-method #' @export -#' @examples \dontrun{crc32(df$c)} #' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), @@ -590,19 +605,13 @@ setMethod("crc32", column(jc) }) -#' hash -#' -#' Calculates the hash code of given columns, and returns the result as a int column. -#' -#' @param x Column to compute on. -#' @param ... additional Column(s) to be included. +#' @details +#' \code{hash}: Calculates the hash code of given columns, and returns the result +#' as an int column. #' -#' @rdname hash -#' @name hash -#' @family misc functions -#' @aliases hash,Column-method +#' @rdname column_misc_functions +#' @aliases hash hash,Column-method #' @export -#' @examples \dontrun{hash(df$c)} #' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), @@ -1055,19 +1064,13 @@ setMethod("max", column(jc) }) -#' md5 -#' -#' Calculates the MD5 digest of a binary column and returns the value +#' @details +#' \code{md5}: Calculates the MD5 digest of a binary column and returns the value #' as a 32 character hex string. #' -#' @param x Column to compute on. -#' -#' @rdname md5 -#' @name md5 -#' @family misc functions -#' @aliases md5,Column-method +#' @rdname column_misc_functions +#' @aliases md5 md5,Column-method #' @export -#' @examples \dontrun{md5(df$c)} #' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), @@ -1307,19 +1310,13 @@ setMethod("second", column(jc) }) -#' sha1 -#' -#' Calculates the SHA-1 digest of a binary column and returns the value +#' @details +#' \code{sha1}: Calculates the SHA-1 digest of a binary column and returns the value #' as a 40 character hex string. #' -#' @param x Column to compute on. -#' -#' @rdname sha1 -#' @name sha1 -#' @family misc functions -#' @aliases sha1,Column-method +#' @rdname column_misc_functions +#' @aliases sha1 sha1,Column-method #' @export -#' @examples \dontrun{sha1(df$c)} #' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), @@ -2309,19 +2306,14 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), column(jc) }) -#' sha2 -#' -#' Calculates the SHA-2 family of hash functions of a binary column and -#' returns the value as a hex string. +#' @details +#' \code{sha2}: Calculates the SHA-2 family of hash functions of a binary column and +#' returns the value as a hex string. The second argument \code{x} specifies the number +#' of bits, and is one of 224, 256, 384, or 512. #' -#' @param y column to compute SHA-2 on. -#' @param x one of 224, 256, 384, or 512. -#' @family misc functions -#' @rdname sha2 -#' @name sha2 -#' @aliases sha2,Column,numeric-method +#' @rdname column_misc_functions +#' @aliases sha2 sha2,Column,numeric-method #' @export -#' @examples \dontrun{sha2(df$c, 256)} #' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 1deb057bb1b82..bdd4b360f4973 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -992,8 +992,9 @@ setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) -#' @rdname crc32 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @rdname column_nonaggregate_functions @@ -1006,8 +1007,9 @@ setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) #' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) -#' @rdname hash +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @param x empty. Should be used with no argument. @@ -1205,8 +1207,9 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) -#' @rdname md5 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) #' @rdname column_datetime_functions @@ -1350,12 +1353,14 @@ setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) -#' @rdname sha1 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("sha1", function(x) { standardGeneric("sha1") }) -#' @rdname sha2 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) #' @rdname column_math_functions From 52981715bb8d653a1141f55b36da804412eb783a Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 29 Jun 2017 23:00:50 -0700 Subject: [PATCH 053/779] [SPARK-20889][SPARKR] Grouped documentation for COLLECTION column methods ## What changes were proposed in this pull request? Grouped documentation for column collection methods. Author: actuaryzhang Author: Wayne Zhang Closes #18458 from actuaryzhang/sparkRDocCollection. --- R/pkg/R/functions.R | 204 +++++++++++++++++++------------------------- R/pkg/R/generics.R | 27 ++++-- 2 files changed, 108 insertions(+), 123 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 67cb7a7f6db08..a1f5c4f8cc18d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -171,6 +171,35 @@ NULL #' } NULL +#' Collection functions for Column operations +#' +#' Collection functions defined for \code{Column}. +#' +#' @param x Column to compute on. Note the difference in the following methods: +#' \itemize{ +#' \item \code{to_json}: it is the column containing the struct or array of the structs. +#' \item \code{from_json}: it is the column containing the JSON string. +#' } +#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains +#' additional named properties to control how it is converted, accepts the same +#' options as the JSON data source. +#' @name column_collection_functions +#' @rdname column_collection_functions +#' @family collection functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) +#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) +#' head(tmp2) +#' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, sort_array(tmp$v1))) +#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))} +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -1642,30 +1671,23 @@ setMethod("to_date", column(jc) }) -#' to_json -#' -#' Converts a column containing a \code{structType} or array of \code{structType} into a Column -#' of JSON string. Resolving the Column can fail if an unsupported type is encountered. -#' -#' @param x Column containing the struct or array of the structs -#' @param ... additional named properties to control how it is converted, accepts the same options -#' as the JSON data source. +#' @details +#' \code{to_json}: Converts a column containing a \code{structType} or array of \code{structType} +#' into a Column of JSON string. Resolving the Column can fail if an unsupported type is encountered. #' -#' @family non-aggregate functions -#' @rdname to_json -#' @name to_json -#' @aliases to_json,Column-method +#' @rdname column_collection_functions +#' @aliases to_json to_json,Column-method #' @export #' @examples +#' #' \dontrun{ #' # Converts a struct into a JSON object -#' df <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") -#' select(df, to_json(df$d, dateFormat = 'dd/MM/yyyy')) +#' df2 <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' select(df2, to_json(df2$d, dateFormat = 'dd/MM/yyyy')) #' #' # Converts an array of structs into a JSON array -#' df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") -#' select(df, to_json(df$people)) -#'} +#' df2 <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people))} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), function(x, ...) { @@ -2120,28 +2142,28 @@ setMethod("date_format", signature(y = "Column", x = "character"), column(jc) }) -#' from_json -#' -#' Parses a column containing a JSON string into a Column of \code{structType} with the specified -#' \code{schema} or array of \code{structType} if \code{as.json.array} is set to \code{TRUE}. -#' If the string is unparseable, the Column will contains the value NA. +#' @details +#' \code{from_json}: Parses a column containing a JSON string into a Column of \code{structType} +#' with the specified \code{schema} or array of \code{structType} if \code{as.json.array} is set +#' to \code{TRUE}. If the string is unparseable, the Column will contain the value NA. #' -#' @param x Column containing the JSON string. +#' @rdname column_collection_functions #' @param schema a structType object to use as the schema to use when parsing the JSON string. #' @param as.json.array indicating if input string is JSON array of objects or a single object. -#' @param ... additional named properties to control how the json is parsed, accepts the same -#' options as the JSON data source. -#' -#' @family non-aggregate functions -#' @rdname from_json -#' @name from_json -#' @aliases from_json,Column,structType-method +#' @aliases from_json from_json,Column,structType-method #' @export #' @examples +#' #' \dontrun{ -#' schema <- structType(structField("name", "string"), -#' select(df, from_json(df$value, schema, dateFormat = "dd/MM/yyyy")) -#'} +#' df2 <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' df2 <- mutate(df2, d2 = to_json(df2$d, dateFormat = 'dd/MM/yyyy')) +#' schema <- structType(structField("date", "string")) +#' head(select(df2, from_json(df2$d2, schema, dateFormat = 'dd/MM/yyyy'))) + +#' df2 <- sql("SELECT named_struct('name', 'Bob') as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' schema <- structType(structField("name", "string")) +#' head(select(df2, from_json(df2$people_json, schema)))} #' @note from_json since 2.2.0 setMethod("from_json", signature(x = "Column", schema = "structType"), function(x, schema, as.json.array = FALSE, ...) { @@ -3101,18 +3123,14 @@ setMethod("row_number", ###################### Collection functions###################### -#' array_contains -#' -#' Returns null if the array is null, true if the array contains the value, and false otherwise. +#' @details +#' \code{array_contains}: Returns null if the array is null, true if the array contains +#' the value, and false otherwise. #' -#' @param x A Column #' @param value A value to be checked if contained in the column -#' @rdname array_contains -#' @aliases array_contains,Column-method -#' @name array_contains -#' @family collection functions +#' @rdname column_collection_functions +#' @aliases array_contains array_contains,Column-method #' @export -#' @examples \dontrun{array_contains(df$c, 1)} #' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), @@ -3121,18 +3139,12 @@ setMethod("array_contains", column(jc) }) -#' explode -#' -#' Creates a new row for each element in the given array or map column. -#' -#' @param x Column to compute on +#' @details +#' \code{explode}: Creates a new row for each element in the given array or map column. #' -#' @rdname explode -#' @name explode -#' @family collection functions -#' @aliases explode,Column-method +#' @rdname column_collection_functions +#' @aliases explode explode,Column-method #' @export -#' @examples \dontrun{explode(df$c)} #' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), @@ -3141,18 +3153,12 @@ setMethod("explode", column(jc) }) -#' size -#' -#' Returns length of array or map. -#' -#' @param x Column to compute on +#' @details +#' \code{size}: Returns length of array or map. #' -#' @rdname size -#' @name size -#' @aliases size,Column-method -#' @family collection functions +#' @rdname column_collection_functions +#' @aliases size size,Column-method #' @export -#' @examples \dontrun{size(df$c)} #' @note size since 1.5.0 setMethod("size", signature(x = "Column"), @@ -3161,25 +3167,16 @@ setMethod("size", column(jc) }) -#' sort_array -#' -#' Sorts the input array in ascending or descending order according +#' @details +#' \code{sort_array}: Sorts the input array in ascending or descending order according #' to the natural ordering of the array elements. #' -#' @param x A Column to sort +#' @rdname column_collection_functions #' @param asc A logical flag indicating the sorting order. #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. -#' @rdname sort_array -#' @name sort_array -#' @aliases sort_array,Column-method -#' @family collection functions +#' @aliases sort_array sort_array,Column-method #' @export -#' @examples -#' \dontrun{ -#' sort_array(df$c) -#' sort_array(df$c, FALSE) -#' } #' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), @@ -3188,18 +3185,13 @@ setMethod("sort_array", column(jc) }) -#' posexplode -#' -#' Creates a new row for each element with position in the given array or map column. -#' -#' @param x Column to compute on +#' @details +#' \code{posexplode}: Creates a new row for each element with position in the given array +#' or map column. #' -#' @rdname posexplode -#' @name posexplode -#' @family collection functions -#' @aliases posexplode,Column-method +#' @rdname column_collection_functions +#' @aliases posexplode posexplode,Column-method #' @export -#' @examples \dontrun{posexplode(df$c)} #' @note posexplode since 2.1.0 setMethod("posexplode", signature(x = "Column"), @@ -3325,27 +3317,24 @@ setMethod("repeat_string", column(jc) }) -#' explode_outer -#' -#' Creates a new row for each element in the given array or map column. +#' @details +#' \code{explode}: Creates a new row for each element in the given array or map column. #' Unlike \code{explode}, if the array/map is \code{null} or empty #' then \code{null} is produced. #' -#' @param x Column to compute on #' -#' @rdname explode_outer -#' @name explode_outer -#' @family collection functions -#' @aliases explode_outer,Column-method +#' @rdname column_collection_functions +#' @aliases explode_outer explode_outer,Column-method #' @export #' @examples +#' #' \dontrun{ -#' df <- createDataFrame(data.frame( +#' df2 <- createDataFrame(data.frame( #' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") #' )) #' -#' head(select(df, df$id, explode_outer(split_string(df$text, ",")))) -#' } +#' head(select(df2, df2$id, explode_outer(split_string(df2$text, ",")))) +#' head(select(df2, df2$id, posexplode_outer(split_string(df2$text, ","))))} #' @note explode_outer since 2.3.0 setMethod("explode_outer", signature(x = "Column"), @@ -3354,27 +3343,14 @@ setMethod("explode_outer", column(jc) }) -#' posexplode_outer -#' -#' Creates a new row for each element with position in the given array or map column. -#' Unlike \code{posexplode}, if the array/map is \code{null} or empty +#' @details +#' \code{posexplode_outer}: Creates a new row for each element with position in the given +#' array or map column. Unlike \code{posexplode}, if the array/map is \code{null} or empty #' then the row (\code{null}, \code{null}) is produced. #' -#' @param x Column to compute on -#' -#' @rdname posexplode_outer -#' @name posexplode_outer -#' @family collection functions -#' @aliases posexplode_outer,Column-method +#' @rdname column_collection_functions +#' @aliases posexplode_outer posexplode_outer,Column-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(data.frame( -#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") -#' )) -#' -#' head(select(df, df$id, posexplode_outer(split_string(df$text, ",")))) -#' } #' @note posexplode_outer since 2.3.0 setMethod("posexplode_outer", signature(x = "Column"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index bdd4b360f4973..b901b74e4728d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -913,8 +913,9 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) -#' @rdname array_contains +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) #' @rdname column_string_functions @@ -1062,12 +1063,14 @@ setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) -#' @rdname explode +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("explode", function(x) { standardGeneric("explode") }) -#' @rdname explode_outer +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @rdname column_nonaggregate_functions @@ -1090,8 +1093,9 @@ setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) -#' @rdname from_json +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) #' @rdname column_datetime_functions @@ -1275,12 +1279,14 @@ setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_ra #' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) -#' @rdname posexplode +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) -#' @rdname posexplode_outer +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) #' @rdname column_datetime_functions @@ -1383,8 +1389,9 @@ setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUns #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) -#' @rdname size +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("size", function(x) { standardGeneric("size") }) #' @rdname column_aggregate_functions @@ -1392,8 +1399,9 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) -#' @rdname sort_array +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) #' @rdname column_string_functions @@ -1456,8 +1464,9 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) -#' @rdname to_json +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) #' @rdname column_datetime_functions From 49d767d838691fc7d964be2c4349662f5500ff2b Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Fri, 30 Jun 2017 20:02:15 +0800 Subject: [PATCH 054/779] [SPARK-18710][ML] Add offset in GLM ## What changes were proposed in this pull request? Add support for offset in GLM. This is useful for at least two reasons: 1. Account for exposure: e.g., when modeling the number of accidents, we may need to use miles driven as an offset to access factors on frequency. 2. Test incremental effects of new variables: we can use predictions from the existing model as offset and run a much smaller model on only new variables. This avoids re-estimating the large model with all variables (old + new) and can be very important for efficient large-scaled analysis. ## How was this patch tested? New test. yanboliang srowen felixcheung sethah Author: actuaryzhang Closes #16699 from actuaryzhang/offset. --- .../apache/spark/ml/feature/Instance.scala | 21 + .../IterativelyReweightedLeastSquares.scala | 14 +- .../spark/ml/optim/WeightedLeastSquares.scala | 2 +- .../GeneralizedLinearRegression.scala | 184 +++-- ...erativelyReweightedLeastSquaresSuite.scala | 40 +- .../GeneralizedLinearRegressionSuite.scala | 634 ++++++++++-------- 6 files changed, 534 insertions(+), 361 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala index cce3ca45ccd8f..dd56fbbfa2b63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala @@ -27,3 +27,24 @@ import org.apache.spark.ml.linalg.Vector * @param features The vector of features for this data point. */ private[ml] case class Instance(label: Double, weight: Double, features: Vector) + +/** + * Case class that represents an instance of data point with + * label, weight, offset and features. + * This is mainly used in GeneralizedLinearRegression currently. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param offset The offset used for this data point. + * @param features The vector of features for this data point. + */ +private[ml] case class OffsetInstance( + label: Double, + weight: Double, + offset: Double, + features: Vector) { + + /** Converts to an [[Instance]] object by leaving out the offset. */ + def toInstance: Instance = Instance(label, weight, features) + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 9c495512422ba..6961b45f55e4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ import org.apache.spark.rdd.RDD @@ -43,7 +43,7 @@ private[ml] class IterativelyReweightedLeastSquaresModel( * find M-estimator in robust regression and other optimization problems. * * @param initialModel the initial guess model. - * @param reweightFunc the reweight function which is used to update offsets and weights + * @param reweightFunc the reweight function which is used to update working labels and weights * at each iteration. * @param fitIntercept whether to fit intercept. * @param regParam L2 regularization parameter used by WLS. @@ -57,13 +57,13 @@ private[ml] class IterativelyReweightedLeastSquaresModel( */ private[ml] class IterativelyReweightedLeastSquares( val initialModel: WeightedLeastSquaresModel, - val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double), + val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double), val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, val tol: Double) extends Logging with Serializable { - def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = { + def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -75,10 +75,10 @@ private[ml] class IterativelyReweightedLeastSquares( oldModel = model - // Update offsets and weights using reweightFunc + // Update working labels and weights using reweightFunc val newInstances = instances.map { instance => - val (newOffset, newWeight) = reweightFunc(instance, oldModel) - Instance(newOffset, newWeight, instance.features) + val (newLabel, newWeight) = reweightFunc(instance, oldModel) + Instance(newLabel, newWeight, instance.features) } // Estimate new model diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 56ab9675700a0..32b0af72ba9bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index bff0d9bbb46ff..ce3460ae43566 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -26,8 +26,8 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams -import org.apache.spark.ml.feature.Instance -import org.apache.spark.ml.linalg.{BLAS, Vector} +import org.apache.spark.ml.feature.{Instance, OffsetInstance} +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -138,6 +138,27 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") def getLinkPredictionCol: String = $(linkPredictionCol) + /** + * Param for offset column name. If this is not set or empty, we treat all instance offsets + * as 0.0. The feature specified as offset has a constant coefficient of 1.0. + * @group param + */ + @Since("2.3.0") + final val offsetCol: Param[String] = new Param[String](this, "offsetCol", "The offset " + + "column name. If this is not set or empty, we treat all instance offsets as 0.0") + + /** @group getParam */ + @Since("2.3.0") + def getOffsetCol: String = $(offsetCol) + + /** Checks whether weight column is set and nonempty. */ + private[regression] def hasWeightCol: Boolean = + isSet(weightCol) && $(weightCol).nonEmpty + + /** Checks whether offset column is set and nonempty. */ + private[regression] def hasOffsetCol: Boolean = + isSet(offsetCol) && $(offsetCol).nonEmpty + /** Checks whether we should output link prediction. */ private[regression] def hasLinkPredictionCol: Boolean = { isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty @@ -172,6 +193,11 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam } val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + + if (hasOffsetCol) { + SchemaUtils.checkNumericType(schema, $(offsetCol)) + } + if (hasLinkPredictionCol) { SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) } else { @@ -306,6 +332,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val @Since("2.0.0") def setWeightCol(value: String): this.type = set(weightCol, value) + /** + * Sets the value of param [[offsetCol]]. + * If this is not set or empty, we treat all instance offsets as 0.0. + * Default is not set, so all instances have offset 0.0. + * + * @group setParam + */ + @Since("2.3.0") + def setOffsetCol(value: String): this.type = set(offsetCol, value) + /** * Sets the solver algorithm used for optimization. * Currently only supports "irls" which is also the default solver. @@ -329,7 +365,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, linkPredictionCol, + instr.logParams(labelCol, featuresCol, weightCol, offsetCol, predictionCol, linkPredictionCol, family, solver, fitIntercept, link, maxIter, regParam, tol) instr.logNumFeatures(numFeatures) @@ -343,15 +379,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val "GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " + "set to false. To fit a model with 0 features, fitIntercept must be set to true." ) - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val w = if (!hasWeightCol) lit(1.0) else col($(weightCol)) + val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { // TODO: Make standardizeFeatures and standardizeLabel configurable. + val instances: RDD[Instance] = + dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, offset: Double, features: Vector) => + Instance(label - offset, weight, features) + } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) val wlsModel = optimizer.fit(instances) @@ -362,6 +399,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val wlsModel.diagInvAtWA.toArray, 1, getSolver) model.setSummary(Some(trainingSummary)) } else { + val instances: RDD[OffsetInstance] = + dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, offset: Double, features: Vector) => + OffsetInstance(label, weight, offset, features) + } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, @@ -425,12 +467,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Get the initial guess model for [[IterativelyReweightedLeastSquares]]. */ def initialize( - instances: RDD[Instance], + instances: RDD[OffsetInstance], fitIntercept: Boolean, regParam: Double): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) - val eta = predict(mu) + val eta = predict(mu) - instance.offset Instance(eta, instance.weight, instance.features) } // TODO: Make standardizeFeatures and standardizeLabel configurable. @@ -441,16 +483,16 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine } /** - * The reweight function used to update offsets and weights + * The reweight function used to update working labels and weights * at each iteration of [[IterativelyReweightedLeastSquares]]. */ - val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { - (instance: Instance, model: WeightedLeastSquaresModel) => { - val eta = model.predict(instance.features) + val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = { + (instance: OffsetInstance, model: WeightedLeastSquaresModel) => { + val eta = model.predict(instance.features) + instance.offset val mu = fitted(eta) - val offset = eta + (instance.label - mu) * link.deriv(mu) - val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) - (offset, weight) + val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu) + val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (newLabel, newWeight) } } } @@ -950,15 +992,22 @@ class GeneralizedLinearRegressionModel private[ml] ( private lazy val familyAndLink = FamilyAndLink(this) override protected def predict(features: Vector): Double = { - val eta = predictLink(features) + predict(features, 0.0) + } + + /** + * Calculates the predicted value when offset is set. + */ + private def predict(features: Vector, offset: Double): Double = { + val eta = predictLink(features, offset) familyAndLink.fitted(eta) } /** - * Calculate the link prediction (linear predictor) of the given instance. + * Calculates the link prediction (linear predictor) of the given instance. */ - private def predictLink(features: Vector): Double = { - BLAS.dot(features, coefficients) + intercept + private def predictLink(features: Vector, offset: Double): Double = { + BLAS.dot(features, coefficients) + intercept + offset } override def transform(dataset: Dataset[_]): DataFrame = { @@ -967,14 +1016,16 @@ class GeneralizedLinearRegressionModel private[ml] ( } override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val predictUDF = udf { (features: Vector) => predict(features) } - val predictLinkUDF = udf { (features: Vector) => predictLink(features) } + val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) } + val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) } + + val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) var output = dataset if ($(predictionCol).nonEmpty) { - output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset)) } if (hasLinkPredictionCol) { - output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) + output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), offset)) } output.toDF() } @@ -1146,9 +1197,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** Degrees of freedom. */ @Since("2.0.0") - lazy val degreesOfFreedom: Long = { - numInstances - rank - } + lazy val degreesOfFreedom: Long = numInstances - rank /** The residual degrees of freedom. */ @Since("2.0.0") @@ -1156,18 +1205,20 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** The residual degrees of freedom for the null model. */ @Since("2.0.0") - lazy val residualDegreeOfFreedomNull: Long = if (model.getFitIntercept) { - numInstances - 1 - } else { - numInstances + lazy val residualDegreeOfFreedomNull: Long = { + if (model.getFitIntercept) numInstances - 1 else numInstances } - private def weightCol: Column = { - if (!model.isDefined(model.weightCol) || model.getWeightCol.isEmpty) { - lit(1.0) - } else { - col(model.getWeightCol) - } + private def label: Column = col(model.getLabelCol).cast(DoubleType) + + private def prediction: Column = col(predictionCol) + + private def weight: Column = { + if (!model.hasWeightCol) lit(1.0) else col(model.getWeightCol) + } + + private def offset: Column = { + if (!model.hasOffsetCol) lit(0.0) else col(model.getOffsetCol).cast(DoubleType) } private[regression] lazy val devianceResiduals: DataFrame = { @@ -1175,25 +1226,23 @@ class GeneralizedLinearRegressionSummary private[regression] ( val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0)) if (y > mu) r else -1.0 * r } - val w = weightCol predictions.select( - drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals")) + drUDF(label, prediction, weight).as("devianceResiduals")) } private[regression] lazy val pearsonResiduals: DataFrame = { val prUDF = udf { mu: Double => family.variance(mu) } - val w = weightCol - predictions.select(col(model.getLabelCol).minus(col(predictionCol)) - .multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals")) + predictions.select(label.minus(prediction) + .multiply(sqrt(weight)).divide(sqrt(prUDF(prediction))).as("pearsonResiduals")) } private[regression] lazy val workingResiduals: DataFrame = { val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) } - predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals")) + predictions.select(wrUDF(label, prediction).as("workingResiduals")) } private[regression] lazy val responseResiduals: DataFrame = { - predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals")) + predictions.select(label.minus(prediction).as("responseResiduals")) } /** @@ -1225,16 +1274,35 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val nullDeviance: Double = { - val w = weightCol - val wtdmu: Double = if (model.getFitIntercept) { - val agg = predictions.agg(sum(w.multiply(col(model.getLabelCol))), sum(w)).first() - agg.getDouble(0) / agg.getDouble(1) + val intercept: Double = if (!model.getFitIntercept) { + 0.0 } else { - link.unlink(0.0) + /* + Estimate intercept analytically when there is no offset, or when there is offset but + the model is Gaussian family with identity link. Otherwise, fit an intercept only model. + */ + if (!model.hasOffsetCol || + (model.hasOffsetCol && family == Gaussian && link == Identity)) { + val agg = predictions.agg(sum(weight.multiply( + label.minus(offset))), sum(weight)).first() + link.link(agg.getDouble(0) / agg.getDouble(1)) + } else { + // Create empty feature column and fit intercept only model using param setting from model + val featureNull = "feature_" + java.util.UUID.randomUUID.toString + val paramMap = model.extractParamMap() + paramMap.put(model.featuresCol, featureNull) + if (family.name != "tweedie") { + paramMap.remove(model.variancePower) + } + val emptyVectorUDF = udf{ () => Vectors.zeros(0) } + model.parent.fit( + dataset.withColumn(featureNull, emptyVectorUDF()), paramMap + ).intercept + } } - predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map { - case Row(y: Double, weight: Double) => - family.deviance(y, wtdmu, weight) + predictions.select(label, offset, weight).rdd.map { + case Row(y: Double, offset: Double, weight: Double) => + family.deviance(y, link.unlink(intercept + offset), weight) }.sum() } @@ -1243,8 +1311,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val deviance: Double = { - val w = weightCol - predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + predictions.select(label, prediction, weight).rdd.map { case Row(label: Double, pred: Double, weight: Double) => family.deviance(label, pred, weight) }.sum() @@ -1269,10 +1336,9 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** Akaike Information Criterion (AIC) for the fitted model. */ @Since("2.0.0") lazy val aic: Double = { - val w = weightCol - val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) + val weightSum = predictions.select(weight).agg(sum(weight)).first().getDouble(0) val t = predictions.select( - col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + label, prediction, weight).rdd.map { case Row(label: Double, pred: Double, weight: Double) => (label, pred, weight) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala index 50260952ecb66..6d143504fcf58 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -26,8 +26,8 @@ import org.apache.spark.rdd.RDD class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { - private var instances1: RDD[Instance] = _ - private var instances2: RDD[Instance] = _ + private var instances1: RDD[OffsetInstance] = _ + private var instances2: RDD[OffsetInstance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -39,10 +39,10 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes w <- c(1, 2, 3, 4) */ instances1 = sc.parallelize(Seq( - Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + OffsetInstance(1.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(0.0, 2.0, 0.0, Vectors.dense(1.0, 2.0)), + OffsetInstance(1.0, 3.0, 0.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.0, 4.0, 0.0, Vectors.dense(3.0, 3.0)) ), 2) /* R code: @@ -52,10 +52,10 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes w <- c(1, 2, 3, 4) */ instances2 = sc.parallelize(Seq( - Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + OffsetInstance(2.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(8.0, 2.0, 0.0, Vectors.dense(1.0, 7.0)), + OffsetInstance(3.0, 3.0, 0.0, Vectors.dense(2.0, 11.0)), + OffsetInstance(9.0, 4.0, 0.0, Vectors.dense(3.0, 13.0)) ), 2) } @@ -156,7 +156,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes var idx = 0 for (fitIntercept <- Seq(false, true)) { val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(instances2) + standardizeFeatures = false, standardizeLabel = false).fit(instances2.map(_.toInstance)) val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) @@ -169,29 +169,29 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes object IterativelyReweightedLeastSquaresSuite { def BinomialReweightFunc( - instance: Instance, + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { - val eta = model.predict(instance.features) + val eta = model.predict(instance.features) + instance.offset val mu = 1.0 / (1.0 + math.exp(-1.0 * eta)) - val z = eta + (instance.label - mu) / (mu * (1.0 - mu)) + val z = eta - instance.offset + (instance.label - mu) / (mu * (1.0 - mu)) val w = mu * (1 - mu) * instance.weight (z, w) } def PoissonReweightFunc( - instance: Instance, + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { - val eta = model.predict(instance.features) + val eta = model.predict(instance.features) + instance.offset val mu = math.exp(eta) - val z = eta + (instance.label - mu) / mu + val z = eta - instance.offset + (instance.label - mu) / mu val w = mu * instance.weight (z, w) } def L1RegressionReweightFunc( - instance: Instance, + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { - val eta = model.predict(instance.features) + val eta = model.predict(instance.features) + instance.offset val e = math.max(math.abs(eta - instance.label), 1e-7) val w = 1 / e val y = instance.label diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index f7c7c001a36af..cfaa57314bd66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LogisticRegressionSuite._ -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} @@ -797,77 +797,160 @@ class GeneralizedLinearRegressionSuite } } - test("glm summary: gaussian family with weight") { + test("generalized linear regression with weight and offset") { /* - R code: + R code: + library(statmod) - A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) - b <- c(17, 19, 23, 29) - w <- c(1, 2, 3, 4) - df <- as.data.frame(cbind(A, b)) - */ - val datasetWithWeight = Seq( - Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + df <- as.data.frame(matrix(c( + 0.2, 1.0, 2.0, 0.0, 5.0, + 0.5, 2.1, 0.5, 1.0, 2.0, + 0.9, 0.4, 1.0, 2.0, 1.0, + 0.7, 0.7, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) + families <- list(gaussian, binomial, poisson, Gamma, tweedie(1.5)) + f1 <- V1 ~ -1 + V4 + V5 + f2 <- V1 ~ V4 + V5 + for (f in c(f1, f2)) { + for (fam in families) { + model <- glm(f, df, family = fam, weights = V2, offset = V3) + print(as.vector(coef(model))) + } + } + [1] 0.5169222 -0.3344444 + [1] 0.9419107 -0.6864404 + [1] 0.1812436 -0.6568422 + [1] -0.2869094 0.7857710 + [1] 0.1055254 0.2979113 + [1] -0.05990345 0.53188982 -0.32118415 + [1] -0.2147117 0.9911750 -0.6356096 + [1] -1.5616130 0.6646470 -0.3192581 + [1] 0.3390397 -0.3406099 0.6870259 + [1] 0.3665034 0.1039416 0.1484616 + */ + val dataset = Seq( + OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + OffsetInstance(0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0)) ).toDF() + + val expected = Seq( + Vectors.dense(0, 0.5169222, -0.3344444), + Vectors.dense(0, 0.9419107, -0.6864404), + Vectors.dense(0, 0.1812436, -0.6568422), + Vectors.dense(0, -0.2869094, 0.785771), + Vectors.dense(0, 0.1055254, 0.2979113), + Vectors.dense(-0.05990345, 0.53188982, -0.32118415), + Vectors.dense(-0.2147117, 0.991175, -0.6356096), + Vectors.dense(-1.561613, 0.664647, -0.3192581), + Vectors.dense(0.3390397, -0.3406099, 0.6870259), + Vectors.dense(0.3665034, 0.1039416, 0.1484616)) + + import GeneralizedLinearRegression._ + + var idx = 0 + + for (fitIntercept <- Seq(false, true)) { + for (family <- Seq("gaussian", "binomial", "poisson", "gamma", "tweedie")) { + val trainer = new GeneralizedLinearRegression().setFamily(family) + .setFitIntercept(fitIntercept).setOffsetCol("offset") + .setWeightCol("weight").setLinkPredictionCol("linkPrediction") + if (family == "tweedie") trainer.setVariancePower(1.5) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, s"Model mismatch: GLM with family = $family," + + s" and fitIntercept = $fitIntercept.") + + val familyLink = FamilyAndLink(trainer) + model.transform(dataset).select("features", "offset", "prediction", "linkPrediction") + .collect().foreach { + case Row(features: DenseVector, offset: Double, prediction1: Double, + linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + offset + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"family = $family, and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with family = $family, and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("glm summary: gaussian family with weight and offset") { /* - R code: + R code: - model <- glm(formula = "b ~ .", family="gaussian", data = df, weights = w) - summary(model) + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + off <- c(2, 3, 1, 4) + df <- as.data.frame(cbind(A, b)) + */ + val dataset = Seq( + OffsetInstance(17.0, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(19.0, 2.0, 3.0, Vectors.dense(1.0, 7.0)), + OffsetInstance(23.0, 3.0, 1.0, Vectors.dense(2.0, 11.0)), + OffsetInstance(29.0, 4.0, 4.0, Vectors.dense(3.0, 13.0)) + ).toDF() + /* + R code: - Deviance Residuals: - 1 2 3 4 - 1.920 -1.358 -1.109 0.960 + model <- glm(formula = "b ~ .", family = "gaussian", data = df, + weights = w, offset = off) + summary(model) - Coefficients: - Estimate Std. Error t value Pr(>|t|) - (Intercept) 18.080 9.608 1.882 0.311 - V1 6.080 5.556 1.094 0.471 - V2 -0.600 1.960 -0.306 0.811 + Deviance Residuals: + 1 2 3 4 + 0.9600 -0.6788 -0.5543 0.4800 - (Dispersion parameter for gaussian family taken to be 7.68) + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 5.5400 4.8040 1.153 0.455 + V1 -0.9600 2.7782 -0.346 0.788 + V2 1.7000 0.9798 1.735 0.333 - Null deviance: 202.00 on 3 degrees of freedom - Residual deviance: 7.68 on 1 degrees of freedom - AIC: 18.783 + (Dispersion parameter for gaussian family taken to be 1.92) - Number of Fisher Scoring iterations: 2 + Null deviance: 152.10 on 3 degrees of freedom + Residual deviance: 1.92 on 1 degrees of freedom + AIC: 13.238 - residuals(model, type="pearson") - 1 2 3 4 - 1.920000 -1.357645 -1.108513 0.960000 + Number of Fisher Scoring iterations: 2 - residuals(model, type="working") + residuals(model, type = "pearson") + 1 2 3 4 + 0.9600000 -0.6788225 -0.5542563 0.4800000 + residuals(model, type = "working") 1 2 3 4 - 1.92 -0.96 -0.64 0.48 - - residuals(model, type="response") + 0.96 -0.48 -0.32 0.24 + residuals(model, type = "response") 1 2 3 4 - 1.92 -0.96 -0.64 0.48 + 0.96 -0.48 -0.32 0.24 */ val trainer = new GeneralizedLinearRegression() - .setWeightCol("weight") + .setWeightCol("weight").setOffsetCol("offset") + + val model = trainer.fit(dataset) - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(6.080, -0.600)) - val interceptR = 18.080 - val devianceResidualsR = Array(1.920, -1.358, -1.109, 0.960) - val pearsonResidualsR = Array(1.920000, -1.357645, -1.108513, 0.960000) - val workingResidualsR = Array(1.92, -0.96, -0.64, 0.48) - val responseResidualsR = Array(1.92, -0.96, -0.64, 0.48) - val seCoefR = Array(5.556, 1.960, 9.608) - val tValsR = Array(1.094, -0.306, 1.882) - val pValsR = Array(0.471, 0.811, 0.311) - val dispersionR = 7.68 - val nullDevianceR = 202.00 - val residualDevianceR = 7.68 + val coefficientsR = Vectors.dense(Array(-0.96, 1.7)) + val interceptR = 5.54 + val devianceResidualsR = Array(0.96, -0.67882, -0.55426, 0.48) + val pearsonResidualsR = Array(0.96, -0.67882, -0.55426, 0.48) + val workingResidualsR = Array(0.96, -0.48, -0.32, 0.24) + val responseResidualsR = Array(0.96, -0.48, -0.32, 0.24) + val seCoefR = Array(2.7782, 0.9798, 4.804) + val tValsR = Array(-0.34555, 1.73506, 1.15321) + val pValsR = Array(0.78819, 0.33286, 0.45478) + val dispersionR = 1.92 + val nullDevianceR = 152.1 + val residualDevianceR = 1.92 val residualDegreeOfFreedomNullR = 3 val residualDegreeOfFreedomR = 1 - val aicR = 18.783 + val aicR = 13.23758 assert(model.hasSummary) val summary = model.summary @@ -912,7 +995,7 @@ class GeneralizedLinearRegressionSuite assert(summary.aic ~== aicR absTol 1E-3) assert(summary.solver === "irls") - val summary2: GeneralizedLinearRegressionSummary = model.evaluate(datasetWithWeight) + val summary2: GeneralizedLinearRegressionSummary = model.evaluate(dataset) assert(summary.predictions.columns.toSet === summary2.predictions.columns.toSet) assert(summary.predictionCol === summary2.predictionCol) assert(summary.rank === summary2.rank) @@ -925,79 +1008,79 @@ class GeneralizedLinearRegressionSuite assert(summary.aic === summary2.aic) } - test("glm summary: binomial family with weight") { + test("glm summary: binomial family with weight and offset") { /* - R code: + R code: - A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) - b <- c(1, 0.5, 1, 0) - w <- c(1, 2.0, 0.3, 4.7) - df <- as.data.frame(cbind(A, b)) + df <- as.data.frame(matrix(c( + 0.2, 1.0, 2.0, 0.0, 5.0, + 0.5, 2.1, 0.5, 1.0, 2.0, + 0.9, 0.4, 1.0, 2.0, 1.0, + 0.7, 0.7, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) */ - val datasetWithWeight = Seq( - Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(0.5, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 0.3, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.7, Vectors.dense(3.0, 3.0)) + val dataset = Seq( + OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + OffsetInstance(0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0)) ).toDF() - /* - R code: - - model <- glm(formula = "b ~ . -1", family="binomial", data = df, weights = w) - summary(model) - - Deviance Residuals: - 1 2 3 4 - 0.2404 0.1965 1.2824 -0.6916 + R code: - Coefficients: - Estimate Std. Error z value Pr(>|z|) - x1 -1.6901 1.2764 -1.324 0.185 - x2 0.7059 0.9449 0.747 0.455 + model <- glm(formula = "V1 ~ V4 + V5", family = "binomial", data = df, + weights = V2, offset = V3) + summary(model) - (Dispersion parameter for binomial family taken to be 1) + Deviance Residuals: + 1 2 3 4 + 0.002584 -0.003800 0.012478 -0.001796 - Null deviance: 8.3178 on 4 degrees of freedom - Residual deviance: 2.2193 on 2 degrees of freedom - AIC: 5.9915 + Coefficients: + Estimate Std. Error z value Pr(>|z|) + (Intercept) -0.2147 3.5687 -0.060 0.952 + V4 0.9912 1.2344 0.803 0.422 + V5 -0.6356 0.9669 -0.657 0.511 - Number of Fisher Scoring iterations: 5 + (Dispersion parameter for binomial family taken to be 1) - residuals(model, type="pearson") - 1 2 3 4 - 0.171217 0.197406 2.085864 -0.495332 + Null deviance: 2.17560881 on 3 degrees of freedom + Residual deviance: 0.00018005 on 1 degrees of freedom + AIC: 10.245 - residuals(model, type="working") - 1 2 3 4 - 1.029315 0.281881 15.502768 -1.052203 + Number of Fisher Scoring iterations: 4 - residuals(model, type="response") - 1 2 3 4 - 0.028480 0.069123 0.935495 -0.049613 + residuals(model, type = "pearson") + 1 2 3 4 + 0.002586113 -0.003799744 0.012372235 -0.001796892 + residuals(model, type = "working") + 1 2 3 4 + 0.006477857 -0.005244163 0.063541250 -0.004691064 + residuals(model, type = "response") + 1 2 3 4 + 0.0010324375 -0.0013110318 0.0060225522 -0.0009832738 */ val trainer = new GeneralizedLinearRegression() .setFamily("Binomial") .setWeightCol("weight") - .setFitIntercept(false) - - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(-1.690134, 0.705929)) - val interceptR = 0.0 - val devianceResidualsR = Array(0.2404, 0.1965, 1.2824, -0.6916) - val pearsonResidualsR = Array(0.171217, 0.197406, 2.085864, -0.495332) - val workingResidualsR = Array(1.029315, 0.281881, 15.502768, -1.052203) - val responseResidualsR = Array(0.02848, 0.069123, 0.935495, -0.049613) - val seCoefR = Array(1.276417, 0.944934) - val tValsR = Array(-1.324124, 0.747068) - val pValsR = Array(0.185462, 0.455023) - val dispersionR = 1.0 - val nullDevianceR = 8.3178 - val residualDevianceR = 2.2193 - val residualDegreeOfFreedomNullR = 4 - val residualDegreeOfFreedomR = 2 - val aicR = 5.991537 + .setOffsetCol("offset") + + val model = trainer.fit(dataset) + + val coefficientsR = Vectors.dense(Array(0.99117, -0.63561)) + val interceptR = -0.21471 + val devianceResidualsR = Array(0.00258, -0.0038, 0.01248, -0.0018) + val pearsonResidualsR = Array(0.00259, -0.0038, 0.01237, -0.0018) + val workingResidualsR = Array(0.00648, -0.00524, 0.06354, -0.00469) + val responseResidualsR = Array(0.00103, -0.00131, 0.00602, -0.00098) + val seCoefR = Array(1.23439, 0.9669, 3.56866) + val tValsR = Array(0.80297, -0.65737, -0.06017) + val pValsR = Array(0.42199, 0.51094, 0.95202) + val dispersionR = 1 + val nullDevianceR = 2.17561 + val residualDevianceR = 0.00018 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 + val aicR = 10.24453 val summary = model.summary val devianceResiduals = summary.residuals() @@ -1040,81 +1123,79 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } - test("glm summary: poisson family with weight") { + test("glm summary: poisson family with weight and offset") { /* - R code: + R code: - A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) - b <- c(2, 8, 3, 9) - w <- c(1, 2, 3, 4) - df <- as.data.frame(cbind(A, b)) + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + off <- c(2, 3, 1, 4) + df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = Seq( - Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + val dataset = Seq( + OffsetInstance(2.0, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(8.0, 2.0, 3.0, Vectors.dense(1.0, 7.0)), + OffsetInstance(3.0, 3.0, 1.0, Vectors.dense(2.0, 11.0)), + OffsetInstance(9.0, 4.0, 4.0, Vectors.dense(3.0, 13.0)) ).toDF() /* - R code: - - model <- glm(formula = "b ~ .", family="poisson", data = df, weights = w) - summary(model) - - Deviance Residuals: - 1 2 3 4 - -0.28952 0.11048 0.14839 -0.07268 - - Coefficients: - Estimate Std. Error z value Pr(>|z|) - (Intercept) 6.2999 1.6086 3.916 8.99e-05 *** - V1 3.3241 1.0184 3.264 0.00110 ** - V2 -1.0818 0.3522 -3.071 0.00213 ** - --- - Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 - - (Dispersion parameter for poisson family taken to be 1) - - Null deviance: 15.38066 on 3 degrees of freedom - Residual deviance: 0.12333 on 1 degrees of freedom - AIC: 41.803 - - Number of Fisher Scoring iterations: 3 + R code: - residuals(model, type="pearson") - 1 2 3 4 - -0.28043145 0.11099310 0.14963714 -0.07253611 + model <- glm(formula = "b ~ .", family = "poisson", data = df, + weights = w, offset = off) + summary(model) - residuals(model, type="working") - 1 2 3 4 - -0.17960679 0.02813593 0.05113852 -0.01201650 + Deviance Residuals: + 1 2 3 4 + -2.0480 1.2315 1.8293 -0.7107 - residuals(model, type="response") - 1 2 3 4 - -0.4378554 0.2189277 0.1459518 -0.1094638 + Coefficients: + Estimate Std. Error z value Pr(>|z|) + (Intercept) -4.5678 1.9625 -2.328 0.0199 + V1 -2.8784 1.1683 -2.464 0.0137 + V2 0.8859 0.4170 2.124 0.0336 + + (Dispersion parameter for poisson family taken to be 1) + + Null deviance: 22.5585 on 3 degrees of freedom + Residual deviance: 9.5622 on 1 degrees of freedom + AIC: 51.242 + + Number of Fisher Scoring iterations: 5 + + residuals(model, type = "pearson") + 1 2 3 4 + -1.7480418 1.3037611 2.0750099 -0.6972966 + residuals(model, type = "working") + 1 2 3 4 + -0.6891489 0.3833588 0.9710682 -0.1096590 + residuals(model, type = "response") + 1 2 3 4 + -4.433948 2.216974 1.477983 -1.108487 */ val trainer = new GeneralizedLinearRegression() .setFamily("Poisson") .setWeightCol("weight") - .setFitIntercept(true) - - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(3.3241, -1.0818)) - val interceptR = 6.2999 - val devianceResidualsR = Array(-0.28952, 0.11048, 0.14839, -0.07268) - val pearsonResidualsR = Array(-0.28043145, 0.11099310, 0.14963714, -0.07253611) - val workingResidualsR = Array(-0.17960679, 0.02813593, 0.05113852, -0.01201650) - val responseResidualsR = Array(-0.4378554, 0.2189277, 0.1459518, -0.1094638) - val seCoefR = Array(1.0184, 0.3522, 1.6086) - val tValsR = Array(3.264, -3.071, 3.916) - val pValsR = Array(0.00110, 0.00213, 0.00009) - val dispersionR = 1.0 - val nullDevianceR = 15.38066 - val residualDevianceR = 0.12333 + .setOffsetCol("offset") + + val model = trainer.fit(dataset) + + val coefficientsR = Vectors.dense(Array(-2.87843, 0.88589)) + val interceptR = -4.56784 + val devianceResidualsR = Array(-2.04796, 1.23149, 1.82933, -0.71066) + val pearsonResidualsR = Array(-1.74804, 1.30376, 2.07501, -0.6973) + val workingResidualsR = Array(-0.68915, 0.38336, 0.97107, -0.10966) + val responseResidualsR = Array(-4.43395, 2.21697, 1.47798, -1.10849) + val seCoefR = Array(1.16826, 0.41703, 1.96249) + val tValsR = Array(-2.46387, 2.12428, -2.32757) + val pValsR = Array(0.01374, 0.03365, 0.01993) + val dispersionR = 1 + val nullDevianceR = 22.55853 + val residualDevianceR = 9.5622 val residualDegreeOfFreedomNullR = 3 val residualDegreeOfFreedomR = 1 - val aicR = 41.803 + val aicR = 51.24218 val summary = model.summary val devianceResiduals = summary.residuals() @@ -1157,78 +1238,79 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } - test("glm summary: gamma family with weight") { + test("glm summary: gamma family with weight and offset") { /* - R code: + R code: - A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) - b <- c(2, 8, 3, 9) - w <- c(1, 2, 3, 4) - df <- as.data.frame(cbind(A, b)) + A <- matrix(c(0, 5, 1, 2, 2, 1, 3, 3), 4, 2, byrow = TRUE) + b <- c(1, 2, 1, 2) + w <- c(1, 2, 3, 4) + off <- c(0, 0.5, 1, 0) + df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = Seq( - Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + val dataset = Seq( + OffsetInstance(1.0, 1.0, 0.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(2.0, 2.0, 0.5, Vectors.dense(1.0, 2.0)), + OffsetInstance(1.0, 3.0, 1.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(2.0, 4.0, 0.0, Vectors.dense(3.0, 3.0)) ).toDF() /* - R code: - - model <- glm(formula = "b ~ .", family="Gamma", data = df, weights = w) - summary(model) + R code: - Deviance Residuals: - 1 2 3 4 - -0.26343 0.05761 0.12818 -0.03484 + model <- glm(formula = "b ~ .", family = "Gamma", data = df, + weights = w, offset = off) + summary(model) - Coefficients: - Estimate Std. Error t value Pr(>|t|) - (Intercept) -0.81511 0.23449 -3.476 0.178 - V1 -0.72730 0.16137 -4.507 0.139 - V2 0.23894 0.05481 4.359 0.144 + Deviance Residuals: + 1 2 3 4 + -0.17095 0.19867 -0.23604 0.03241 - (Dispersion parameter for Gamma family taken to be 0.07986091) + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) -0.56474 0.23866 -2.366 0.255 + V1 0.07695 0.06931 1.110 0.467 + V2 0.28068 0.07320 3.835 0.162 - Null deviance: 2.937462 on 3 degrees of freedom - Residual deviance: 0.090358 on 1 degrees of freedom - AIC: 23.202 + (Dispersion parameter for Gamma family taken to be 0.1212174) - Number of Fisher Scoring iterations: 4 + Null deviance: 2.02568 on 3 degrees of freedom + Residual deviance: 0.12546 on 1 degrees of freedom + AIC: 0.93388 - residuals(model, type="pearson") - 1 2 3 4 - -0.24082508 0.05839241 0.13135766 -0.03463621 + Number of Fisher Scoring iterations: 4 - residuals(model, type="working") + residuals(model, type = "pearson") + 1 2 3 4 + -0.16134949 0.20807694 -0.22544551 0.03258777 + residuals(model, type = "working") 1 2 3 4 - 0.091414181 -0.005374314 -0.027196998 0.001890910 - - residuals(model, type="response") - 1 2 3 4 - -0.6344390 0.3172195 0.2114797 -0.1586097 + 0.135315831 -0.084390309 0.113219135 -0.008279688 + residuals(model, type = "response") + 1 2 3 4 + -0.1923918 0.2565224 -0.1496381 0.0320653 */ val trainer = new GeneralizedLinearRegression() .setFamily("Gamma") .setWeightCol("weight") + .setOffsetCol("offset") + + val model = trainer.fit(dataset) - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(-0.72730, 0.23894)) - val interceptR = -0.81511 - val devianceResidualsR = Array(-0.26343, 0.05761, 0.12818, -0.03484) - val pearsonResidualsR = Array(-0.24082508, 0.05839241, 0.13135766, -0.03463621) - val workingResidualsR = Array(0.091414181, -0.005374314, -0.027196998, 0.001890910) - val responseResidualsR = Array(-0.6344390, 0.3172195, 0.2114797, -0.1586097) - val seCoefR = Array(0.16137, 0.05481, 0.23449) - val tValsR = Array(-4.507, 4.359, -3.476) - val pValsR = Array(0.139, 0.144, 0.178) - val dispersionR = 0.07986091 - val nullDevianceR = 2.937462 - val residualDevianceR = 0.090358 + val coefficientsR = Vectors.dense(Array(0.07695, 0.28068)) + val interceptR = -0.56474 + val devianceResidualsR = Array(-0.17095, 0.19867, -0.23604, 0.03241) + val pearsonResidualsR = Array(-0.16135, 0.20808, -0.22545, 0.03259) + val workingResidualsR = Array(0.13532, -0.08439, 0.11322, -0.00828) + val responseResidualsR = Array(-0.19239, 0.25652, -0.14964, 0.03207) + val seCoefR = Array(0.06931, 0.0732, 0.23866) + val tValsR = Array(1.11031, 3.83453, -2.3663) + val pValsR = Array(0.46675, 0.16241, 0.25454) + val dispersionR = 0.12122 + val nullDevianceR = 2.02568 + val residualDevianceR = 0.12546 val residualDegreeOfFreedomNullR = 3 val residualDegreeOfFreedomR = 1 - val aicR = 23.202 + val aicR = 0.93388 val summary = model.summary val devianceResiduals = summary.residuals() @@ -1271,77 +1353,81 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } - test("glm summary: tweedie family with weight") { + test("glm summary: tweedie family with weight and offset") { /* R code: - library(statmod) df <- as.data.frame(matrix(c( - 1.0, 1.0, 0.0, 5.0, - 0.5, 2.0, 1.0, 2.0, - 1.0, 3.0, 2.0, 1.0, - 0.0, 4.0, 3.0, 3.0), 4, 4, byrow = TRUE)) + 1.0, 1.0, 1.0, 0.0, 5.0, + 0.5, 2.0, 3.0, 1.0, 2.0, + 1.0, 3.0, 2.0, 2.0, 1.0, + 0.0, 4.0, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) + */ + val dataset = Seq( + OffsetInstance(1.0, 1.0, 1.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(0.5, 2.0, 3.0, Vectors.dense(1.0, 2.0)), + OffsetInstance(1.0, 3.0, 2.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.0, 4.0, 0.0, Vectors.dense(3.0, 3.0)) + ).toDF() + /* + R code: - model <- glm(V1 ~ -1 + V3 + V4, data = df, weights = V2, - family = tweedie(var.power = 1.6, link.power = 0)) + library(statmod) + model <- glm(V1 ~ V4 + V5, data = df, weights = V2, offset = V3, + family = tweedie(var.power = 1.6, link.power = 0.0)) summary(model) Deviance Residuals: 1 2 3 4 - 0.6210 -0.0515 1.6935 -3.2539 + 0.8917 -2.1396 1.2252 -1.7946 Coefficients: - Estimate Std. Error t value Pr(>|t|) - V3 -0.4087 0.5205 -0.785 0.515 - V4 -0.1212 0.4082 -0.297 0.794 + Estimate Std. Error t value Pr(>|t|) + (Intercept) -0.03047 3.65000 -0.008 0.995 + V4 -1.14577 1.41674 -0.809 0.567 + V5 -0.36585 0.97065 -0.377 0.771 - (Dispersion parameter for Tweedie family taken to be 3.830036) + (Dispersion parameter for Tweedie family taken to be 6.334961) - Null deviance: 20.702 on 4 degrees of freedom - Residual deviance: 13.844 on 2 degrees of freedom + Null deviance: 12.784 on 3 degrees of freedom + Residual deviance: 10.095 on 1 degrees of freedom AIC: NA - Number of Fisher Scoring iterations: 11 - - residuals(model, type="pearson") - 1 2 3 4 - 0.7383616 -0.0509458 2.2348337 -1.4552090 - residuals(model, type="working") - 1 2 3 4 - 0.83354150 -0.04103552 1.55676369 -1.00000000 - residuals(model, type="response") - 1 2 3 4 - 0.45460738 -0.02139574 0.60888055 -0.20392801 + Number of Fisher Scoring iterations: 18 + + residuals(model, type = "pearson") + 1 2 3 4 + 1.1472554 -1.4642569 1.4935199 -0.8025842 + residuals(model, type = "working") + 1 2 3 4 + 1.3624928 -0.8322375 0.9894580 -1.0000000 + residuals(model, type = "response") + 1 2 3 4 + 0.57671828 -2.48040354 0.49735052 -0.01040646 */ - val datasetWithWeight = Seq( - Instance(1.0, 1.0, Vectors.dense(0.0, 5.0)), - Instance(0.5, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) - ).toDF() - val trainer = new GeneralizedLinearRegression() .setFamily("tweedie") .setVariancePower(1.6) .setLinkPower(0.0) .setWeightCol("weight") - .setFitIntercept(false) - - val model = trainer.fit(datasetWithWeight) - val coefficientsR = Vectors.dense(Array(-0.408746, -0.12125)) - val interceptR = 0.0 - val devianceResidualsR = Array(0.621047, -0.051515, 1.693473, -3.253946) - val pearsonResidualsR = Array(0.738362, -0.050946, 2.234834, -1.455209) - val workingResidualsR = Array(0.833541, -0.041036, 1.556764, -1.0) - val responseResidualsR = Array(0.454607, -0.021396, 0.608881, -0.203928) - val seCoefR = Array(0.520519, 0.408215) - val tValsR = Array(-0.785267, -0.297024) - val pValsR = Array(0.514549, 0.794457) - val dispersionR = 3.830036 - val nullDevianceR = 20.702 - val residualDevianceR = 13.844 - val residualDegreeOfFreedomNullR = 4 - val residualDegreeOfFreedomR = 2 + .setOffsetCol("offset") + + val model = trainer.fit(dataset) + + val coefficientsR = Vectors.dense(Array(-1.14577, -0.36585)) + val interceptR = -0.03047 + val devianceResidualsR = Array(0.89171, -2.13961, 1.2252, -1.79463) + val pearsonResidualsR = Array(1.14726, -1.46426, 1.49352, -0.80258) + val workingResidualsR = Array(1.36249, -0.83224, 0.98946, -1) + val responseResidualsR = Array(0.57672, -2.4804, 0.49735, -0.01041) + val seCoefR = Array(1.41674, 0.97065, 3.65) + val tValsR = Array(-0.80873, -0.37691, -0.00835) + val pValsR = Array(0.56707, 0.77053, 0.99468) + val dispersionR = 6.33496 + val nullDevianceR = 12.78358 + val residualDevianceR = 10.09488 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 val summary = model.summary From 3c2fc19d478256f8dc0ae7219fdd188030218c07 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 30 Jun 2017 20:30:26 +0800 Subject: [PATCH 055/779] [SPARK-18294][CORE] Implement commit protocol to support `mapred` package's committer ## What changes were proposed in this pull request? This PR makes the following changes: - Implement a new commit protocol `HadoopMapRedCommitProtocol` which support the old `mapred` package's committer; - Refactor SparkHadoopWriter and SparkHadoopMapReduceWriter, now they are combined together, thus we can support write through both mapred and mapreduce API by the new SparkHadoopWriter, a lot of duplicated codes are removed. After this change, it should be pretty easy for us to support the committer from both the new and the old hadoop API at high level. ## How was this patch tested? No major behavior change, passed the existing test cases. Author: Xingbo Jiang Closes #18438 from jiangxb1987/SparkHadoopWriter. --- .../io/HadoopMapRedCommitProtocol.scala | 36 ++ .../internal/io/HadoopWriteConfigUtil.scala | 79 ++++ .../io/SparkHadoopMapReduceWriter.scala | 181 -------- .../spark/internal/io/SparkHadoopWriter.scala | 393 ++++++++++++++---- .../apache/spark/rdd/PairRDDFunctions.scala | 72 +--- .../spark/rdd/PairRDDFunctionsSuite.scala | 2 +- .../OutputCommitCoordinatorSuite.scala | 35 +- 7 files changed, 461 insertions(+), 337 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala create mode 100644 core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala delete mode 100644 core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala new file mode 100644 index 0000000000000..ddbd624b380d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => NewTaskAttemptContext} + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapRedCommitProtocol(jobId: String, path: String) + extends HadoopMapReduceCommitProtocol(jobId, path) { + + override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = { + val config = context.getConfiguration.asInstanceOf[JobConf] + config.getOutputCommitter + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala new file mode 100644 index 0000000000000..9b987e0e1bb67 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import scala.reflect.ClassTag + +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.SparkConf + +/** + * Interface for create output format/committer/writer used during saving an RDD using a Hadoop + * OutputFormat (both from the old mapred API and the new mapreduce API) + * + * Notes: + * 1. Implementations should throw [[IllegalArgumentException]] when wrong hadoop API is + * referenced; + * 2. Implementations must be serializable, as the instance instantiated on the driver + * will be used for tasks on executors; + * 3. Implementations should have a constructor with exactly one argument: + * (conf: SerializableConfiguration) or (conf: SerializableJobConf). + */ +abstract class HadoopWriteConfigUtil[K, V: ClassTag] extends Serializable { + + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- + + def createJobContext(jobTrackerId: String, jobId: Int): JobContext + + def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): TaskAttemptContext + + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- + + def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol + + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + def initWriter(taskContext: TaskAttemptContext, splitId: Int): Unit + + def write(pair: (K, V)): Unit + + def closeWriter(taskContext: TaskAttemptContext): Unit + + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- + + def initOutputFormat(jobContext: JobContext): Unit + + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- + + def assertConf(jobContext: JobContext, conf: SparkConf): Unit +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala deleted file mode 100644 index 376ff9bb19f74..0000000000000 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.internal.io - -import java.text.SimpleDateFormat -import java.util.{Date, Locale} - -import scala.reflect.ClassTag -import scala.util.DynamicVariable - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.{JobConf, JobID} -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl - -import org.apache.spark.{SparkConf, SparkException, TaskContext} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.OutputMetrics -import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} - -/** - * A helper object that saves an RDD using a Hadoop OutputFormat - * (from the newer mapreduce API, not the old mapred API). - */ -private[spark] -object SparkHadoopMapReduceWriter extends Logging { - - /** - * Basic work flow of this command is: - * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to - * be issued. - * 2. Issues a write job consists of one or more executor side tasks, each of which writes all - * rows within an RDD partition. - * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any - * exception is thrown during task commitment, also aborts that task. - * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is - * thrown during job commitment, also aborts the job. - */ - def write[K, V: ClassTag]( - rdd: RDD[(K, V)], - hadoopConf: Configuration): Unit = { - // Extract context and configuration from RDD. - val sparkContext = rdd.context - val stageId = rdd.id - val sparkConf = rdd.conf - val conf = new SerializableConfiguration(hadoopConf) - - // Set up a job. - val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date()) - val jobAttemptId = new TaskAttemptID(jobTrackerId, stageId, TaskType.MAP, 0, 0) - val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId) - val format = jobContext.getOutputFormatClass - - if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) { - // FileOutputFormat ignores the filesystem parameter - val jobFormat = format.newInstance - jobFormat.checkOutputSpecs(jobContext) - } - - val committer = FileCommitProtocol.instantiate( - className = classOf[HadoopMapReduceCommitProtocol].getName, - jobId = stageId.toString, - outputPath = conf.value.get("mapreduce.output.fileoutputformat.outputdir"), - isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] - committer.setupJob(jobContext) - - // Try to write all RDD partitions as a Hadoop OutputFormat. - try { - val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { - executeTask( - context = context, - jobTrackerId = jobTrackerId, - sparkStageId = context.stageId, - sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, - committer = committer, - hadoopConf = conf.value, - outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]], - iterator = iter) - }) - - committer.commitJob(jobContext, ret) - logInfo(s"Job ${jobContext.getJobID} committed.") - } catch { - case cause: Throwable => - logError(s"Aborting job ${jobContext.getJobID}.", cause) - committer.abortJob(jobContext) - throw new SparkException("Job aborted.", cause) - } - } - - /** Write an RDD partition out in a single Spark task. */ - private def executeTask[K, V: ClassTag]( - context: TaskContext, - jobTrackerId: String, - sparkStageId: Int, - sparkPartitionId: Int, - sparkAttemptNumber: Int, - committer: FileCommitProtocol, - hadoopConf: Configuration, - outputFormat: Class[_ <: OutputFormat[K, V]], - iterator: Iterator[(K, V)]): TaskCommitMessage = { - // Set up a task. - val attemptId = new TaskAttemptID(jobTrackerId, sparkStageId, TaskType.REDUCE, - sparkPartitionId, sparkAttemptNumber) - val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId) - committer.setupTask(taskContext) - - val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) - - // Initiate the writer. - val taskFormat = outputFormat.newInstance() - // If OutputFormat is Configurable, we should set conf to it. - taskFormat match { - case c: Configurable => c.setConf(hadoopConf) - case _ => () - } - var writer = taskFormat.getRecordWriter(taskContext) - .asInstanceOf[RecordWriter[K, V]] - require(writer != null, "Unable to obtain RecordWriter") - var recordsWritten = 0L - - // Write all rows in RDD partition. - try { - val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { - // Write rows out, release resource and commit the task. - while (iterator.hasNext) { - val pair = iterator.next() - writer.write(pair._1, pair._2) - - // Update bytes written metric every few records - SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) - recordsWritten += 1 - } - if (writer != null) { - writer.close(taskContext) - writer = null - } - committer.commitTask(taskContext) - }(catchBlock = { - // If there is an error, release resource and then abort the task. - try { - if (writer != null) { - writer.close(taskContext) - writer = null - } - } finally { - committer.abortTask(taskContext) - logError(s"Task ${taskContext.getTaskAttemptID} aborted.") - } - }) - - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - - ret - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index acc9c38571007..7d846f9354df6 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -17,143 +17,374 @@ package org.apache.spark.internal.io -import java.io.IOException -import java.text.{NumberFormat, SimpleDateFormat} +import java.text.NumberFormat import java.util.{Date, Locale} +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ -import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, +OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, +TaskAttemptContext => NewTaskAttemptContext, TaskAttemptID => NewTaskAttemptID, TaskType} +import org.apache.hadoop.mapreduce.task.{TaskAttemptContextImpl => NewTaskAttemptContextImpl} -import org.apache.spark.SerializableWritable +import org.apache.spark.{SerializableWritable, SparkConf, SparkException, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.rdd.HadoopRDD -import org.apache.spark.util.SerializableJobConf +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} /** - * Internal helper class that saves an RDD using a Hadoop OutputFormat. - * - * Saves the RDD using a JobConf, which should contain an output key class, an output value class, - * a filename to write to, etc, exactly like in a Hadoop MapReduce job. + * A helper object that saves an RDD using a Hadoop OutputFormat. + */ +private[spark] +object SparkHadoopWriter extends Logging { + import SparkHadoopWriterUtils._ + + /** + * Basic work flow of this command is: + * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to + * be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ + def write[K, V: ClassTag]( + rdd: RDD[(K, V)], + config: HadoopWriteConfigUtil[K, V]): Unit = { + // Extract context and configuration from RDD. + val sparkContext = rdd.context + val stageId = rdd.id + + // Set up a job. + val jobTrackerId = createJobTrackerID(new Date()) + val jobContext = config.createJobContext(jobTrackerId, stageId) + config.initOutputFormat(jobContext) + + // Assert the output format/key/value class is set in JobConf. + config.assertConf(jobContext, rdd.conf) + + val committer = config.createCommitter(stageId) + committer.setupJob(jobContext) + + // Try to write all RDD partitions as a Hadoop OutputFormat. + try { + val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + executeTask( + context = context, + config = config, + jobTrackerId = jobTrackerId, + sparkStageId = context.stageId, + sparkPartitionId = context.partitionId, + sparkAttemptNumber = context.attemptNumber, + committer = committer, + iterator = iter) + }) + + committer.commitJob(jobContext, ret) + logInfo(s"Job ${jobContext.getJobID} committed.") + } catch { + case cause: Throwable => + logError(s"Aborting job ${jobContext.getJobID}.", cause) + committer.abortJob(jobContext) + throw new SparkException("Job aborted.", cause) + } + } + + /** Write a RDD partition out in a single Spark task. */ + private def executeTask[K, V: ClassTag]( + context: TaskContext, + config: HadoopWriteConfigUtil[K, V], + jobTrackerId: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + iterator: Iterator[(K, V)]): TaskCommitMessage = { + // Set up a task. + val taskContext = config.createTaskAttemptContext( + jobTrackerId, sparkStageId, sparkPartitionId, sparkAttemptNumber) + committer.setupTask(taskContext) + + val (outputMetrics, callback) = initHadoopOutputMetrics(context) + + // Initiate the writer. + config.initWriter(taskContext, sparkPartitionId) + var recordsWritten = 0L + + // Write all rows in RDD partition. + try { + val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val pair = iterator.next() + config.write(pair) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) + recordsWritten += 1 + } + + config.closeWriter(taskContext) + committer.commitTask(taskContext) + }(catchBlock = { + // If there is an error, release resource and then abort the task. + try { + config.closeWriter(taskContext) + } finally { + committer.abortTask(taskContext) + logError(s"Task ${taskContext.getTaskAttemptID} aborted.") + } + }) + + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) + + ret + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } + } +} + +/** + * A helper class that reads JobConf from older mapred API, creates output Format/Committer/Writer. */ private[spark] -class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { +class HadoopMapRedWriteConfigUtil[K, V: ClassTag](conf: SerializableJobConf) + extends HadoopWriteConfigUtil[K, V] with Logging { - private val now = new Date() - private val conf = new SerializableJobConf(jobConf) + private var outputFormat: Class[_ <: OutputFormat[K, V]] = null + private var writer: RecordWriter[K, V] = null - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null + private def getConf: JobConf = conf.value - @transient private var writer: RecordWriter[AnyRef, AnyRef] = null - @transient private var format: OutputFormat[AnyRef, AnyRef] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- - def preSetup() { - setIDs(0, 0, 0) - HadoopRDD.addLocalConfiguration("", 0, 0, 0, conf.value) + override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = { + val jobAttemptId = new SerializableWritable(new JobID(jobTrackerId, jobId)) + new JobContextImpl(getConf, jobAttemptId.value) + } - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) + override def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): NewTaskAttemptContext = { + // Update JobConf. + HadoopRDD.addLocalConfiguration(jobTrackerId, jobId, splitId, taskAttemptId, conf.value) + // Create taskContext. + val attemptId = new TaskAttemptID(jobTrackerId, jobId, TaskType.MAP, splitId, taskAttemptId) + new TaskAttemptContextImpl(getConf, attemptId) } + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), - jobid, splitID, attemptID, conf.value) + override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = { + // Update JobConf. + HadoopRDD.addLocalConfiguration("", 0, 0, 0, getConf) + // Create commit protocol. + FileCommitProtocol.instantiate( + className = classOf[HadoopMapRedCommitProtocol].getName, + jobId = jobId.toString, + outputPath = getConf.get("mapred.output.dir"), + isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] } - def open() { + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = { val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) - val outputName = "part-" + numfmt.format(splitID) - val path = FileOutputFormat.getOutputPath(conf.value) + val outputName = "part-" + numfmt.format(splitId) + val path = FileOutputFormat.getOutputPath(getConf) val fs: FileSystem = { if (path != null) { - path.getFileSystem(conf.value) + path.getFileSystem(getConf) } else { - FileSystem.get(conf.value) + FileSystem.get(getConf) } } - getOutputCommitter().setupTask(getTaskContext()) - writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL) + writer = getConf.getOutputFormat + .getRecordWriter(fs, getConf, outputName, Reporter.NULL) + .asInstanceOf[RecordWriter[K, V]] + + require(writer != null, "Unable to obtain RecordWriter") } - def write(key: AnyRef, value: AnyRef) { + override def write(pair: (K, V)): Unit = { + require(writer != null, "Must call createWriter before write.") + writer.write(pair._1, pair._2) + } + + override def closeWriter(taskContext: NewTaskAttemptContext): Unit = { if (writer != null) { - writer.write(key, value) - } else { - throw new IOException("Writer is null, open() has not been called") + writer.close(Reporter.NULL) } } - def close() { - writer.close(Reporter.NULL) - } + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- - def commit() { - SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) + override def initOutputFormat(jobContext: NewJobContext): Unit = { + if (outputFormat == null) { + outputFormat = getConf.getOutputFormat.getClass + .asInstanceOf[Class[_ <: OutputFormat[K, V]]] + } } - def commitJob() { - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) + private def getOutputFormat(): OutputFormat[K, V] = { + require(outputFormat != null, "Must call initOutputFormat first.") + + outputFormat.newInstance() } - // ********* Private Functions ********* + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- + + override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = { + val outputFormatInstance = getOutputFormat() + val keyClass = getConf.getOutputKeyClass + val valueClass = getConf.getOutputValueClass + if (outputFormatInstance == null) { + throw new SparkException("Output format class not set") + } + if (keyClass == null) { + throw new SparkException("Output key class not set") + } + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + SparkHadoopUtil.get.addCredentials(getConf) + + logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + + valueClass.getSimpleName + ")") - private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef, AnyRef]] + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) { + // FileOutputFormat ignores the filesystem parameter + val ignoredFs = FileSystem.get(getConf) + getOutputFormat().checkOutputSpecs(ignoredFs, getConf) } - format + } +} + +/** + * A helper class that reads Configuration from newer mapreduce API, creates output + * Format/Committer/Writer. + */ +private[spark] +class HadoopMapReduceWriteConfigUtil[K, V: ClassTag](conf: SerializableConfiguration) + extends HadoopWriteConfigUtil[K, V] with Logging { + + private var outputFormat: Class[_ <: NewOutputFormat[K, V]] = null + private var writer: NewRecordWriter[K, V] = null + + private def getConf: Configuration = conf.value + + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- + + override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = { + val jobAttemptId = new NewTaskAttemptID(jobTrackerId, jobId, TaskType.MAP, 0, 0) + new NewTaskAttemptContextImpl(getConf, jobAttemptId) + } + + override def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): NewTaskAttemptContext = { + val attemptId = new NewTaskAttemptID( + jobTrackerId, jobId, TaskType.REDUCE, splitId, taskAttemptId) + new NewTaskAttemptContextImpl(getConf, attemptId) + } + + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- + + override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = { + FileCommitProtocol.instantiate( + className = classOf[HadoopMapReduceCommitProtocol].getName, + jobId = jobId.toString, + outputPath = getConf.get("mapreduce.output.fileoutputformat.outputdir"), + isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] } - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = { + val taskFormat = getOutputFormat() + // If OutputFormat is Configurable, we should set conf to it. + taskFormat match { + case c: Configurable => c.setConf(getConf) + case _ => () } - committer + + writer = taskFormat.getRecordWriter(taskContext) + .asInstanceOf[NewRecordWriter[K, V]] + + require(writer != null, "Unable to obtain RecordWriter") + } + + override def write(pair: (K, V)): Unit = { + require(writer != null, "Must call createWriter before write.") + writer.write(pair._1, pair._2) } - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = new JobContextImpl(conf.value, jID.value) + override def closeWriter(taskContext: NewTaskAttemptContext): Unit = { + if (writer != null) { + writer.close(taskContext) + writer = null + } else { + logWarning("Writer has been closed.") } - jobContext } - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- + + override def initOutputFormat(jobContext: NewJobContext): Unit = { + if (outputFormat == null) { + outputFormat = jobContext.getOutputFormatClass + .asInstanceOf[Class[_ <: NewOutputFormat[K, V]]] } - taskContext } - protected def newTaskAttemptContext( - conf: JobConf, - attemptId: TaskAttemptID): TaskAttemptContext = { - new TaskAttemptContextImpl(conf, attemptId) + private def getOutputFormat(): NewOutputFormat[K, V] = { + require(outputFormat != null, "Must call initOutputFormat first.") + + outputFormat.newInstance() } - private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { - jobID = jobid - splitID = splitid - attemptID = attemptid + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- - jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) + override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = { + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) { + getOutputFormat().checkOutputSpecs(jobContext) + } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 58762cc0838cd..4628fa8ba270e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -27,7 +27,6 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} @@ -36,13 +35,11 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewO import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter, - SparkHadoopWriterUtils} +import org.apache.spark.internal.io._ import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -1082,9 +1079,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { - SparkHadoopMapReduceWriter.write( + val config = new HadoopMapReduceWriteConfigUtil[K, V](new SerializableConfiguration(conf)) + SparkHadoopWriter.write( rdd = self, - hadoopConf = conf) + config = config) } /** @@ -1094,62 +1092,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * MapReduce job. */ def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { - // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). - val hadoopConf = conf - val outputFormatInstance = hadoopConf.getOutputFormat - val keyClass = hadoopConf.getOutputKeyClass - val valueClass = hadoopConf.getOutputValueClass - if (outputFormatInstance == null) { - throw new SparkException("Output format class not set") - } - if (keyClass == null) { - throw new SparkException("Output key class not set") - } - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - SparkHadoopUtil.get.addCredentials(hadoopConf) - - logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + - valueClass.getSimpleName + ")") - - if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) { - // FileOutputFormat ignores the filesystem parameter - val ignoredFs = FileSystem.get(hadoopConf) - hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) - } - - val writer = new SparkHadoopWriter(hadoopConf) - writer.preSetup() - - val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - - val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) - - writer.setup(context.stageId, context.partitionId, taskAttemptId) - writer.open() - var recordsWritten = 0L - - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iter.hasNext) { - val record = iter.next() - writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) - - // Update bytes written metric every few records - SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) - recordsWritten += 1 - } - }(finallyBlock = writer.close()) - writer.commit() - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - } - - self.context.runJob(self, writeToFile) - writer.commitJob() + val config = new HadoopMapRedWriteConfigUtil[K, V](new SerializableJobConf(conf)) + SparkHadoopWriter.write( + rdd = self, + config = config) } /** diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 02df157be377c..44dd955ce8690 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -561,7 +561,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { pairs.saveAsHadoopFile( "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf) } - assert(e.getMessage contains "failed to write") + assert(e.getCause.getMessage contains "failed to write") assert(FakeWriterWithCallback.calledBy === "write,callback,close") assert(FakeWriterWithCallback.exception != null, "exception should be captured") diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index e51e6a0d3ff6b..1579b614ea5b0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.scheduler import java.io.File +import java.util.Date import java.util.concurrent.TimeoutException import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.hadoop.mapred.{JobConf, OutputCommitter, TaskAttemptContext, TaskAttemptID} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.TaskType import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -31,7 +33,7 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.internal.io.SparkHadoopWriter +import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.{ThreadUtils, Utils} @@ -214,6 +216,8 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { */ private case class OutputCommitFunctions(tempDirPath: String) { + private val jobId = new SerializableWritable(SparkHadoopWriterUtils.createJobID(new Date, 0)) + // Mock output committer that simulates a successful commit (after commit is authorized) private def successfulOutputCommitter = new FakeOutputCommitter { override def commitTask(context: TaskAttemptContext): Unit = { @@ -256,14 +260,23 @@ private case class OutputCommitFunctions(tempDirPath: String) { def jobConf = new JobConf { override def getOutputCommitter(): OutputCommitter = outputCommitter } - val sparkHadoopWriter = new SparkHadoopWriter(jobConf) { - override def newTaskAttemptContext( - conf: JobConf, - attemptId: TaskAttemptID): TaskAttemptContext = { - mock(classOf[TaskAttemptContext]) - } - } - sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber) - sparkHadoopWriter.commit() + + // Instantiate committer. + val committer = FileCommitProtocol.instantiate( + className = classOf[HadoopMapRedCommitProtocol].getName, + jobId = jobId.value.getId.toString, + outputPath = jobConf.get("mapred.output.dir"), + isAppend = false) + + // Create TaskAttemptContext. + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val taskAttemptId = (ctx.taskAttemptId % Int.MaxValue).toInt + val attemptId = new TaskAttemptID( + new TaskID(jobId.value, TaskType.MAP, ctx.partitionId), taskAttemptId) + val taskContext = new TaskAttemptContextImpl(jobConf, attemptId) + + committer.setupTask(taskContext) + committer.commitTask(taskContext) } } From 528c9281aecc49e9bff204dd303962c705c6f237 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 30 Jun 2017 23:25:14 +0800 Subject: [PATCH 056/779] [ML] Fix scala-2.10 build failure of GeneralizedLinearRegressionSuite. ## What changes were proposed in this pull request? Fix scala-2.10 build failure of ```GeneralizedLinearRegressionSuite```. ## How was this patch tested? Build with scala-2.10. Author: Yanbo Liang Closes #18489 from yanboliang/glr. --- .../ml/regression/GeneralizedLinearRegressionSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index cfaa57314bd66..83f1344a7bcb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1075,7 +1075,7 @@ class GeneralizedLinearRegressionSuite val seCoefR = Array(1.23439, 0.9669, 3.56866) val tValsR = Array(0.80297, -0.65737, -0.06017) val pValsR = Array(0.42199, 0.51094, 0.95202) - val dispersionR = 1 + val dispersionR = 1.0 val nullDevianceR = 2.17561 val residualDevianceR = 0.00018 val residualDegreeOfFreedomNullR = 3 @@ -1114,7 +1114,7 @@ class GeneralizedLinearRegressionSuite assert(x._1 ~== x._2 absTol 1E-3) } summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } - assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.dispersion === dispersionR) assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) assert(summary.deviance ~== residualDevianceR absTol 1E-3) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) @@ -1190,7 +1190,7 @@ class GeneralizedLinearRegressionSuite val seCoefR = Array(1.16826, 0.41703, 1.96249) val tValsR = Array(-2.46387, 2.12428, -2.32757) val pValsR = Array(0.01374, 0.03365, 0.01993) - val dispersionR = 1 + val dispersionR = 1.0 val nullDevianceR = 22.55853 val residualDevianceR = 9.5622 val residualDegreeOfFreedomNullR = 3 @@ -1229,7 +1229,7 @@ class GeneralizedLinearRegressionSuite assert(x._1 ~== x._2 absTol 1E-3) } summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } - assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.dispersion === dispersionR) assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) assert(summary.deviance ~== residualDevianceR absTol 1E-3) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) From 1fe08d62f022e12f2f0161af5d8f9eac51baf1b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9B=BE=E6=9E=97=E8=A5=BF?= Date: Fri, 30 Jun 2017 19:28:43 +0100 Subject: [PATCH 057/779] [SPARK-21223] Change fileToAppInfo in FsHistoryProvider to fix concurrent issue. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What issue does this PR address ? Jira:https://issues.apache.org/jira/browse/SPARK-21223 fix the Thread-safety issue in FsHistoryProvider Currently, Spark HistoryServer use a HashMap named fileToAppInfo in class FsHistoryProvider to store the map of eventlog path and attemptInfo. When use ThreadPool to Replay the log files in the list and merge the list of old applications with new ones, multi thread may update fileToAppInfo at the same time, which may cause Thread-safety issues, such as falling into an infinite loop because of calling resize func of the hashtable. Author: 曾林西 Closes #18430 from zenglinxi0615/master. --- .../apache/spark/deploy/history/FsHistoryProvider.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index d05ca142b618b..b2a50bd055712 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -122,7 +122,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] = new mutable.LinkedHashMap() - val fileToAppInfo = new mutable.HashMap[Path, FsApplicationAttemptInfo]() + val fileToAppInfo = new ConcurrentHashMap[Path, FsApplicationAttemptInfo]() // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] @@ -321,7 +321,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // scan for modified applications, replay and merge them val logInfos: Seq[FileStatus] = statusList .filter { entry => - val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) + val fileInfo = fileToAppInfo.get(entry.getPath()) + val prevFileSize = if (fileInfo != null) fileInfo.fileSize else 0L !entry.isDirectory() && // FsHistoryProvider generates a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to @@ -475,7 +476,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) fileStatus.getLen(), appListener.appSparkVersion.getOrElse("") ) - fileToAppInfo(logPath) = attemptInfo + fileToAppInfo.put(logPath, attemptInfo) logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") Some(attemptInfo) } else { From eed9c4ef859fdb75a816a3e0ce2d593b34b23444 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 30 Jun 2017 14:23:56 -0700 Subject: [PATCH 058/779] [SPARK-21129][SQL] Arguments of SQL function call should not be named expressions ### What changes were proposed in this pull request? Function argument should not be named expressions. It could cause two issues: - Misleading error message - Unexpected query results when the column name is `distinct`, which is not a reserved word in our parser. ``` spark-sql> select count(distinct c1, distinct c2) from t1; Error in query: cannot resolve '`distinct`' given input columns: [c1, c2]; line 1 pos 26; 'Project [unresolvedalias('count(c1#30, 'distinct), None)] +- SubqueryAlias t1 +- CatalogRelation `default`.`t1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [c1#30, c2#31] ``` After the fix, the error message becomes ``` spark-sql> select count(distinct c1, distinct c2) from t1; Error in query: extraneous input 'c2' expecting {')', ',', '.', '[', 'OR', 'AND', 'IN', NOT, 'BETWEEN', 'LIKE', RLIKE, 'IS', EQ, '<=>', '<>', '!=', '<', LTE, '>', GTE, '+', '-', '*', '/', '%', 'DIV', '&', '|', '||', '^'}(line 1, pos 35) == SQL == select count(distinct c1, distinct c2) from t1 -----------------------------------^^^ ``` ### How was this patch tested? Added a test case to parser suite. Author: Xiao Li Author: gatorsmile Closes #18338 from gatorsmile/parserDistinctAggFunc. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../spark/sql/catalyst/dsl/package.scala | 1 + .../sql/catalyst/parser/AstBuilder.scala | 9 +++++- .../parser/ExpressionParserSuite.scala | 6 ++-- .../sql/catalyst/parser/PlanParserSuite.scala | 6 ++++ .../resources/sql-tests/inputs/struct.sql | 7 ++++ .../sql-tests/results/struct.sql.out | 32 ++++++++++++++++++- 7 files changed, 59 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 9456031736528..7ffa150096333 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -561,6 +561,7 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast + | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct | FIRST '(' expression (IGNORE NULLS)? ')' #first | LAST '(' expression (IGNORE NULLS)? ')' #last | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position @@ -569,7 +570,7 @@ primaryExpression | qualifiedName '.' ASTERISK #star | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor | '(' query ')' #subqueryExpression - | qualifiedName '(' (setQuantifier? namedExpression (',' namedExpression)*)? ')' + | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index f6792569b704e..7c100afcd738f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -170,6 +170,7 @@ package object dsl { case Seq() => UnresolvedStar(None) case target => UnresolvedStar(Option(target)) } + def namedStruct(e: Expression*): Expression = CreateNamedStruct(e) def callFunction[T, U]( func: T => U, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ef79cbcaa0ce6..8eac3ef2d3568 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1061,6 +1061,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } + /** + * Create a [[CreateStruct]] expression. + */ + override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.argument.asScala.map(expression)) + } + /** * Create a [[First]] expression. */ @@ -1091,7 +1098,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) - val arguments = ctx.namedExpression().asScala.map(expression) match { + val arguments = ctx.argument.asScala.map(expression) match { case Seq(UnresolvedStar(None)) if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 4d08f016a4a16..45f9f72dccc45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -231,7 +231,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) assertEqual("`select`(all a, b)", 'select.function('a, 'b)) - assertEqual("foo(a as x, b as e)", 'foo.function('a as 'x, 'b as 'e)) + intercept("foo(a x)", "extraneous input 'x'") } test("window function expressions") { @@ -330,7 +330,9 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a.b", UnresolvedAttribute("a.b")) assertEqual("`select`.b", UnresolvedAttribute("select.b")) assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. - assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + assertEqual( + "struct(a, b).b", + namedStruct(NamePlaceholder, 'a, NamePlaceholder, 'b).getField("b")) } test("reference") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index bf15b85d5b510..5b2573fa4d601 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -223,6 +223,12 @@ class PlanParserSuite extends AnalysisTest { assertEqual(s"$sql grouping sets((a, b), (a), ())", GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + + val m = intercept[ParseException] { + parsePlan("SELECT a, b, count(distinct a, distinct b) as c FROM d GROUP BY a, b") + }.getMessage + assert(m.contains("extraneous input 'b'")) + } test("limit") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/struct.sql b/sql/core/src/test/resources/sql-tests/inputs/struct.sql index e56344dc4de80..93a1238ab18c2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/struct.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/struct.sql @@ -18,3 +18,10 @@ SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x; -- Prepend a column to a struct SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x; + +-- Select a column from a struct +SELECT ID, STRUCT(ST.*).C NST FROM tbl_x; +SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x; + +-- Select an alias from a struct +SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out index 3e32f46195464..1da33bc736f0b 100644 --- a/sql/core/src/test/resources/sql-tests/results/struct.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 9 -- !query 0 @@ -58,3 +58,33 @@ struct> 1 {"AA":"1","C":"gamma","D":"delta"} 2 {"AA":"2","C":"epsilon","D":"eta"} 3 {"AA":"3","C":"theta","D":"iota"} + + +-- !query 6 +SELECT ID, STRUCT(ST.*).C NST FROM tbl_x +-- !query 6 schema +struct +-- !query 6 output +1 gamma +2 epsilon +3 theta + + +-- !query 7 +SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x +-- !query 7 schema +struct +-- !query 7 output +1 delta +2 eta +3 iota + + +-- !query 8 +SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x +-- !query 8 schema +struct +-- !query 8 output +1 delta +2 eta +3 iota From fd1325522549937232f37215db53d6478f48644c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 30 Jun 2017 15:11:27 -0700 Subject: [PATCH 059/779] [SPARK-21052][SQL][FOLLOW-UP] Add hash map metrics to join ## What changes were proposed in this pull request? Remove `numHashCollisions` in `BytesToBytesMap`. And change `getAverageProbesPerLookup()` to `getAverageProbesPerLookup` as suggested. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #18480 from viirya/SPARK-21052-followup. --- .../spark/unsafe/map/BytesToBytesMap.java | 33 ------------------- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../sql/execution/joins/HashedRelation.scala | 8 ++--- 3 files changed, 5 insertions(+), 38 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 4bef21b6b4e4d..3b6200e74f1e1 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -160,14 +160,10 @@ public final class BytesToBytesMap extends MemoryConsumer { private final boolean enablePerfMetrics; - private long timeSpentResizingNs = 0; - private long numProbes = 0; private long numKeyLookups = 0; - private long numHashCollisions = 0; - private long peakMemoryUsedBytes = 0L; private final int initialCapacity; @@ -489,10 +485,6 @@ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location l ); if (areEqual) { return; - } else { - if (enablePerfMetrics) { - numHashCollisions++; - } } } } @@ -859,16 +851,6 @@ public long getPeakMemoryUsedBytes() { return peakMemoryUsedBytes; } - /** - * Returns the total amount of time spent resizing this map (in nanoseconds). - */ - public long getTimeSpentResizingNs() { - if (!enablePerfMetrics) { - throw new IllegalStateException(); - } - return timeSpentResizingNs; - } - /** * Returns the average number of probes per key lookup. */ @@ -879,13 +861,6 @@ public double getAverageProbesPerLookup() { return (1.0 * numProbes) / numKeyLookups; } - public long getNumHashCollisions() { - if (!enablePerfMetrics) { - throw new IllegalStateException(); - } - return numHashCollisions; - } - @VisibleForTesting public int getNumDataPages() { return dataPages.size(); @@ -923,10 +898,6 @@ public void reset() { void growAndRehash() { assert(longArray != null); - long resizeStartTime = -1; - if (enablePerfMetrics) { - resizeStartTime = System.nanoTime(); - } // Store references to the old data structures to be used when we re-hash final LongArray oldLongArray = longArray; final int oldCapacity = (int) oldLongArray.size() / 2; @@ -951,9 +922,5 @@ void growAndRehash() { longArray.set(newPos * 2 + 1, hashcode); } freeArray(oldLongArray); - - if (enablePerfMetrics) { - timeSpentResizingNs += System.nanoTime() - resizeStartTime; - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index b09edf380c2d4..0396168d3f311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -215,7 +215,7 @@ trait HashJoin { // At the end of the task, we update the avg hash probe. TaskContext.get().addTaskCompletionListener(_ => - avgHashProbe.set(hashed.getAverageProbesPerLookup())) + avgHashProbe.set(hashed.getAverageProbesPerLookup)) val resultProj = createResultProjection joinedIter.map { r => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 3c702856114f9..2038cb9edb67d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -83,7 +83,7 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { /** * Returns the average number of probes per key lookup. */ - def getAverageProbesPerLookup(): Double + def getAverageProbesPerLookup: Double } private[execution] object HashedRelation { @@ -280,7 +280,7 @@ private[joins] class UnsafeHashedRelation( read(in.readInt, in.readLong, in.readBytes) } - override def getAverageProbesPerLookup(): Double = binaryMap.getAverageProbesPerLookup() + override def getAverageProbesPerLookup: Double = binaryMap.getAverageProbesPerLookup } private[joins] object UnsafeHashedRelation { @@ -776,7 +776,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap /** * Returns the average number of probes per key lookup. */ - def getAverageProbesPerLookup(): Double = numProbes.toDouble / numKeyLookups + def getAverageProbesPerLookup: Double = numProbes.toDouble / numKeyLookups } private[joins] class LongHashedRelation( @@ -829,7 +829,7 @@ private[joins] class LongHashedRelation( map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } - override def getAverageProbesPerLookup(): Double = map.getAverageProbesPerLookup() + override def getAverageProbesPerLookup: Double = map.getAverageProbesPerLookup } /** From 4eb41879ce774dec1d16b2281ab1fbf41f9d418a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 1 Jul 2017 09:25:29 +0800 Subject: [PATCH 060/779] [SPARK-17528][SQL] data should be copied properly before saving into InternalRow ## What changes were proposed in this pull request? For performance reasons, `UnsafeRow.getString`, `getStruct`, etc. return a "pointer" that points to a memory region of this unsafe row. This makes the unsafe projection a little dangerous, because all of its output rows share one instance. When we implement SQL operators, we should be careful to not cache the input rows because they may be produced by unsafe projection from child operator and thus its content may change overtime. However, when we updating values of InternalRow(e.g. in mutable projection and safe projection), we only copy UTF8String, we should also copy InternalRow, ArrayData and MapData. This PR fixes this, and also fixes the copy of vairous InternalRow, ArrayData and MapData implementations. ## How was this patch tested? new regression tests Author: Wenchen Fan Closes #18483 from cloud-fan/fix-copy. --- .../apache/spark/unsafe/types/UTF8String.java | 6 + .../spark/sql/catalyst/InternalRow.scala | 27 ++++- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/SpecificInternalRow.scala | 12 -- .../expressions/aggregate/collect.scala | 2 +- .../expressions/aggregate/interfaces.scala | 6 + .../expressions/codegen/CodeGenerator.scala | 6 +- .../codegen/GenerateSafeProjection.scala | 2 - .../spark/sql/catalyst/expressions/rows.scala | 23 ++-- .../sql/catalyst/util/GenericArrayData.scala | 10 +- .../scala/org/apache/spark/sql/RowTest.scala | 4 - .../catalyst/expressions/MapDataSuite.scala | 57 ---------- .../codegen/GeneratedProjectionSuite.scala | 36 ++++++ .../sql/catalyst/util/ComplexDataSuite.scala | 107 ++++++++++++++++++ .../execution/vectorized/ColumnarBatch.java | 2 +- .../SortBasedAggregationIterator.scala | 15 +-- .../columnar/GenerateColumnAccessor.scala | 1 - .../execution/window/AggregateProcessor.scala | 7 +- 18 files changed, 212 insertions(+), 113 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 40b9fc9534f44..9de4ca71ff6d4 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1088,6 +1088,12 @@ public UTF8String clone() { return fromBytes(getBytes()); } + public UTF8String copy() { + byte[] bytes = new byte[numBytes]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + return fromBytes(bytes); + } + @Override public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 256f64e320be8..29110640d64f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.unsafe.types.UTF8String /** * An abstract class for row used internally in Spark SQL, which only contains the columns as @@ -33,6 +35,10 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def setNullAt(i: Int): Unit + /** + * Updates the value at column `i`. Note that after updating, the given value will be kept in this + * row, and the caller side should guarantee that this value won't be changed afterwards. + */ def update(i: Int, value: Any): Unit // default implementation (slow) @@ -58,7 +64,15 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def copy(): InternalRow /** Returns true if there are any NULL values in this row. */ - def anyNull: Boolean + def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } /* ---------------------- utility methods for Scala ---------------------- */ @@ -94,4 +108,15 @@ object InternalRow { /** Returns an empty [[InternalRow]]. */ val empty = apply() + + /** + * Copies the given value if it's string/struct/array/map type. + */ + def copyValue(value: Any): Any = value match { + case v: UTF8String => v.copy() + case v: InternalRow => v.copy() + case v: ArrayData => v.copy() + case v: MapData => v.copy() + case _ => value + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 43df19ba009a8..3862e64b9d828 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1047,7 +1047,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String final $rowClass $result = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpRow = $c; $fieldsEvalCode - $evPrim = $result.copy(); + $evPrim = $result; """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index 74e0b4691d4cc..75feaf670c84a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** @@ -220,17 +219,6 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen override def isNullAt(i: Int): Boolean = values(i).isNull - override def copy(): InternalRow = { - val newValues = new Array[Any](values.length) - var i = 0 - while (i < values.length) { - newValues(i) = values(i).boxed - i += 1 - } - - new GenericInternalRow(newValues) - } - override protected def genericGet(i: Int): Any = values(i).boxed override def update(ordinal: Int, value: Any) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 26cd9ab665383..0d2f9889a27d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -52,7 +52,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator if (value != null) { - buffer += value + buffer += InternalRow.copyValue(value) } buffer } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index fffcc7c9ef53a..7af4901435857 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -317,6 +317,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit @@ -326,6 +329,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5158949b95629..b15bf2ca7c116 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -408,9 +408,11 @@ class CodegenContext { dataType match { case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) - case StringType => s"$row.update($ordinal, $value.clone())" case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" case _ => s"$row.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f708aeff2b146..dd0419d2286d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -131,8 +131,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) - // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. - case StringType => ExprCode("", "false", s"$input.clone()") case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case _ => ExprCode("", "false", input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 751b821e1b009..65539a2f00e6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -50,16 +50,6 @@ trait BaseGenericInternalRow extends InternalRow { override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - override def anyNull: Boolean = { - val len = numFields - var i = 0 - while (i < len) { - if (isNullAt(i)) { return true } - i += 1 - } - false - } - override def toString: String = { if (numFields == 0) { "[empty row]" @@ -79,6 +69,17 @@ trait BaseGenericInternalRow extends InternalRow { } } + override def copy(): GenericInternalRow = { + val len = numFields + val newValues = new Array[Any](len) + var i = 0 + while (i < len) { + newValues(i) = InternalRow.copyValue(genericGet(i)) + i += 1 + } + new GenericInternalRow(newValues) + } + override def equals(o: Any): Boolean = { if (!o.isInstanceOf[BaseGenericInternalRow]) { return false @@ -206,6 +207,4 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow override def setNullAt(i: Int): Unit = { values(i) = null} override def update(i: Int, value: Any): Unit = { values(i) = value } - - override def copy(): GenericInternalRow = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index dd660c80a9c3c..9e39ed9c3a778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray)) - override def copy(): ArrayData = new GenericArrayData(array.clone()) + override def copy(): ArrayData = { + val newValues = new Array[Any](array.length) + var i = 0 + while (i < array.length) { + newValues(i) = InternalRow.copyValue(array(i)) + i += 1 + } + new GenericArrayData(newValues) + } override def numElements(): Int = array.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index c9c9599e7f463..25699de33d717 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -121,10 +121,6 @@ class RowTest extends FunSpec with Matchers { externalRow should be theSameInstanceAs externalRow.copy() } - it("copy should return same ref for internal rows") { - internalRow should be theSameInstanceAs internalRow.copy() - } - it("toSeq should not expose internal state for external rows") { val modifiedValues = modifyValues(externalRow.toSeq) externalRow.toSeq should not equal modifiedValues diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala deleted file mode 100644 index 25a675a90276d..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.collection._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class MapDataSuite extends SparkFunSuite { - - test("inequality tests") { - def u(str: String): UTF8String = UTF8String.fromString(str) - - // test data - val testMap1 = Map(u("key1") -> 1) - val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) - val testMap3 = Map(u("key1") -> 1) - val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) - - // ArrayBasedMapData - val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) - val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) - val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) - val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) - assert(testArrayMap1 !== testArrayMap3) - assert(testArrayMap2 !== testArrayMap4) - - // UnsafeMapData - val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) - val row = new GenericInternalRow(1) - def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { - row.update(0, map) - val unsafeRow = unsafeConverter.apply(row) - unsafeRow.getMap(0).copy - } - assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) - assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 58ea5b9cb52d3..0cd0d8859145f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -172,4 +172,40 @@ class GeneratedProjectionSuite extends SparkFunSuite { assert(unsafe1 === unsafe3) assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7)) } + + test("MutableProjection should not cache content from the input row") { + val mutableProj = GenerateMutableProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val row = new GenericInternalRow(1) + mutableProj.target(row) + + val unsafeProj = GenerateUnsafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a")))) + + mutableProj.apply(unsafeRow) + assert(row.getStruct(0, 1).getString(0) == "a") + + // Even if the input row of the mutable projection has been changed, the target mutable row + // should keep same. + unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) + assert(row.getStruct(0, 1).getString(0).toString == "a") + } + + test("SafeProjection should not cache content from the input row") { + val safeProj = GenerateSafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + + val unsafeProj = GenerateUnsafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a")))) + + val row = safeProj.apply(unsafeRow) + assert(row.getStruct(0, 1).getString(0) == "a") + + // Even if the input row of the mutable projection has been changed, the target mutable row + // should keep same. + unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) + assert(row.getStruct(0, 1).getString(0).toString == "a") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala new file mode 100644 index 0000000000000..9d285916bcf42 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class ComplexDataSuite extends SparkFunSuite { + def utf8(str: String): UTF8String = UTF8String.fromString(str) + + test("inequality tests for MapData") { + // test data + val testMap1 = Map(utf8("key1") -> 1) + val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2) + val testMap3 = Map(utf8("key1") -> 1) + val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2) + + // ArrayBasedMapData + val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) + val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) + val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) + val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) + assert(testArrayMap1 !== testArrayMap3) + assert(testArrayMap2 !== testArrayMap4) + + // UnsafeMapData + val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) + val row = new GenericInternalRow(1) + def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { + row.update(0, map) + val unsafeRow = unsafeConverter.apply(row) + unsafeRow.getMap(0).copy + } + assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) + assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) + } + + test("GenericInternalRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericRow = genericRow.copy() + assert(copiedGenericRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedGenericRow.getString(0) == "a") + } + + test("SpecificMutableRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val mutableRow = new SpecificInternalRow(Seq(StringType)) + mutableRow(0) = unsafeRow.getUTF8String(0) + val copiedMutableRow = mutableRow.copy() + assert(copiedMutableRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedMutableRow.getString(0) == "a") + } + + test("GenericArrayData.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericArray = genericArray.copy() + assert(copiedGenericArray.getUTF8String(0).toString == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied array data should not be changed externally. + assert(copiedGenericArray.getUTF8String(0).toString == "a") + } + + test("copy on nested complex type") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0)))) + val copied = arrayOfRow.copy() + assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied data should not be changed externally. + assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index e23a64350cbc5..34dc3af9b85c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -149,7 +149,7 @@ public InternalRow copy() { } else if (dt instanceof DoubleType) { row.setDouble(i, getDouble(i)); } else if (dt instanceof StringType) { - row.update(i, getUTF8String(i)); + row.update(i, getUTF8String(i).copy()); } else if (dt instanceof BinaryType) { row.update(i, getBinary(i)); } else if (dt instanceof DecimalType) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index bea2dce1a7657..a5a444b160c63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -86,17 +86,6 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer - // This safe projection is used to turn the input row into safe row. This is necessary - // because the input row may be produced by unsafe projection in child operator and all the - // produced rows share one byte array. However, when we update the aggregate buffer according to - // the input row, we may cache some values from input row, e.g. `Max` will keep the max value from - // input row via MutableProjection, `CollectList` will keep all values in an array via - // ImperativeAggregate framework. These values may get changed unexpectedly if the underlying - // unsafe projection update the shared byte array. By applying a safe projection to the input row, - // we can cut down the connection from input row to the shared byte array, and thus it's safe to - // cache values from input row while updating the aggregation buffer. - private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) - protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -119,7 +108,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -130,7 +119,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, safeProj(currentRow)) + processRow(sortBasedAggregationBuffer, currentRow) } else { // We find a new group. findNextPartition = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index d3fa0dcd2d7c3..fc977f2fd5530 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -56,7 +56,6 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalR // all other methods inherited from GenericMutableRow are not need override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException override def numFields: Int = throw new UnsupportedOperationException - override def copy(): InternalRow = throw new UnsupportedOperationException } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index 2195c6ea95948..bc141b36e63b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -145,13 +145,10 @@ private[window] final class AggregateProcessor( /** Update the buffer. */ def update(input: InternalRow): Unit = { - // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that - // MutableProjection makes copies of the complex input objects it buffer. - val copy = input.copy() - updateProjection(join(buffer, copy)) + updateProjection(join(buffer, input)) var i = 0 while (i < numImperatives) { - imperatives(i).update(buffer, copy) + imperatives(i).update(buffer, input) i += 1 } } From 61b5df567eb8ae0df4059cb0e334316fff462de9 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Sat, 1 Jul 2017 10:01:44 +0800 Subject: [PATCH 061/779] [SPARK-21127][SQL] Update statistics after data changing commands ## What changes were proposed in this pull request? Update stats after the following data changing commands: - InsertIntoHadoopFsRelationCommand - InsertIntoHiveTable - LoadDataCommand - TruncateTableCommand - AlterTableSetLocationCommand - AlterTableDropPartitionCommand ## How was this patch tested? Added new test cases. Author: wangzhenhua Author: Zhenhua Wang Closes #18334 from wzhfy/changeStatsForOperation. --- .../apache/spark/sql/internal/SQLConf.scala | 10 + .../sql/execution/command/CommandUtils.scala | 17 +- .../spark/sql/execution/command/ddl.scala | 15 +- .../spark/sql/StatisticsCollectionSuite.scala | 77 +++++--- .../spark/sql/hive/StatisticsSuite.scala | 187 +++++++++++------- 5 files changed, 207 insertions(+), 99 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c641e4d3a23e1..25152f3e32d6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -774,6 +774,14 @@ object SQLConf { .doubleConf .createWithDefault(0.05) + val AUTO_UPDATE_SIZE = + buildConf("spark.sql.statistics.autoUpdate.size") + .doc("Enables automatic update for table size once table's data is changed. Note that if " + + "the total number of files of the table is very large, this can be expensive and slow " + + "down data change commands.") + .booleanConf + .createWithDefault(false) + val CBO_ENABLED = buildConf("spark.sql.cbo.enabled") .doc("Enables CBO for estimation of plan statistics when set true.") @@ -1083,6 +1091,8 @@ class SQLConf extends Serializable with Logging { def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) + def autoUpdateSize: Boolean = getConf(SQLConf.AUTO_UPDATE_SIZE) + def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 92397607f38fd..fce12cc96620c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -36,7 +36,14 @@ object CommandUtils extends Logging { def updateTableStats(sparkSession: SparkSession, table: CatalogTable): Unit = { if (table.stats.nonEmpty) { val catalog = sparkSession.sessionState.catalog - catalog.alterTableStats(table.identifier, None) + if (sparkSession.sessionState.conf.autoUpdateSize) { + val newTable = catalog.getTableMetadata(table.identifier) + val newSize = CommandUtils.calculateTotalSize(sparkSession.sessionState, newTable) + val newStats = CatalogStatistics(sizeInBytes = newSize) + catalog.alterTableStats(table.identifier, Some(newStats)) + } else { + catalog.alterTableStats(table.identifier, None) + } } } @@ -84,7 +91,9 @@ object CommandUtils extends Logging { size } - locationUri.map { p => + val startTime = System.nanoTime() + logInfo(s"Starting to calculate the total file size under path $locationUri.") + val size = locationUri.map { p => val path = new Path(p) try { val fs = path.getFileSystem(sessionState.newHadoopConf()) @@ -97,6 +106,10 @@ object CommandUtils extends Logging { 0L } }.getOrElse(0L) + val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) + logInfo(s"It took $durationInMs ms to calculate the total file size under path $locationUri.") + + size } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index ac897c1b22d77..ba7ca84f229fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -437,7 +437,20 @@ case class AlterTableAddPartitionCommand( } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) - CommandUtils.updateTableStats(sparkSession, table) + if (table.stats.nonEmpty) { + if (sparkSession.sessionState.conf.autoUpdateSize) { + val addedSize = parts.map { part => + CommandUtils.calculateLocationSize(sparkSession.sessionState, table.identifier, + part.storage.locationUri) + }.sum + if (addedSize > 0) { + val newStats = CatalogStatistics(sizeInBytes = table.stats.get.sizeInBytes + addedSize) + catalog.alterTableStats(table.identifier, Some(newStats)) + } + } else { + catalog.alterTableStats(table.identifier, None) + } + } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b031c52dad8b5..d9392de37a815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ @@ -178,36 +178,63 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after set location command") { val table = "change_stats_set_location_table" - withTable(table) { - spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") - val fetched1 = checkTableStats( - table, hasSizeInBytes = true, expectedRowCounts = Some(100)) - assert(fetched1.get.sizeInBytes > 0) - assert(fetched1.get.colStats.size == 2) - - // set location command - withTempDir { newLocation => - sql(s"ALTER TABLE $table SET LOCATION '${newLocation.toURI.toString}'") - checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats( + table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // set location command + val initLocation = spark.sessionState.catalog.getTableMetadata(TableIdentifier(table)) + .storage.locationUri.get.toString + withTempDir { newLocation => + sql(s"ALTER TABLE $table SET LOCATION '${newLocation.toURI.toString}'") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes == 0) + assert(fetched2.get.colStats.isEmpty) + + // set back to the initial location + sql(s"ALTER TABLE $table SET LOCATION '$initLocation'") + val fetched3 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched3.get.sizeInBytes == fetched1.get.sizeInBytes) + } else { + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } } } } test("change stats after insert command for datasource table") { val table = "change_stats_insert_datasource_table" - withTable(table) { - sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetched1.get.sizeInBytes == 0) - assert(fetched1.get.colStats.size == 2) - - // insert into command - sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") - checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.isEmpty) + } else { + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 5fd266c2d033c..c601038a2b0af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -444,88 +444,133 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("change stats after insert command for hive table") { val table = s"change_stats_insert_hive_table" - withTable(table) { - sql(s"CREATE TABLE $table (i int, j string)") - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetched1.get.sizeInBytes == 0) - assert(fetched1.get.colStats.size == 2) - - // insert into command - sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") - assert(getStatsProperties(table).isEmpty) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string)") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched2.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + } + } } } test("change stats after load data command") { val table = "change_stats_load_table" - withTable(table) { - sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetched1.get.sizeInBytes == 0) - assert(fetched1.get.colStats.size == 2) - - withTempDir { loadPath => - // load data command - val file = new File(loadPath + "/data") - val writer = new PrintWriter(file) - writer.write("2,xyz") - writer.close() - sql(s"LOAD DATA INPATH '${loadPath.toURI.toString}' INTO TABLE $table") - assert(getStatsProperties(table).isEmpty) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + withTempDir { loadPath => + // load data command + val file = new File(loadPath + "/data") + val writer = new PrintWriter(file) + writer.write("2,xyz") + writer.close() + sql(s"LOAD DATA INPATH '${loadPath.toURI.toString}' INTO TABLE $table") + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched2.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + } + } } } } test("change stats after add/drop partition command") { val table = "change_stats_part_table" - withTable(table) { - sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") - // table has two partitions initially - for (ds <- Seq("2008-04-08"); hr <- Seq("11", "12")) { - sql(s"INSERT OVERWRITE TABLE $table PARTITION (ds='$ds',hr='$hr') SELECT 1, 'a'") - } - // analyze to get initial stats - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) - assert(fetched1.get.sizeInBytes > 0) - assert(fetched1.get.colStats.size == 2) - - withTempPaths(numPaths = 2) { case Seq(dir1, dir2) => - val file1 = new File(dir1 + "/data") - val writer1 = new PrintWriter(file1) - writer1.write("1,a") - writer1.close() - - val file2 = new File(dir2 + "/data") - val writer2 = new PrintWriter(file2) - writer2.write("1,a") - writer2.close() - - // add partition command - sql( - s""" - |ALTER TABLE $table ADD - |PARTITION (ds='2008-04-09', hr='11') LOCATION '${dir1.toURI.toString}' - |PARTITION (ds='2008-04-09', hr='12') LOCATION '${dir2.toURI.toString}' - """.stripMargin) - assert(getStatsProperties(table).isEmpty) - - // generate stats again - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") - val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(4)) - assert(fetched2.get.sizeInBytes > 0) - assert(fetched2.get.colStats.size == 2) - - // drop partition command - sql(s"ALTER TABLE $table DROP PARTITION (ds='2008-04-08'), PARTITION (hr='12')") - // only one partition left - assert(spark.sessionState.catalog.listPartitions(TableIdentifier(table)) - .map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) - assert(getStatsProperties(table).isEmpty) + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") + // table has two partitions initially + for (ds <- Seq("2008-04-08"); hr <- Seq("11", "12")) { + sql(s"INSERT OVERWRITE TABLE $table PARTITION (ds='$ds',hr='$hr') SELECT 1, 'a'") + } + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + withTempPaths(numPaths = 2) { case Seq(dir1, dir2) => + val file1 = new File(dir1 + "/data") + val writer1 = new PrintWriter(file1) + writer1.write("1,a") + writer1.close() + + val file2 = new File(dir2 + "/data") + val writer2 = new PrintWriter(file2) + writer2.write("1,a") + writer2.close() + + // add partition command + sql( + s""" + |ALTER TABLE $table ADD + |PARTITION (ds='2008-04-09', hr='11') LOCATION '${dir1.toURI.toString}' + |PARTITION (ds='2008-04-09', hr='12') LOCATION '${dir2.toURI.toString}' + """.stripMargin) + if (autoUpdate) { + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched2.get.sizeInBytes > fetched1.get.sizeInBytes) + assert(fetched2.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched2.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + + // now the table has four partitions, generate stats again + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched3 = checkTableStats( + table, hasSizeInBytes = true, expectedRowCounts = Some(4)) + assert(fetched3.get.sizeInBytes > 0) + assert(fetched3.get.colStats.size == 2) + + // drop partition command + sql(s"ALTER TABLE $table DROP PARTITION (ds='2008-04-08'), PARTITION (hr='12')") + assert(spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + .map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + // only one partition left + if (autoUpdate) { + val fetched4 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetched4.get.sizeInBytes < fetched1.get.sizeInBytes) + assert(fetched4.get.colStats.isEmpty) + val statsProp = getStatsProperties(table) + assert(statsProp(STATISTICS_TOTAL_SIZE).toLong == fetched4.get.sizeInBytes) + } else { + assert(getStatsProperties(table).isEmpty) + } + } + } } } } From b1d719e7c9faeb5661a7e712b3ecefca56bf356f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 30 Jun 2017 21:10:23 -0700 Subject: [PATCH 062/779] [SPARK-21273][SQL] Propagate logical plan stats using visitor pattern and mixin ## What changes were proposed in this pull request? We currently implement statistics propagation directly in logical plan. Given we already have two different implementations, it'd make sense to actually decouple the two and add stats propagation using mixin. This would reduce the coupling between logical plan and statistics handling. This can also be a powerful pattern in the future to add additional properties (e.g. constraints). ## How was this patch tested? Should be covered by existing test cases. Author: Reynold Xin Closes #18479 from rxin/stats-trait. --- .../sql/catalyst/catalog/interface.scala | 2 +- .../plans/logical/LocalRelation.scala | 5 +- .../catalyst/plans/logical/LogicalPlan.scala | 61 +------ .../plans/logical/LogicalPlanVisitor.scala | 87 ++++++++++ .../plans/logical/basicLogicalOperators.scala | 128 +------------- .../sql/catalyst/plans/logical/hints.scala | 5 - .../BasicStatsPlanVisitor.scala | 82 +++++++++ .../statsEstimation/LogicalPlanStats.scala | 50 ++++++ .../SizeInBytesOnlyStatsPlanVisitor.scala | 163 ++++++++++++++++++ .../BasicStatsEstimationSuite.scala | 44 ----- .../StatsEstimationTestBase.scala | 2 +- .../spark/sql/execution/ExistingRDD.scala | 4 +- .../execution/columnar/InMemoryRelation.scala | 2 +- .../datasources/LogicalRelation.scala | 7 +- .../sql/execution/streaming/memory.scala | 3 +- .../PruneFileSourcePartitionsSuite.scala | 2 +- 16 files changed, 409 insertions(+), 238 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/LogicalPlanStats.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index da50b0e7e8e42..9531456434a15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -438,7 +438,7 @@ case class CatalogRelation( case (attr, index) => attr.withExprId(ExprId(index + dataCols.length)) }) - override def computeStats: Statistics = { + override def computeStats(): Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for // hive serde tables, we will always generate a statistics. // TODO: unify the table stats generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index dc2add64b68b7..1c986fbde7ada 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -66,9 +66,8 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def computeStats: Statistics = - Statistics(sizeInBytes = - output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) + override def computeStats(): Statistics = + Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0d30aa76049a5..8649603b1a9f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,11 +22,16 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType -abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstraints with Logging { +abstract class LogicalPlan + extends QueryPlan[LogicalPlan] + with LogicalPlanStats + with QueryPlanConstraints + with Logging { private var _analyzed: Boolean = false @@ -80,40 +85,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai } } - /** A cache for the estimated statistics, such that it will only be computed once. */ - private var statsCache: Option[Statistics] = None - - /** - * Returns the estimated statistics for the current logical plan node. Under the hood, this - * method caches the return value, which is computed based on the configuration passed in the - * first time. If the configuration changes, the cache can be invalidated by calling - * [[invalidateStatsCache()]]. - */ - final def stats: Statistics = statsCache.getOrElse { - statsCache = Some(computeStats) - statsCache.get - } - - /** Invalidates the stats cache. See [[stats]] for more information. */ - final def invalidateStatsCache(): Unit = { - statsCache = None - children.foreach(_.invalidateStatsCache()) - } - - /** - * Computes [[Statistics]] for this plan. The default implementation assumes the output - * cardinality is the product of all child plan's cardinality, i.e. applies in the case - * of cartesian joins. - * - * [[LeafNode]]s must override this. - */ - protected def computeStats: Statistics = { - if (children.isEmpty) { - throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") - } - Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product) - } - override def verboseStringWithSuffix: String = { super.verboseString + statsCache.map(", " + _.toString).getOrElse("") } @@ -300,6 +271,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai abstract class LeafNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = Nil override def producedAttributes: AttributeSet = outputSet + + /** Leaf nodes that can survive analysis must define their own statistics. */ + def computeStats(): Statistics = throw new UnsupportedOperationException } /** @@ -331,23 +305,6 @@ abstract class UnaryNode extends LogicalPlan { } override protected def validConstraints: Set[Expression] = child.constraints - - override def computeStats: Statistics = { - // There should be some overhead in Row object, the size should not be zero when there is - // no columns, this help to prevent divide-by-zero error. - val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 - val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 - // Assume there will be the same number of rows as child has. - var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize - if (sizeInBytes == 0) { - // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero - // (product of children). - sizeInBytes = 1 - } - - // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints) - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala new file mode 100644 index 0000000000000..b23045810a4f6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +/** + * A visitor pattern for traversing a [[LogicalPlan]] tree and compute some properties. + */ +trait LogicalPlanVisitor[T] { + + def visit(p: LogicalPlan): T = p match { + case p: Aggregate => visitAggregate(p) + case p: Distinct => visitDistinct(p) + case p: Except => visitExcept(p) + case p: Expand => visitExpand(p) + case p: Filter => visitFilter(p) + case p: Generate => visitGenerate(p) + case p: GlobalLimit => visitGlobalLimit(p) + case p: Intersect => visitIntersect(p) + case p: Join => visitJoin(p) + case p: LocalLimit => visitLocalLimit(p) + case p: Pivot => visitPivot(p) + case p: Project => visitProject(p) + case p: Range => visitRange(p) + case p: Repartition => visitRepartition(p) + case p: RepartitionByExpression => visitRepartitionByExpr(p) + case p: Sample => visitSample(p) + case p: ScriptTransformation => visitScriptTransform(p) + case p: Union => visitUnion(p) + case p: ResolvedHint => visitHint(p) + case p: LogicalPlan => default(p) + } + + def default(p: LogicalPlan): T + + def visitAggregate(p: Aggregate): T + + def visitDistinct(p: Distinct): T + + def visitExcept(p: Except): T + + def visitExpand(p: Expand): T + + def visitFilter(p: Filter): T + + def visitGenerate(p: Generate): T + + def visitGlobalLimit(p: GlobalLimit): T + + def visitHint(p: ResolvedHint): T + + def visitIntersect(p: Intersect): T + + def visitJoin(p: Join): T + + def visitLocalLimit(p: LocalLimit): T + + def visitPivot(p: Pivot): T + + def visitProject(p: Project): T + + def visitRange(p: Range): T + + def visitRepartition(p: Repartition): T + + def visitRepartitionByExpr(p: RepartitionByExpression): T + + def visitSample(p: Sample): T + + def visitScriptTransform(p: ScriptTransformation): T + + def visitUnion(p: Union): T +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index e89caabf252d7..0bd3166352d35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -63,14 +63,6 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - - override def computeStats: Statistics = { - if (conf.cboEnabled) { - ProjectEstimation.estimate(this).getOrElse(super.computeStats) - } else { - super.computeStats - } - } } /** @@ -137,14 +129,6 @@ case class Filter(condition: Expression, child: LogicalPlan) .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(predicates.toSet) } - - override def computeStats: Statistics = { - if (conf.cboEnabled) { - FilterEstimation(this).estimate.getOrElse(super.computeStats) - } else { - super.computeStats - } - } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -190,15 +174,6 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation Some(children.flatMap(_.maxRows).min) } } - - override def computeStats: Statistics = { - val leftSize = left.stats.sizeInBytes - val rightSize = right.stats.sizeInBytes - val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize - Statistics( - sizeInBytes = sizeInBytes, - hints = left.stats.hints.resetForJoin()) - } } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { @@ -207,10 +182,6 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override def output: Seq[Attribute] = left.output override protected def validConstraints: Set[Expression] = leftConstraints - - override def computeStats: Statistics = { - left.stats.copy() - } } /** Factory for constructing new `Union` nodes. */ @@ -247,11 +218,6 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override def computeStats: Statistics = { - val sizeInBytes = children.map(_.stats.sizeInBytes).sum - Statistics(sizeInBytes = sizeInBytes) - } - /** * Maps the constraints containing a given (original) sequence of attributes to those with a * given (reference) sequence of attributes. Given the nature of union, we expect that the @@ -355,25 +321,6 @@ case class Join( case UsingJoin(_, _) => false case _ => resolvedExceptNatural } - - override def computeStats: Statistics = { - def simpleEstimation: Statistics = joinType match { - case LeftAnti | LeftSemi => - // LeftSemi and LeftAnti won't ever be bigger than left - left.stats - case _ => - // Make sure we don't propagate isBroadcastable in other joins, because - // they could explode the size. - val stats = super.computeStats - stats.copy(hints = stats.hints.resetForJoin()) - } - - if (conf.cboEnabled) { - JoinEstimation.estimate(this).getOrElse(simpleEstimation) - } else { - simpleEstimation - } - } } /** @@ -522,14 +469,13 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def computeStats: Statistics = { - val sizeInBytes = LongType.defaultSize * numElements - Statistics( sizeInBytes = sizeInBytes ) - } - override def simpleString: String = { s"Range ($start, $end, step=$step, splits=$numSlices)" } + + override def computeStats(): Statistics = { + Statistics(sizeInBytes = LongType.defaultSize * numElements) + } } case class Aggregate( @@ -554,25 +500,6 @@ case class Aggregate( val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) child.constraints.union(getAliasedConstraints(nonAgg)) } - - override def computeStats: Statistics = { - def simpleEstimation: Statistics = { - if (groupingExpressions.isEmpty) { - Statistics( - sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), - rowCount = Some(1), - hints = child.stats.hints) - } else { - super.computeStats - } - } - - if (conf.cboEnabled) { - AggregateEstimation.estimate(this).getOrElse(simpleEstimation) - } else { - simpleEstimation - } - } } case class Window( @@ -671,11 +598,6 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def computeStats: Statistics = { - val sizeInBytes = super.computeStats.sizeInBytes * projections.length - Statistics(sizeInBytes = sizeInBytes) - } - // This operator can reuse attributes (for example making them null when doing a roll up) so // the constraints of the child may no longer be valid. override protected def validConstraints: Set[Expression] = Set.empty[Expression] @@ -742,16 +664,6 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override def computeStats: Statistics = { - val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats - val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) - // Don't propagate column stats, because we don't know the distribution after a limit operation - Statistics( - sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats), - rowCount = Some(rowCount), - hints = childStats.hints) - } } case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { @@ -762,24 +674,6 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override def computeStats: Statistics = { - val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats - if (limit == 0) { - // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero - // (product of children). - Statistics( - sizeInBytes = 1, - rowCount = Some(0), - hints = childStats.hints) - } else { - // The output row count of LocalLimit should be the sum of row counts from each partition. - // However, since the number of partitions is not available here, we just use statistics of - // the child. Because the distribution after a limit operation is unknown, we do not propagate - // the column stats. - childStats.copy(attributeStats = AttributeMap(Nil)) - } - } } /** @@ -828,18 +722,6 @@ case class Sample( } override def output: Seq[Attribute] = child.output - - override def computeStats: Statistics = { - val ratio = upperBound - lowerBound - val childStats = child.stats - var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) - if (sizeInBytes == 0) { - sizeInBytes = 1 - } - val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) - // Don't propagate column stats, because we don't know the distribution after a sample operation - Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints) - } } /** @@ -893,7 +775,7 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - override def computeStats: Statistics = Statistics(sizeInBytes = 1) + override def computeStats(): Statistics = Statistics(sizeInBytes = 1) } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 8479c702d7561..29a43528124d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -42,11 +42,6 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override def output: Seq[Attribute] = child.output override lazy val canonicalized: LogicalPlan = child.canonicalized - - override def computeStats: Statistics = { - val stats = child.stats - stats.copy(hints = hints) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala new file mode 100644 index 0000000000000..93908b04fb643 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.LongType + +/** + * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. + */ +object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { + + /** Falls back to the estimation computed by [[SizeInBytesOnlyStatsPlanVisitor]]. */ + private def fallback(p: LogicalPlan): Statistics = SizeInBytesOnlyStatsPlanVisitor.visit(p) + + override def default(p: LogicalPlan): Statistics = fallback(p) + + override def visitAggregate(p: Aggregate): Statistics = { + AggregateEstimation.estimate(p).getOrElse(fallback(p)) + } + + override def visitDistinct(p: Distinct): Statistics = fallback(p) + + override def visitExcept(p: Except): Statistics = fallback(p) + + override def visitExpand(p: Expand): Statistics = fallback(p) + + override def visitFilter(p: Filter): Statistics = { + FilterEstimation(p).estimate.getOrElse(fallback(p)) + } + + override def visitGenerate(p: Generate): Statistics = fallback(p) + + override def visitGlobalLimit(p: GlobalLimit): Statistics = fallback(p) + + override def visitHint(p: ResolvedHint): Statistics = fallback(p) + + override def visitIntersect(p: Intersect): Statistics = fallback(p) + + override def visitJoin(p: Join): Statistics = { + JoinEstimation.estimate(p).getOrElse(fallback(p)) + } + + override def visitLocalLimit(p: LocalLimit): Statistics = fallback(p) + + override def visitPivot(p: Pivot): Statistics = fallback(p) + + override def visitProject(p: Project): Statistics = { + ProjectEstimation.estimate(p).getOrElse(fallback(p)) + } + + override def visitRange(p: logical.Range): Statistics = { + val sizeInBytes = LongType.defaultSize * p.numElements + Statistics(sizeInBytes = sizeInBytes) + } + + override def visitRepartition(p: Repartition): Statistics = fallback(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p) + + override def visitSample(p: Sample): Statistics = fallback(p) + + override def visitScriptTransform(p: ScriptTransformation): Statistics = fallback(p) + + override def visitUnion(p: Union): Statistics = fallback(p) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/LogicalPlanStats.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/LogicalPlanStats.scala new file mode 100644 index 0000000000000..8660d93550192 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/LogicalPlanStats.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * A trait to add statistics propagation to [[LogicalPlan]]. + */ +trait LogicalPlanStats { self: LogicalPlan => + + /** + * Returns the estimated statistics for the current logical plan node. Under the hood, this + * method caches the return value, which is computed based on the configuration passed in the + * first time. If the configuration changes, the cache can be invalidated by calling + * [[invalidateStatsCache()]]. + */ + def stats: Statistics = statsCache.getOrElse { + if (conf.cboEnabled) { + statsCache = Option(BasicStatsPlanVisitor.visit(self)) + } else { + statsCache = Option(SizeInBytesOnlyStatsPlanVisitor.visit(self)) + } + statsCache.get + } + + /** A cache for the estimated statistics, such that it will only be computed once. */ + protected var statsCache: Option[Statistics] = None + + /** Invalidates the stats cache. See [[stats]] for more information. */ + final def invalidateStatsCache(): Unit = { + statsCache = None + children.foreach(_.invalidateStatsCache()) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala new file mode 100644 index 0000000000000..559f12072e448 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * An [[LogicalPlanVisitor]] that computes a single dimension for plan stats: size in bytes. + */ +object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { + + /** + * A default, commonly used estimation for unary nodes. We assume the input row number is the + * same as the output row number, and compute sizes based on the column types. + */ + private def visitUnaryNode(p: UnaryNode): Statistics = { + // There should be some overhead in Row object, the size should not be zero when there is + // no columns, this help to prevent divide-by-zero error. + val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 + val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + // Assume there will be the same number of rows as child has. + var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize + if (sizeInBytes == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + sizeInBytes = 1 + } + + // Don't propagate rowCount and attributeStats, since they are not estimated here. + Statistics(sizeInBytes = sizeInBytes, hints = p.child.stats.hints) + } + + /** + * For leaf nodes, use its computeStats. For other nodes, we assume the size in bytes is the + * sum of all of the children's. + */ + override def default(p: LogicalPlan): Statistics = p match { + case p: LeafNode => p.computeStats() + case _: LogicalPlan => Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).product) + } + + override def visitAggregate(p: Aggregate): Statistics = { + if (p.groupingExpressions.isEmpty) { + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(p.output, outputRowCount = 1), + rowCount = Some(1), + hints = p.child.stats.hints) + } else { + visitUnaryNode(p) + } + } + + override def visitDistinct(p: Distinct): Statistics = default(p) + + override def visitExcept(p: Except): Statistics = p.left.stats.copy() + + override def visitExpand(p: Expand): Statistics = { + val sizeInBytes = visitUnaryNode(p).sizeInBytes * p.projections.length + Statistics(sizeInBytes = sizeInBytes) + } + + override def visitFilter(p: Filter): Statistics = visitUnaryNode(p) + + override def visitGenerate(p: Generate): Statistics = default(p) + + override def visitGlobalLimit(p: GlobalLimit): Statistics = { + val limit = p.limitExpr.eval().asInstanceOf[Int] + val childStats = p.child.stats + val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) + // Don't propagate column stats, because we don't know the distribution after limit + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(p.output, rowCount, childStats.attributeStats), + rowCount = Some(rowCount), + hints = childStats.hints) + } + + override def visitHint(p: ResolvedHint): Statistics = p.child.stats.copy(hints = p.hints) + + override def visitIntersect(p: Intersect): Statistics = { + val leftSize = p.left.stats.sizeInBytes + val rightSize = p.right.stats.sizeInBytes + val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize + Statistics( + sizeInBytes = sizeInBytes, + hints = p.left.stats.hints.resetForJoin()) + } + + override def visitJoin(p: Join): Statistics = { + p.joinType match { + case LeftAnti | LeftSemi => + // LeftSemi and LeftAnti won't ever be bigger than left + p.left.stats + case _ => + // Make sure we don't propagate isBroadcastable in other joins, because + // they could explode the size. + val stats = default(p) + stats.copy(hints = stats.hints.resetForJoin()) + } + } + + override def visitLocalLimit(p: LocalLimit): Statistics = { + val limit = p.limitExpr.eval().asInstanceOf[Int] + val childStats = p.child.stats + if (limit == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + Statistics(sizeInBytes = 1, rowCount = Some(0), hints = childStats.hints) + } else { + // The output row count of LocalLimit should be the sum of row counts from each partition. + // However, since the number of partitions is not available here, we just use statistics of + // the child. Because the distribution after a limit operation is unknown, we do not propagate + // the column stats. + childStats.copy(attributeStats = AttributeMap(Nil)) + } + } + + override def visitPivot(p: Pivot): Statistics = default(p) + + override def visitProject(p: Project): Statistics = visitUnaryNode(p) + + override def visitRange(p: logical.Range): Statistics = { + p.computeStats() + } + + override def visitRepartition(p: Repartition): Statistics = default(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = default(p) + + override def visitSample(p: Sample): Statistics = { + val ratio = p.upperBound - p.lowerBound + var sizeInBytes = EstimationUtils.ceil(BigDecimal(p.child.stats.sizeInBytes) * ratio) + if (sizeInBytes == 0) { + sizeInBytes = 1 + } + val sampleRows = p.child.stats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) + // Don't propagate column stats, because we don't know the distribution after a sample operation + Statistics(sizeInBytes, sampleRows, hints = p.child.stats.hints) + } + + override def visitScriptTransform(p: ScriptTransformation): Statistics = default(p) + + override def visitUnion(p: Union): Statistics = { + Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).sum) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 912c5fed63450..31a8cbdee9777 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -77,37 +77,6 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { checkStats(globalLimit, stats) } - test("sample estimation") { - val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan) - checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) - - // Child doesn't have rowCount in stats - val childStats = Statistics(sizeInBytes = 120) - val childPlan = DummyLogicalPlan(childStats, childStats) - val sample2 = - Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan) - checkStats(sample2, Statistics(sizeInBytes = 14)) - } - - test("estimate statistics when the conf changes") { - val expectedDefaultStats = - Statistics( - sizeInBytes = 40, - rowCount = Some(10), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) - val expectedCboStats = - Statistics( - sizeInBytes = 4, - rowCount = Some(1), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) - - val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) - checkStats( - plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats) - } - /** Check estimated stats when cbo is turned on/off. */ private def checkStats( plan: LogicalPlan, @@ -132,16 +101,3 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit = checkStats(plan, expectedStats, expectedStats) } - -/** - * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes - * a simple statistics or a cbo estimated statistics based on the conf. - */ -private case class DummyLogicalPlan( - defaultStats: Statistics, - cboStats: Statistics) extends LogicalPlan { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil - override def computeStats: Statistics = - if (conf.cboEnabled) cboStats else defaultStats -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index eaa33e44a6a5a..31dea2e3e7f1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -65,7 +65,7 @@ case class StatsTestPlan( attributeStats: AttributeMap[ColumnStat], size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList - override def computeStats: Statistics = Statistics( + override def computeStats(): Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 66f66a289a065..dcb918eeb9d10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -88,7 +88,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats: Statistics = Statistics( + override def computeStats(): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -156,7 +156,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats: Statistics = Statistics( + override def computeStats(): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 2972132336de0..39cf8fcac5116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -69,7 +69,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override def computeStats: Statistics = { + override def computeStats(): Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 6ba190b9e5dcf..699f1bad9c4ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -48,9 +48,10 @@ case class LogicalRelation( output = output.map(QueryPlan.normalizeExprId(_, output)), catalogTable = None) - @transient override def computeStats: Statistics = { - catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( - Statistics(sizeInBytes = relation.sizeInBytes)) + override def computeStats(): Statistics = { + catalogTable + .flatMap(_.stats.map(_.toPlanStats(output))) + .getOrElse(Statistics(sizeInBytes = relation.sizeInBytes)) } /** Used to lookup original attribute capitalization */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 4979873ee3c7f..587ae2bfb63fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -230,6 +230,5 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def computeStats: Statistics = - Statistics(sizePerRow * sink.allData.size) + override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index 3a724aa14f2a9..94384185d190a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -86,7 +86,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te case relation: LogicalRelation => relation } assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") - val size2 = relations(0).computeStats.sizeInBytes + val size2 = relations(0).stats.sizeInBytes assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) assert(size2 < tableStats.get.sizeInBytes) } From 37ef32e515ea071afe63b56ba0d4299bb76e8a75 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Sat, 1 Jul 2017 14:57:57 +0800 Subject: [PATCH 063/779] [SPARK-21275][ML] Update GLM test to use supportedFamilyNames ## What changes were proposed in this pull request? Update GLM test to use supportedFamilyNames as suggested here: https://github.com/apache/spark/pull/16699#discussion-diff-100574976R855 Author: actuaryzhang Closes #18495 from actuaryzhang/mlGlmTest2. --- .../GeneralizedLinearRegressionSuite.scala | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 83f1344a7bcb1..a47bd17f47bb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -749,15 +749,15 @@ class GeneralizedLinearRegressionSuite library(statmod) y <- c(1.0, 0.5, 0.7, 0.3) w <- c(1, 2, 3, 4) - for (fam in list(gaussian(), poisson(), binomial(), Gamma(), tweedie(1.6))) { + for (fam in list(binomial(), Gamma(), gaussian(), poisson(), tweedie(1.6))) { model1 <- glm(y ~ 1, family = fam) model2 <- glm(y ~ 1, family = fam, weights = w) print(as.vector(c(coef(model1), coef(model2)))) } - [1] 0.625 0.530 - [1] -0.4700036 -0.6348783 [1] 0.5108256 0.1201443 [1] 1.600000 1.886792 + [1] 0.625 0.530 + [1] -0.4700036 -0.6348783 [1] 1.325782 1.463641 */ @@ -768,13 +768,13 @@ class GeneralizedLinearRegressionSuite Instance(0.3, 4.0, Vectors.zeros(0)) ).toDF() - val expected = Seq(0.625, 0.530, -0.4700036, -0.6348783, 0.5108256, 0.1201443, - 1.600000, 1.886792, 1.325782, 1.463641) + val expected = Seq(0.5108256, 0.1201443, 1.600000, 1.886792, 0.625, 0.530, + -0.4700036, -0.6348783, 1.325782, 1.463641) import GeneralizedLinearRegression._ var idx = 0 - for (family <- Seq("gaussian", "poisson", "binomial", "gamma", "tweedie")) { + for (family <- GeneralizedLinearRegression.supportedFamilyNames.sortWith(_ < _)) { for (useWeight <- Seq(false, true)) { val trainer = new GeneralizedLinearRegression().setFamily(family) if (useWeight) trainer.setWeightCol("weight") @@ -807,7 +807,7 @@ class GeneralizedLinearRegressionSuite 0.5, 2.1, 0.5, 1.0, 2.0, 0.9, 0.4, 1.0, 2.0, 1.0, 0.7, 0.7, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) - families <- list(gaussian, binomial, poisson, Gamma, tweedie(1.5)) + families <- list(binomial, Gamma, gaussian, poisson, tweedie(1.5)) f1 <- V1 ~ -1 + V4 + V5 f2 <- V1 ~ V4 + V5 for (f in c(f1, f2)) { @@ -816,15 +816,15 @@ class GeneralizedLinearRegressionSuite print(as.vector(coef(model))) } } - [1] 0.5169222 -0.3344444 [1] 0.9419107 -0.6864404 - [1] 0.1812436 -0.6568422 [1] -0.2869094 0.7857710 + [1] 0.5169222 -0.3344444 + [1] 0.1812436 -0.6568422 [1] 0.1055254 0.2979113 - [1] -0.05990345 0.53188982 -0.32118415 [1] -0.2147117 0.9911750 -0.6356096 - [1] -1.5616130 0.6646470 -0.3192581 [1] 0.3390397 -0.3406099 0.6870259 + [1] -0.05990345 0.53188982 -0.32118415 + [1] -1.5616130 0.6646470 -0.3192581 [1] 0.3665034 0.1039416 0.1484616 */ val dataset = Seq( @@ -835,23 +835,22 @@ class GeneralizedLinearRegressionSuite ).toDF() val expected = Seq( - Vectors.dense(0, 0.5169222, -0.3344444), Vectors.dense(0, 0.9419107, -0.6864404), - Vectors.dense(0, 0.1812436, -0.6568422), Vectors.dense(0, -0.2869094, 0.785771), + Vectors.dense(0, 0.5169222, -0.3344444), + Vectors.dense(0, 0.1812436, -0.6568422), Vectors.dense(0, 0.1055254, 0.2979113), - Vectors.dense(-0.05990345, 0.53188982, -0.32118415), Vectors.dense(-0.2147117, 0.991175, -0.6356096), - Vectors.dense(-1.561613, 0.664647, -0.3192581), Vectors.dense(0.3390397, -0.3406099, 0.6870259), + Vectors.dense(-0.05990345, 0.53188982, -0.32118415), + Vectors.dense(-1.561613, 0.664647, -0.3192581), Vectors.dense(0.3665034, 0.1039416, 0.1484616)) import GeneralizedLinearRegression._ var idx = 0 - for (fitIntercept <- Seq(false, true)) { - for (family <- Seq("gaussian", "binomial", "poisson", "gamma", "tweedie")) { + for (family <- GeneralizedLinearRegression.supportedFamilyNames.sortWith(_ < _)) { val trainer = new GeneralizedLinearRegression().setFamily(family) .setFitIntercept(fitIntercept).setOffsetCol("offset") .setWeightCol("weight").setLinkPredictionCol("linkPrediction") From e0b047eafed92eadf6842a9df964438095e12d41 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 1 Jul 2017 15:37:41 +0800 Subject: [PATCH 064/779] [SPARK-18518][ML] HasSolver supports override ## What changes were proposed in this pull request? 1, make param support non-final with `finalFields` option 2, generate `HasSolver` with `finalFields = false` 3, override `solver` in LiR, GLR, and make MLPC inherit `HasSolver` ## How was this patch tested? existing tests Author: Ruifeng Zheng Author: Zheng RuiFeng Closes #16028 from zhengruifeng/param_non_final. --- .../MultilayerPerceptronClassifier.scala | 19 ++++---- .../ml/param/shared/SharedParamsCodeGen.scala | 11 +++-- .../spark/ml/param/shared/sharedParams.scala | 8 ++-- .../GeneralizedLinearRegression.scala | 21 ++++++++- .../ml/regression/LinearRegression.scala | 46 +++++++++++++++---- python/pyspark/ml/classification.py | 18 +------- python/pyspark/ml/regression.py | 5 ++ 7 files changed, 82 insertions(+), 46 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index ec39f964e213a..ceba11edc93be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -27,13 +27,16 @@ import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset /** Params for Multilayer Perceptron. */ private[classification] trait MultilayerPerceptronParams extends PredictorParams - with HasSeed with HasMaxIter with HasTol with HasStepSize { + with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver { + + import MultilayerPerceptronClassifier._ + /** * Layer sizes including input size and output size. * @@ -78,14 +81,10 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams * @group expertParam */ @Since("2.0.0") - final val solver: Param[String] = new Param[String](this, "solver", + final override val solver: Param[String] = new Param[String](this, "solver", "The solver algorithm for optimization. Supported options: " + - s"${MultilayerPerceptronClassifier.supportedSolvers.mkString(", ")}. (Default l-bfgs)", - ParamValidators.inArray[String](MultilayerPerceptronClassifier.supportedSolvers)) - - /** @group expertGetParam */ - @Since("2.0.0") - final def getSolver: String = $(solver) + s"${supportedSolvers.mkString(", ")}. (Default l-bfgs)", + ParamValidators.inArray[String](supportedSolvers)) /** * The initial weights of the model. @@ -101,7 +100,7 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams final def getInitialWeights: Vector = $(initialWeights) setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128, - solver -> MultilayerPerceptronClassifier.LBFGS, stepSize -> 0.03) + solver -> LBFGS, stepSize -> 0.03) } /** Label to vector converter. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 013817a41baf5..23e0d45d943a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -80,8 +80,7 @@ private[shared] object SharedParamsCodeGen { " 0)", isValid = "ParamValidators.gt(0)"), ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + "all instance weights as 1.0"), - ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " + - "empty, default value is 'auto'", Some("\"auto\"")), + ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) @@ -99,6 +98,7 @@ private[shared] object SharedParamsCodeGen { defaultValueStr: Option[String] = None, isValid: String = "", finalMethods: Boolean = true, + finalFields: Boolean = true, isExpertParam: Boolean = false) { require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") @@ -167,6 +167,11 @@ private[shared] object SharedParamsCodeGen { } else { "def" } + val fieldStr = if (param.finalFields) { + "final val" + } else { + "val" + } val htmlCompliantDoc = Utility.escape(doc) @@ -180,7 +185,7 @@ private[shared] object SharedParamsCodeGen { | * Param for $htmlCompliantDoc. | * @group ${groupStr(0)} | */ - | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) + | $fieldStr $name: $Param = new $Param(this, "$name", "$doc"$isValid) |$setDefault | /** @group ${groupStr(1)} */ | $methodStr get$Name: $T = $$($name) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 50619607a5054..1a8f499798b80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -374,17 +374,15 @@ private[ml] trait HasWeightCol extends Params { } /** - * Trait for shared param solver (default: "auto"). + * Trait for shared param solver. */ private[ml] trait HasSolver extends Params { /** - * Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + * Param for the solver algorithm for optimization. * @group param */ - final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'") - - setDefault(solver, "auto") + val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization") /** @group getParam */ final def getSolver: String = $(solver) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index ce3460ae43566..c600b87bdc64a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -164,7 +164,18 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty } - import GeneralizedLinearRegression._ + /** + * The solver algorithm for optimization. + * Supported options: "irls" (iteratively reweighted least squares). + * Default: "irls" + * + * @group param + */ + @Since("2.3.0") + final override val solver: Param[String] = new Param[String](this, "solver", + "The solver algorithm for optimization. Supported options: " + + s"${supportedSolvers.mkString(", ")}. (Default irls)", + ParamValidators.inArray[String](supportedSolvers)) @Since("2.0.0") override def validateAndTransformSchema( @@ -350,7 +361,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.0.0") def setSolver(value: String): this.type = set(solver, value) - setDefault(solver -> "irls") + setDefault(solver -> IRLS) /** * Sets the link prediction (linear predictor) column name. @@ -442,6 +453,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine Gamma -> Inverse, Gamma -> Identity, Gamma -> Log ) + /** String name for "irls" (iteratively reweighted least squares) solver. */ + private[regression] val IRLS = "irls" + + /** Set of solvers that GeneralizedLinearRegression supports. */ + private[regression] val supportedSolvers = Array(IRLS) + /** Set of family names that GeneralizedLinearRegression supports. */ private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie" diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index db5ac4f14bd3b..ce5e0797915df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics @@ -53,7 +53,23 @@ import org.apache.spark.storage.StorageLevel private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver - with HasAggregationDepth + with HasAggregationDepth { + + import LinearRegression._ + + /** + * The solver algorithm for optimization. + * Supported options: "l-bfgs", "normal" and "auto". + * Default: "auto" + * + * @group param + */ + @Since("2.3.0") + final override val solver: Param[String] = new Param[String](this, "solver", + "The solver algorithm for optimization. Supported options: " + + s"${supportedSolvers.mkString(", ")}. (Default auto)", + ParamValidators.inArray[String](supportedSolvers)) +} /** * Linear regression. @@ -78,6 +94,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with DefaultParamsWritable with Logging { + import LinearRegression._ + @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -175,12 +193,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String * @group setParam */ @Since("1.6.0") - def setSolver(value: String): this.type = { - require(Set("auto", "l-bfgs", "normal").contains(value), - s"Solver $value was not supported. Supported options: auto, l-bfgs, normal") - set(solver, value) - } - setDefault(solver -> "auto") + def setSolver(value: String): this.type = set(solver, value) + setDefault(solver -> AUTO) /** * Suggested depth for treeAggregate (greater than or equal to 2). @@ -210,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth) instr.logNumFeatures(numFeatures) - if (($(solver) == "auto" && - numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { + if (($(solver) == AUTO && + numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) { // For low dimensional data, WeightedLeastSquares is more efficient since the // training algorithm only requires one pass through the data. (SPARK-10668) @@ -444,6 +458,18 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { */ @Since("2.1.0") val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES + + /** String name for "auto". */ + private[regression] val AUTO = "auto" + + /** String name for "normal". */ + private[regression] val NORMAL = "normal" + + /** String name for "l-bfgs". */ + private[regression] val LBFGS = "l-bfgs" + + /** Set of solvers that LinearRegression supports. */ + private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS) } /** diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 9b345ac73f3d9..948806a5c936c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1265,8 +1265,8 @@ def theta(self): @inherit_doc class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable, - JavaMLReadable): + HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver, + JavaMLWritable, JavaMLReadable): """ Classifier trainer based on the Multilayer Perceptron. Each layer has sigmoid activation function, output layer has softmax. @@ -1407,20 +1407,6 @@ def getStepSize(self): """ return self.getOrDefault(self.stepSize) - @since("2.0.0") - def setSolver(self, value): - """ - Sets the value of :py:attr:`solver`. - """ - return self._set(solver=value) - - @since("2.0.0") - def getSolver(self): - """ - Gets the value of solver or its default value. - """ - return self.getOrDefault(self.solver) - @since("2.0.0") def setInitialWeights(self, value): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 2d17f95b0c44f..84d843369e105 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -95,6 +95,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction .. versionadded:: 1.4.0 """ + solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + + "options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, @@ -1371,6 +1374,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " + "Only applicable to the Tweedie family.", typeConverter=TypeConverters.toFloat) + solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + + "options: irls.", typeConverter=TypeConverters.toString) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", From 6beca9ce94f484de2f9ffb946bef8334781b3122 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Sat, 1 Jul 2017 15:53:49 +0100 Subject: [PATCH 065/779] [SPARK-21170][CORE] Utils.tryWithSafeFinallyAndFailureCallbacks throws IllegalArgumentException: Self-suppression not permitted ## What changes were proposed in this pull request? Not adding the exception to the suppressed if it is the same instance as originalThrowable. ## How was this patch tested? Added new tests to verify this, these tests fail without source code changes and passes with the change. Author: Devaraj K Closes #18384 from devaraj-kavali/SPARK-21170. --- .../scala/org/apache/spark/util/Utils.scala | 30 +++---- .../org/apache/spark/util/UtilsSuite.scala | 88 ++++++++++++++++++- 2 files changed, 99 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bbb7999e2a144..26f61e25da4d3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1348,14 +1348,10 @@ private[spark] object Utils extends Logging { try { finallyBlock } catch { - case t: Throwable => - if (originalThrowable != null) { - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in finally: " + t.getMessage, t) - throw originalThrowable - } else { - throw t - } + case t: Throwable if (originalThrowable != null && originalThrowable != t) => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) + throw originalThrowable } } } @@ -1387,22 +1383,20 @@ private[spark] object Utils extends Logging { catchBlock } catch { case t: Throwable => - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in catch: " + t.getMessage, t) + if (originalThrowable != t) { + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in catch: ${t.getMessage}", t) + } } throw originalThrowable } finally { try { finallyBlock } catch { - case t: Throwable => - if (originalThrowable != null) { - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in finally: " + t.getMessage, t) - throw originalThrowable - } else { - throw t - } + case t: Throwable if (originalThrowable != null && originalThrowable != t) => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) + throw originalThrowable } } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index f7bc8f888b0d5..4ce143f18bbf1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -38,7 +38,7 @@ import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit @@ -1024,4 +1024,90 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(redactedConf("spark.sensitive.property") === Utils.REDACTION_REPLACEMENT_TEXT) } + + test("tryWithSafeFinally") { + var e = new Error("Block0") + val finallyBlockError = new Error("Finally Block") + var isErrorOccurred = false + // if the try and finally blocks throw different exception instances + try { + Utils.tryWithSafeFinally { throw e }(finallyBlock = { throw finallyBlockError }) + } catch { + case t: Error => + assert(t.getSuppressed.head == finallyBlockError) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try and finally blocks throw the same exception instance then it should not + // try to add to suppressed and get IllegalArgumentException + e = new Error("Block1") + isErrorOccurred = false + try { + Utils.tryWithSafeFinally { throw e }(finallyBlock = { throw e }) + } catch { + case t: Error => + assert(t.getSuppressed.length == 0) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try throws the exception and finally doesn't throw exception + e = new Error("Block2") + isErrorOccurred = false + try { + Utils.tryWithSafeFinally { throw e }(finallyBlock = {}) + } catch { + case t: Error => + assert(t.getSuppressed.length == 0) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try and finally block don't throw exception + Utils.tryWithSafeFinally {}(finallyBlock = {}) + } + + test("tryWithSafeFinallyAndFailureCallbacks") { + var e = new Error("Block0") + val catchBlockError = new Error("Catch Block") + val finallyBlockError = new Error("Finally Block") + var isErrorOccurred = false + TaskContext.setTaskContext(TaskContext.empty()) + // if the try, catch and finally blocks throw different exception instances + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { throw e }( + catchBlock = { throw catchBlockError }, finallyBlock = { throw finallyBlockError }) + } catch { + case t: Error => + assert(t.getSuppressed.head == catchBlockError) + assert(t.getSuppressed.last == finallyBlockError) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try, catch and finally blocks throw the same exception instance then it should not + // try to add to suppressed and get IllegalArgumentException + e = new Error("Block1") + isErrorOccurred = false + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { throw e }(catchBlock = { throw e }, + finallyBlock = { throw e }) + } catch { + case t: Error => + assert(t.getSuppressed.length == 0) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try throws the exception, catch and finally don't throw exceptions + e = new Error("Block2") + isErrorOccurred = false + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { throw e }(catchBlock = {}, finallyBlock = {}) + } catch { + case t: Error => + assert(t.getSuppressed.length == 0) + isErrorOccurred = true + } + assert(isErrorOccurred) + // if the try, catch and finally blocks don't throw exceptions + Utils.tryWithSafeFinallyAndFailureCallbacks {}(catchBlock = {}, finallyBlock = {}) + TaskContext.unset + } } From c605fee01f180588ecb2f48710a7b84073bd3b9a Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 2 Jul 2017 08:50:48 +0100 Subject: [PATCH 066/779] [SPARK-21260][SQL][MINOR] Remove the unused OutputFakerExec ## What changes were proposed in this pull request? OutputFakerExec was added long ago and is not used anywhere now so we should remove it. ## How was this patch tested? N/A Author: Xingbo Jiang Closes #18473 from jiangxb1987/OutputFakerExec. --- .../spark/sql/execution/basicPhysicalOperators.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index f3ca8397047fe..2151c339b9b87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -584,17 +584,6 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } } -/** - * A plan node that does nothing but lie about the output of its child. Used to spice a - * (hopefully structurally equivalent) tree from a different optimization sequence into an already - * resolved tree. - */ -case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { - def children: Seq[SparkPlan] = child :: Nil - - protected override def doExecute(): RDD[InternalRow] = child.execute() -} - /** * Physical plan for a subquery. */ From c19680be1c532dded1e70edce7a981ba28af09ad Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 2 Jul 2017 16:17:03 +0800 Subject: [PATCH 067/779] [SPARK-19852][PYSPARK][ML] Python StringIndexer supports 'keep' to handle invalid data ## What changes were proposed in this pull request? This PR is to maintain API parity with changes made in SPARK-17498 to support a new option 'keep' in StringIndexer to handle unseen labels or NULL values with PySpark. Note: This is updated version of #17237 , the primary author of this PR is VinceShieh . ## How was this patch tested? Unit tests. Author: VinceShieh Author: Yanbo Liang Closes #18453 from yanboliang/spark-19852. --- python/pyspark/ml/feature.py | 6 ++++++ python/pyspark/ml/tests.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 77de1cc18246d..25ad06f682ed9 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2132,6 +2132,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", typeConverter=TypeConverters.toString) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "labels or NULL values). Options are 'skip' (filter out rows with " + + "invalid data), error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 17a39472e1fe5..ffb8b0a890ff8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -551,6 +551,27 @@ def test_rformula_string_indexer_order_type(self): for i in range(0, len(expected)): self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + def test_string_indexer_handle_invalid(self): + df = self.spark.createDataFrame([ + (0, "a"), + (1, "d"), + (2, None)], ["id", "label"]) + + si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", + stringOrderType="alphabetAsc") + model1 = si1.fit(df) + td1 = model1.transform(df) + actual1 = td1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] + self.assertEqual(actual1, expected1) + + si2 = si1.setHandleInvalid("skip") + model2 = si2.fit(df) + td2 = model2.transform(df) + actual2 = td2.select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] + self.assertEqual(actual2, expected2) + class HasInducedError(Params): From d4107196d59638845bd19da6aab074424d90ddaf Mon Sep 17 00:00:00 2001 From: Rui Zha Date: Sun, 2 Jul 2017 17:37:47 -0700 Subject: [PATCH 068/779] [SPARK-18004][SQL] Make sure the date or timestamp related predicate can be pushed down to Oracle correctly ## What changes were proposed in this pull request? Move `compileValue` method in JDBCRDD to JdbcDialect, and override the `compileValue` method in OracleDialect to rewrite the Oracle-specific timestamp and date literals in where clause. ## How was this patch tested? An integration test has been added. Author: Rui Zha Author: Zharui Closes #18451 from SharpRay/extend-compileValue-to-dialects. --- .../sql/jdbc/OracleIntegrationSuite.scala | 45 +++++++++++++++++++ .../execution/datasources/jdbc/JDBCRDD.scala | 35 +++++---------- .../apache/spark/sql/jdbc/JdbcDialects.scala | 27 ++++++++++- .../apache/spark/sql/jdbc/OracleDialect.scala | 15 ++++++- 4 files changed, 95 insertions(+), 27 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index b2f096964427e..e14810a32edc6 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -223,4 +223,49 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo val types = rows(0).toSeq.map(x => x.getClass.toString) assert(types(1).equals("class java.sql.Timestamp")) } + + test("SPARK-18004: Make sure date or timestamp related predicate is pushed down correctly") { + val props = new Properties() + props.put("oracle.jdbc.mapDateToTimestamp", "false") + + val schema = StructType(Seq( + StructField("date_type", DateType, true), + StructField("timestamp_type", TimestampType, true) + )) + + val tableName = "test_date_timestamp_pushdown" + val dateVal = Date.valueOf("2017-06-22") + val timestampVal = Timestamp.valueOf("2017-06-22 21:30:07") + + val data = spark.sparkContext.parallelize(Seq( + Row(dateVal, timestampVal) + )) + + val dfWrite = spark.createDataFrame(data, schema) + dfWrite.write.jdbc(jdbcUrl, tableName, props) + + val dfRead = spark.read.jdbc(jdbcUrl, tableName, props) + + val millis = System.currentTimeMillis() + val dt = new java.sql.Date(millis) + val ts = new java.sql.Timestamp(millis) + + // Query Oracle table with date and timestamp predicates + // which should be pushed down to Oracle. + val df = dfRead.filter(dfRead.col("date_type").lt(dt)) + .filter(dfRead.col("timestamp_type").lt(ts)) + + val metadata = df.queryExecution.sparkPlan.metadata + // The "PushedFilters" part should be exist in Datafrome's + // physical plan and the existence of right literals in + // "PushedFilters" is used to prove that the predicates + // pushing down have been effective. + assert(metadata.get("PushedFilters").ne(None)) + assert(metadata("PushedFilters").contains(dt.toString)) + assert(metadata("PushedFilters").contains(ts.toString)) + + val row = df.collect()(0) + assert(row.getDate(0).equals(dateVal)) + assert(row.getTimestamp(1).equals(timestampVal)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 2bdc43254133e..0f53b5c7c6f0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp} +import java.sql.{Connection, PreparedStatement, ResultSet, SQLException} import scala.util.control.NonFatal -import org.apache.commons.lang3.StringUtils - import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -86,20 +84,6 @@ object JDBCRDD extends Logging { new StructType(columns.map(name => fieldMap(name))) } - /** - * Converts value to SQL expression. - */ - private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" - case timestampValue: Timestamp => "'" + timestampValue + "'" - case dateValue: Date => "'" + dateValue + "'" - case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") - case _ => value - } - - private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") - /** * Turns a single Filter into a String representing a SQL expression. * Returns None for an unhandled filter. @@ -108,15 +92,16 @@ object JDBCRDD extends Logging { def quote(colName: String): String = dialect.quoteIdentifier(colName) Option(f match { - case EqualTo(attr, value) => s"${quote(attr)} = ${compileValue(value)}" + case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}" case EqualNullSafe(attr, value) => val col = quote(attr) - s"(NOT ($col != ${compileValue(value)} OR $col IS NULL OR " + - s"${compileValue(value)} IS NULL) OR ($col IS NULL AND ${compileValue(value)} IS NULL))" - case LessThan(attr, value) => s"${quote(attr)} < ${compileValue(value)}" - case GreaterThan(attr, value) => s"${quote(attr)} > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${compileValue(value)}" + s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " + + s"${dialect.compileValue(value)} IS NULL) OR " + + s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))" + case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}" + case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}" + case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}" case IsNull(attr) => s"${quote(attr)} IS NULL" case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL" case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'" @@ -124,7 +109,7 @@ object JDBCRDD extends Logging { case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'" case In(attr, value) if value.isEmpty => s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END" - case In(attr, value) => s"${quote(attr)} IN (${compileValue(value)})" + case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})" case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null) case Or(f1, f2) => // We can't compile Or filter unless both sub-filters are compiled successfully. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a86a86d408906..7c38ed68c0413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.jdbc -import java.sql.Connection +import java.sql.{Connection, Date, Timestamp} + +import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.sql.types._ @@ -123,6 +125,29 @@ abstract class JdbcDialect extends Serializable { def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } + /** + * Escape special characters in SQL string literals. + * @param value The string to be escaped. + * @return Escaped string. + */ + @Since("2.3.0") + protected[jdbc] def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + + /** + * Converts value to SQL expression. + * @param value The value to be converted. + * @return Converted value. + */ + @Since("2.3.0") + def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 20e634c06b610..3b44c1de93a61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Date, Timestamp, Types} import org.apache.spark.sql.types._ @@ -64,5 +64,18 @@ private case object OracleDialect extends JdbcDialect { case _ => None } + override def compileValue(value: Any): Any = value match { + // The JDBC drivers support date literals in SQL statements written in the + // format: {d 'yyyy-mm-dd'} and timestamp literals in SQL statements written + // in the format: {ts 'yyyy-mm-dd hh:mm:ss.f...'}. For details, see + // 'Oracle Database JDBC Developer’s Guide and Reference, 11g Release 1 (11.1)' + // Appendix A Reference Information. + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "{ts '" + timestampValue + "'}" + case dateValue: Date => "{d '" + dateValue + "'}" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } From d913db16a0de0983961f9d0c5f9b146be7226ac1 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 3 Jul 2017 13:31:01 +0800 Subject: [PATCH 069/779] [SPARK-21250][WEB-UI] Add a url in the table of 'Running Executors' in worker page to visit job page. ## What changes were proposed in this pull request? Add a url in the table of 'Running Executors' in worker page to visit job page. When I click URL of 'Name', the current page jumps to the job page. Of course this is only in the table of 'Running Executors'. This URL of 'Name' is in the table of 'Finished Executors' does not exist, the click will not jump to any page. fix before: ![1](https://user-images.githubusercontent.com/26266482/27679397-30ddc262-5ceb-11e7-839b-0889d1f42480.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/27679405-3588ef12-5ceb-11e7-9756-0a93815cd698.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #18464 from guoxiaolongzte/SPARK-21250. --- .../apache/spark/deploy/worker/ui/WorkerPage.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 1ad973122b609..ea39b0dce0a41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -23,8 +23,8 @@ import scala.xml.Node import org.json4s.JValue +import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} -import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.ui.{UIUtils, WebUIPage} @@ -112,7 +112,15 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
  • ID: {executor.appId}
  • -
  • Name: {executor.appDesc.name}
  • +
  • Name: + { + if ({executor.state == ExecutorState.RUNNING} && executor.appDesc.appUiUrl.nonEmpty) { + {executor.appDesc.name} + } else { + {executor.appDesc.name} + } + } +
  • User: {executor.appDesc.user}
From a9339db99f0620d4828eb903523be55dfbf2fb64 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 3 Jul 2017 19:52:39 +0800 Subject: [PATCH 070/779] [SPARK-21137][CORE] Spark reads many small files slowly ## What changes were proposed in this pull request? Parallelize FileInputFormat.listStatus in Hadoop API via LIST_STATUS_NUM_THREADS to speed up examination of file sizes for wholeTextFiles et al ## How was this patch tested? Existing tests, which will exercise the key path here: using a local file system. Author: Sean Owen Closes #18441 from srowen/SPARK-21137. --- .../main/scala/org/apache/spark/rdd/BinaryFileRDD.scala | 7 ++++++- .../main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 50d977a92da51..a14bad47dfe10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.spark.{Partition, SparkContext} @@ -35,8 +36,12 @@ private[spark] class BinaryFileRDD[T]( extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance val conf = getConf + // setMinPartitions below will call FileInputFormat.listStatus(), which can be quite slow when + // traversing a large number of directories and files. Parallelize it. + conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, + Runtime.getRuntime.availableProcessors().toString) + val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala index 8e1baae796fc5..9f3d0745c33c9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.{Text, Writable} import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.spark.{Partition, SparkContext} @@ -38,8 +39,12 @@ private[spark] class WholeTextFileRDD( extends NewHadoopRDD[Text, Text](sc, inputFormatClass, keyClass, valueClass, conf) { override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance val conf = getConf + // setMinPartitions below will call FileInputFormat.listStatus(), which can be quite slow when + // traversing a large number of directories and files. Parallelize it. + conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, + Runtime.getRuntime.availableProcessors().toString) + val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => configurable.setConf(conf) From eb7a5a66bbd5837c01f13c76b68de2a6034976f3 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 3 Jul 2017 09:01:42 -0700 Subject: [PATCH 071/779] [TEST] Load test table based on case sensitivity ## What changes were proposed in this pull request? It is strange that we will get "table not found" error if **the first sql** uses upper case table names, when developers write tests with `TestHiveSingleton`, **although case insensitivity**. This is because in `TestHiveQueryExecution`, test tables are loaded based on exact matching instead of case sensitivity. ## How was this patch tested? Added a new test case. Author: Zhenhua Wang Closes #18504 from wzhfy/testHive. --- .../apache/spark/sql/hive/test/TestHive.scala | 7 ++- .../apache/spark/sql/hive/TestHiveSuite.scala | 45 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 4e1792321c89b..801f9b9923641 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -449,6 +449,8 @@ private[hive] class TestHiveSparkSession( private val loadedTables = new collection.mutable.HashSet[String] + def getLoadedTables: collection.mutable.HashSet[String] = loadedTables + def loadTestTable(name: String) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infinite mutually recursive table loading. @@ -553,7 +555,10 @@ private[hive] class TestHiveQueryExecution( val referencedTables = describedTables ++ logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } - val referencedTestTables = referencedTables.filter(sparkSession.testTables.contains) + val resolver = sparkSession.sessionState.conf.resolver + val referencedTestTables = sparkSession.testTables.keys.filter { testTable => + referencedTables.exists(resolver(_, testTable)) + } logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala new file mode 100644 index 0000000000000..193fa83dbad99 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.test.{TestHiveSingleton, TestHiveSparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + + +class TestHiveSuite extends TestHiveSingleton with SQLTestUtils { + test("load test table based on case sensitivity") { + val testHiveSparkSession = spark.asInstanceOf[TestHiveSparkSession] + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql("SELECT * FROM SRC").queryExecution.analyzed + assert(testHiveSparkSession.getLoadedTables.contains("src")) + assert(testHiveSparkSession.getLoadedTables.size == 1) + } + testHiveSparkSession.reset() + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val err = intercept[AnalysisException] { + sql("SELECT * FROM SRC").queryExecution.analyzed + } + assert(err.message.contains("Table or view not found")) + } + testHiveSparkSession.reset() + } +} From 17bdc36ef16a544b693c628db276fe32db87fe7a Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Mon, 3 Jul 2017 09:35:49 -0700 Subject: [PATCH 072/779] [SPARK-21102][SQL] Refresh command is too aggressive in parsing ### Idea This PR adds validation to REFRESH sql statements. Currently, users can specify whatever they want as resource path. For example, spark.sql("REFRESH ! $ !") will be executed without any exceptions. ### Implementation I am not sure that my current implementation is the most optimal, so any feedback is appreciated. My first idea was to make the grammar as strict as possible. Unfortunately, there were some problems. I tried the approach below: SqlBase.g4 ``` ... | REFRESH TABLE tableIdentifier #refreshTable | REFRESH resourcePath #refreshResource ... resourcePath : STRING | (IDENTIFIER | number | nonReserved | '/' | '-')+ // other symbols can be added if needed ; ``` It is not flexible enough and requires to explicitly mention all possible symbols. Therefore, I came up with the current approach that is implemented in the code. Let me know your opinion on which one is better. Author: aokolnychyi Closes #18368 from aokolnychyi/spark-21102. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../spark/sql/execution/SparkSqlParser.scala | 20 +++++++++++++++--- .../sql/execution/SparkSqlParserSuite.scala | 21 ++++++++++++++++++- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 7ffa150096333..29f554451ed4a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -149,7 +149,7 @@ statement | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? tableIdentifier partitionSpec? describeColName? #describeTable | REFRESH TABLE tableIdentifier #refreshTable - | REFRESH .*? #refreshResource + | REFRESH (STRING | .*?) #refreshResource | CACHE LAZY? TABLE tableIdentifier (AS? query)? #cacheTable | UNCACHE TABLE (IF EXISTS)? tableIdentifier #uncacheTable | CLEAR CACHE #clearCache diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3c58c6e1b6780..2b79eb5eac0f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -230,11 +230,25 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create a [[RefreshTable]] logical plan. + * Create a [[RefreshResource]] logical plan. */ override def visitRefreshResource(ctx: RefreshResourceContext): LogicalPlan = withOrigin(ctx) { - val resourcePath = remainder(ctx.REFRESH.getSymbol).trim - RefreshResource(resourcePath) + val path = if (ctx.STRING != null) string(ctx.STRING) else extractUnquotedResourcePath(ctx) + RefreshResource(path) + } + + private def extractUnquotedResourcePath(ctx: RefreshResourceContext): String = withOrigin(ctx) { + val unquotedPath = remainder(ctx.REFRESH.getSymbol).trim + validate( + unquotedPath != null && !unquotedPath.isEmpty, + "Resource paths cannot be empty in REFRESH statements. Use / to match everything", + ctx) + val forbiddenSymbols = Seq(" ", "\n", "\r", "\t") + validate( + !forbiddenSymbols.exists(unquotedPath.contains(_)), + "REFRESH statements cannot contain ' ', '\\n', '\\r', '\\t' inside unquoted resource paths", + ctx) + unquotedPath } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index bd9c2ebd6fab9..d238c76fbeeff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.CreateTable +import org.apache.spark.sql.execution.datasources.{CreateTable, RefreshResource} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -66,6 +66,25 @@ class SparkSqlParserSuite extends AnalysisTest { } } + test("refresh resource") { + assertEqual("REFRESH prefix_path", RefreshResource("prefix_path")) + assertEqual("REFRESH /", RefreshResource("/")) + assertEqual("REFRESH /path///a", RefreshResource("/path///a")) + assertEqual("REFRESH pat1h/112/_1a", RefreshResource("pat1h/112/_1a")) + assertEqual("REFRESH pat1h/112/_1a/a-1", RefreshResource("pat1h/112/_1a/a-1")) + assertEqual("REFRESH path-with-dash", RefreshResource("path-with-dash")) + assertEqual("REFRESH \'path with space\'", RefreshResource("path with space")) + assertEqual("REFRESH \"path with space 2\"", RefreshResource("path with space 2")) + intercept("REFRESH a b", "REFRESH statements cannot contain") + intercept("REFRESH a\tb", "REFRESH statements cannot contain") + intercept("REFRESH a\nb", "REFRESH statements cannot contain") + intercept("REFRESH a\rb", "REFRESH statements cannot contain") + intercept("REFRESH a\r\nb", "REFRESH statements cannot contain") + intercept("REFRESH @ $a$", "REFRESH statements cannot contain") + intercept("REFRESH ", "Resource paths cannot be empty in REFRESH statements") + intercept("REFRESH", "Resource paths cannot be empty in REFRESH statements") + } + test("show functions") { assertEqual("show functions", ShowFunctionsCommand(None, None, true, true)) assertEqual("show all functions", ShowFunctionsCommand(None, None, true, true)) From 363bfe30ba44852a8fac946a37032f76480f6f1b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 3 Jul 2017 10:14:03 -0700 Subject: [PATCH 073/779] [SPARK-20073][SQL] Prints an explicit warning message in case of NULL-safe equals ## What changes were proposed in this pull request? This pr added code to print the same warning messages with `===` cases when using NULL-safe equals (`<=>`). ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #18436 from maropu/SPARK-20073. --- .../src/main/scala/org/apache/spark/sql/Column.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7e1f1d83cb3de..bd1669b6dba69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -464,7 +464,15 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <=> (other: Any): Column = withExpr { EqualNullSafe(expr, lit(other).expr) } + def <=> (other: Any): Column = withExpr { + val right = lit(other).expr + if (this.expr == right) { + logWarning( + s"Constructing trivially true equals predicate, '${this.expr} <=> $right'. " + + "Perhaps you need to use aliases.") + } + EqualNullSafe(expr, right) + } /** * Equality test that is safe for null values. From f953ca56eccdaef29ac580d44613a028415ba3f5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 3 Jul 2017 10:51:44 -0700 Subject: [PATCH 074/779] [SPARK-21284][SQL] rename SessionCatalog.registerFunction parameter name ## What changes were proposed in this pull request? Looking at the code in `SessionCatalog.registerFunction`, the parameter `ignoreIfExists` is a wrong name. When `ignoreIfExists` is true, we will override the function if it already exists. So `overrideIfExists` should be the corrected name. ## How was this patch tested? N/A Author: Wenchen Fan Closes #18510 from cloud-fan/minor. --- .../sql/catalyst/catalog/SessionCatalog.scala | 6 +++--- .../catalog/SessionCatalogSuite.scala | 20 ++++++++++--------- .../sql/execution/command/functions.scala | 3 +-- .../spark/sql/internal/CatalogSuite.scala | 2 +- .../spark/sql/hive/HiveSessionCatalog.scala | 2 +- .../ObjectHashAggregateExecBenchmark.scala | 2 +- 6 files changed, 18 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 7ece77df7fc14..a86604e4353ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1104,10 +1104,10 @@ class SessionCatalog( */ def registerFunction( funcDefinition: CatalogFunction, - ignoreIfExists: Boolean, + overrideIfExists: Boolean, functionBuilder: Option[FunctionBuilder] = None): Unit = { val func = funcDefinition.identifier - if (functionRegistry.functionExists(func) && !ignoreIfExists) { + if (functionRegistry.functionExists(func) && !overrideIfExists) { throw new AnalysisException(s"Function $func already exists") } val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) @@ -1219,7 +1219,7 @@ class SessionCatalog( // catalog. So, it is possible that qualifiedName is not exactly the same as // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). // At here, we preserve the input from the user. - registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false) + registerFunction(catalogFunction.copy(identifier = qualifiedName), overrideIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(qualifiedName, children) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index fc3893e197792..8f856a0daad15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1175,9 +1175,9 @@ abstract class SessionCatalogSuite extends AnalysisTest { val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last catalog.registerFunction( - newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + newFunc("temp1", None), overrideIfExists = false, functionBuilder = Some(tempFunc1)) catalog.registerFunction( - newFunc("temp2", None), ignoreIfExists = false, functionBuilder = Some(tempFunc2)) + newFunc("temp2", None), overrideIfExists = false, functionBuilder = Some(tempFunc2)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) @@ -1189,12 +1189,12 @@ abstract class SessionCatalogSuite extends AnalysisTest { // Temporary function already exists val e = intercept[AnalysisException] { catalog.registerFunction( - newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc3)) + newFunc("temp1", None), overrideIfExists = false, functionBuilder = Some(tempFunc3)) }.getMessage assert(e.contains("Function temp1 already exists")) // Temporary function is overridden catalog.registerFunction( - newFunc("temp1", None), ignoreIfExists = true, functionBuilder = Some(tempFunc3)) + newFunc("temp1", None), overrideIfExists = true, functionBuilder = Some(tempFunc3)) assert( catalog.lookupFunction( FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) @@ -1208,7 +1208,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { val tempFunc1 = (e: Seq[Expression]) => e.head catalog.registerFunction( - newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + newFunc("temp1", None), overrideIfExists = false, functionBuilder = Some(tempFunc1)) // Returns true when the function is temporary assert(catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) @@ -1259,7 +1259,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { withBasicCatalog { catalog => val tempFunc = (e: Seq[Expression]) => e.head catalog.registerFunction( - newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc)) + newFunc("func1", None), overrideIfExists = false, functionBuilder = Some(tempFunc)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1300,7 +1300,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { withBasicCatalog { catalog => val tempFunc1 = (e: Seq[Expression]) => e.head catalog.registerFunction( - newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + newFunc("func1", None), overrideIfExists = false, functionBuilder = Some(tempFunc1)) assert(catalog.lookupFunction( FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1318,8 +1318,10 @@ abstract class SessionCatalogSuite extends AnalysisTest { val tempFunc2 = (e: Seq[Expression]) => e.last catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) - catalog.registerFunction(funcMeta1, ignoreIfExists = false, functionBuilder = Some(tempFunc1)) - catalog.registerFunction(funcMeta2, ignoreIfExists = false, functionBuilder = Some(tempFunc2)) + catalog.registerFunction( + funcMeta1, overrideIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction( + funcMeta2, overrideIfExists = false, functionBuilder = Some(tempFunc2)) assert(catalog.listFunctions("db1", "*").map(_._1).toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index f39a3269efaf1..a91ad413f4d1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -58,9 +58,8 @@ case class CreateFunctionCommand( s"is not allowed: '${databaseName.get}'") } // We first load resources and then put the builder in the function registry. - // Please note that it is allowed to overwrite an existing temp function. catalog.loadFunctionResources(resources) - catalog.registerFunction(func, ignoreIfExists = false) + catalog.registerFunction(func, overrideIfExists = false) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index b2d568ce320e6..6acac1a9aa317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -79,7 +79,7 @@ class CatalogSuite val tempFunc = (e: Seq[Expression]) => e.head val funcMeta = CatalogFunction(FunctionIdentifier(name, None), "className", Nil) sessionCatalog.registerFunction( - funcMeta, ignoreIfExists = false, functionBuilder = Some(tempFunc)) + funcMeta, overrideIfExists = false, functionBuilder = Some(tempFunc)) } private def dropFunction(name: String, db: Option[String] = None): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index da87f0218e3ad..0d0269f694300 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -161,7 +161,7 @@ private[sql] class HiveSessionCatalog( FunctionIdentifier(functionName.toLowerCase(Locale.ROOT), database) val func = CatalogFunction(functionIdentifier, className, Nil) // Put this Hive built-in function to our function registry. - registerFunction(func, ignoreIfExists = false) + registerFunction(func, overrideIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(functionIdentifier, children) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 73383ae4d4118..e599d1ab1d486 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -221,7 +221,7 @@ class ObjectHashAggregateExecBenchmark extends BenchmarkBase with TestHiveSingle val sessionCatalog = sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog] val functionIdentifier = FunctionIdentifier(functionName, database = None) val func = CatalogFunction(functionIdentifier, clazz.getName, resources = Nil) - sessionCatalog.registerFunction(func, ignoreIfExists = false) + sessionCatalog.registerFunction(func, overrideIfExists = false) } private def percentile_approx( From c79c10ebaf3d63b697b8d6d1a7e55aa2d406af69 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 3 Jul 2017 16:18:54 -0700 Subject: [PATCH 075/779] [TEST] Different behaviors of SparkContext Conf when building SparkSession ## What changes were proposed in this pull request? If the created ACTIVE sparkContext is not EXPLICITLY passed through the Builder's API `sparkContext()`, the conf of this sparkContext will also contain the conf set through the API `config()`; otherwise, the conf of this sparkContext will NOT contain the conf set through the API `config()` ## How was this patch tested? N/A Author: gatorsmile Closes #18517 from gatorsmile/fixTestCase2. --- .../spark/sql/SparkSessionBuilderSuite.scala | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 386d13d07a95f..4f6d5f79d466e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -98,12 +98,31 @@ class SparkSessionBuilderSuite extends SparkFunSuite { val session = SparkSession.builder().config("key2", "value2").getOrCreate() assert(session.conf.get("key1") == "value1") assert(session.conf.get("key2") == "value2") + assert(session.sparkContext == sparkContext2) assert(session.sparkContext.conf.get("key1") == "value1") + // If the created sparkContext is not passed through the Builder's API sparkContext, + // the conf of this sparkContext will also contain the conf set through the API config. assert(session.sparkContext.conf.get("key2") == "value2") assert(session.sparkContext.conf.get("spark.app.name") == "test") session.stop() } + test("create SparkContext first then pass context to SparkSession") { + sparkContext.stop() + val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") + val newSC = new SparkContext(conf) + val session = SparkSession.builder().sparkContext(newSC).config("key2", "value2").getOrCreate() + assert(session.conf.get("key1") == "value1") + assert(session.conf.get("key2") == "value2") + assert(session.sparkContext == newSC) + assert(session.sparkContext.conf.get("key1") == "value1") + // If the created sparkContext is passed through the Builder's API sparkContext, + // the conf of this sparkContext will not contain the conf set through the API config. + assert(!session.sparkContext.conf.contains("key2")) + assert(session.sparkContext.conf.get("spark.app.name") == "test") + session.stop() + } + test("SPARK-15887: hive-site.xml should be loaded") { val session = SparkSession.builder().master("local").getOrCreate() assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true") From 6657e00de36b59011d3fe78e8613fb64e54c957a Mon Sep 17 00:00:00 2001 From: liuxian Date: Tue, 4 Jul 2017 09:16:40 +0800 Subject: [PATCH 076/779] [SPARK-21283][CORE] FileOutputStream should be created as append mode ## What changes were proposed in this pull request? `FileAppender` is used to write `stderr` and `stdout` files in `ExecutorRunner`, But before writing `ErrorStream` into the the `stderr` file, the header information has been written into ,if FileOutputStream is not created as append mode, the header information will be lost ## How was this patch tested? unit test case Author: liuxian Closes #18507 from 10110346/wip-lx-0703. --- .../scala/org/apache/spark/util/logging/FileAppender.scala | 2 +- .../test/scala/org/apache/spark/util/FileAppenderSuite.scala | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index fdb1495899bc3..8a0cc709bccc5 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -94,7 +94,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi /** Open the file output stream */ protected def openFile() { - outputStream = new FileOutputStream(file, false) + outputStream = new FileOutputStream(file, true) logDebug(s"Opened file $file") } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 7e2da8e141532..cd0ed5b036bf9 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -52,10 +52,13 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { test("basic file appender") { val testString = (1 to 1000).mkString(", ") val inputStream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)) + // The `header` should not be covered + val header = "Add header" + Files.write(header, testFile, StandardCharsets.UTF_8) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(Files.toString(testFile, StandardCharsets.UTF_8) === testString) + assert(Files.toString(testFile, StandardCharsets.UTF_8) === header + testString) } test("rolling file appender - time-based rolling") { From a848d552ef6b5d0d3bb3b2da903478437a8b10aa Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 11:35:08 +0900 Subject: [PATCH 077/779] [SPARK-21264][PYTHON] Call cross join path in join without 'on' and with 'how' ## What changes were proposed in this pull request? Currently, it throws a NPE when missing columns but join type is speicified in join at PySpark as below: ```python spark.conf.set("spark.sql.crossJoin.enabled", "false") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` Traceback (most recent call last): ... py4j.protocol.Py4JJavaError: An error occurred while calling o66.join. : java.lang.NullPointerException at org.apache.spark.sql.Dataset.join(Dataset.scala:931) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ... ``` ```python spark.conf.set("spark.sql.crossJoin.enabled", "true") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` ... py4j.protocol.Py4JJavaError: An error occurred while calling o84.join. : java.lang.NullPointerException at org.apache.spark.sql.Dataset.join(Dataset.scala:931) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ... ``` This PR suggests to follow Scala's one as below: ```scala scala> spark.conf.set("spark.sql.crossJoin.enabled", "false") scala> spark.range(1).join(spark.range(1), Seq.empty[String], "inner").show() ``` ``` org.apache.spark.sql.AnalysisException: Detected cartesian product for INNER join between logical plans Range (0, 1, step=1, splits=Some(8)) and Range (0, 1, step=1, splits=Some(8)) Join condition is missing or trivial. Use the CROSS JOIN syntax to allow cartesian products between these relations.; ... ``` ```scala scala> spark.conf.set("spark.sql.crossJoin.enabled", "true") scala> spark.range(1).join(spark.range(1), Seq.empty[String], "inner").show() ``` ``` +---+---+ | id| id| +---+---+ | 0| 0| +---+---+ ``` **After** ```python spark.conf.set("spark.sql.crossJoin.enabled", "false") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` Traceback (most recent call last): ... pyspark.sql.utils.AnalysisException: u'Detected cartesian product for INNER join between logical plans\nRange (0, 1, step=1, splits=Some(8))\nand\nRange (0, 1, step=1, splits=Some(8))\nJoin condition is missing or trivial.\nUse the CROSS JOIN syntax to allow cartesian products between these relations.;' ``` ```python spark.conf.set("spark.sql.crossJoin.enabled", "true") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` +---+---+ | id| id| +---+---+ | 0| 0| +---+---+ ``` ## How was this patch tested? Added tests in `python/pyspark/sql/tests.py`. Author: hyukjinkwon Closes #18484 from HyukjinKwon/SPARK-21264. --- python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/tests.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0649271ed2246..27a6dad8917d3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -833,6 +833,8 @@ def join(self, other, on=None, how=None): else: if how is None: how = "inner" + if on is None: + on = self._jseq([]) assert isinstance(how, basestring), "how should be basestring" jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0a1cd6856b8e8..c105969b26b97 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2021,6 +2021,22 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + def test_join_without_on(self): + df1 = self.spark.range(1).toDF("a") + df2 = self.spark.range(1).toDF("b") + + try: + self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) + + self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + actual = df1.join(df2, how="inner").collect() + expected = [Row(a=0, b=0)] + self.assertEqual(actual, expected) + finally: + # We should unset this. Otherwise, other tests are affected. + self.spark.conf.unset("spark.sql.crossJoin.enabled") + # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) From 8ca4ebefa6301d9cb633ea15cf71f49c2d7f8607 Mon Sep 17 00:00:00 2001 From: Thomas Decaux Date: Tue, 4 Jul 2017 12:17:48 +0100 Subject: [PATCH 078/779] [MINOR] Add french stop word "les" ## What changes were proposed in this pull request? Added "les" as french stop word (plurial of le) Author: Thomas Decaux Closes #18514 from ebuildy/patch-1. --- .../resources/org/apache/spark/ml/feature/stopwords/french.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt index 94b8f8f39a3e1..a59a0424616cc 100644 --- a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt @@ -15,6 +15,7 @@ il je la le +les leur lui ma @@ -152,4 +153,4 @@ eusses eût eussions eussiez -eussent \ No newline at end of file +eussent From 2b1e94b9add82b30bc94f639fa97492624bf0dce Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 12:18:42 +0100 Subject: [PATCH 079/779] [MINOR][SPARK SUBMIT] Print out R file usage in spark-submit ## What changes were proposed in this pull request? Currently, running the shell below: ```bash $ ./bin/spark-submit tmp.R a b c ``` with R file, `tmp.R` as below: ```r #!/usr/bin/env Rscript library(SparkR) sparkRSQL.init(sparkR.init(master = "local")) collect(createDataFrame(list(list(1)))) print(commandArgs(trailingOnly = TRUE)) ``` working fine as below: ```bash _1 1 1 [1] "a" "b" "c" ``` However, it looks not printed in usage documentation as below: ```bash $ ./bin/spark-submit ``` ``` Usage: spark-submit [options] [app arguments] ... ``` For `./bin/sparkR`, it looks fine as below: ```bash $ ./bin/sparkR tmp.R ``` ``` Running R applications through 'sparkR' is not supported as of Spark 2.0. Use ./bin/spark-submit ``` Running the script below: ```bash $ ./bin/spark-submit ``` **Before** ``` Usage: spark-submit [options] [app arguments] ... ``` **After** ``` Usage: spark-submit [options] [app arguments] ... ``` ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #18505 from HyukjinKwon/minor-doc-summit. --- .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 3d9a14c51618b..7800d3d624e3e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -504,7 +504,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println("Unknown/unsupported param " + unknownParam) } val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( - """Usage: spark-submit [options] [app arguments] + """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] |Usage: spark-submit --status [submission ID] --master [spark://...] |Usage: spark-submit run-example [options] example-class [example args]""".stripMargin) From d492cc5a21cd67b3999b85d97f5c41c3734b1ba3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 20:45:58 +0800 Subject: [PATCH 080/779] [SPARK-19507][SPARK-21296][PYTHON] Avoid per-record type dispatch in schema verification and improve exception message ## What changes were proposed in this pull request? **Context** While reviewing https://github.com/apache/spark/pull/17227, I realised here we type-dispatch per record. The PR itself is fine in terms of performance as is but this prints a prefix, `"obj"` in exception message as below: ``` from pyspark.sql.types import * schema = StructType([StructField('s', IntegerType(), nullable=False)]) spark.createDataFrame([["1"]], schema) ... TypeError: obj.s: IntegerType can not accept object '1' in type ``` I suggested to get rid of this but during investigating this, I realised my approach might bring a performance regression as it is a hot path. Only for SPARK-19507 and https://github.com/apache/spark/pull/17227, It needs more changes to cleanly get rid of the prefix and I rather decided to fix both issues together. **Propersal** This PR tried to - get rid of per-record type dispatch as we do in many code paths in Scala so that it improves the performance (roughly ~25% improvement) - SPARK-21296 This was tested with a simple code `spark.createDataFrame(range(1000000), "int")`. However, I am quite sure the actual improvement in practice is larger than this, in particular, when the schema is complicated. - improve error message in exception describing field information as prose - SPARK-19507 ## How was this patch tested? Manually tested and unit tests were added in `python/pyspark/sql/tests.py`. Benchmark - codes: https://gist.github.com/HyukjinKwon/c3397469c56cb26c2d7dd521ed0bc5a3 Error message - codes: https://gist.github.com/HyukjinKwon/b1b2c7f65865444c4a8836435100e398 **Before** Benchmark: - Results: https://gist.github.com/HyukjinKwon/4a291dab45542106301a0c1abcdca924 Error message - Results: https://gist.github.com/HyukjinKwon/57b1916395794ce924faa32b14a3fe19 **After** Benchmark - Results: https://gist.github.com/HyukjinKwon/21496feecc4a920e50c4e455f836266e Error message - Results: https://gist.github.com/HyukjinKwon/7a494e4557fe32a652ce1236e504a395 Closes #17227 Author: hyukjinkwon Author: David Gingrich Closes #18521 from HyukjinKwon/python-type-dispatch. --- python/pyspark/rdd.py | 1 - python/pyspark/sql/session.py | 12 +- python/pyspark/sql/tests.py | 203 ++++++++++++++++++++++++++++++- python/pyspark/sql/types.py | 219 +++++++++++++++++++++++----------- 4 files changed, 352 insertions(+), 83 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 60141792d499b..7dfa17f68a943 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -627,7 +627,6 @@ def sortPartition(iterator): def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. - # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey().first() diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e3bf0f35ea15e..2cc0e2d1d7b8d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -33,7 +33,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ +from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string from pyspark.sql.utils import install_exception_handler @@ -514,17 +514,21 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] - verify_func = _verify_type if verifySchema else lambda _, t: True if isinstance(schema, StructType): + verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, schema) + verify_func(obj) return obj elif isinstance(schema, DataType): dataType = schema schema = StructType().add("value", schema) + verify_func = _make_type_verifier( + dataType, name="field value") if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, dataType) + verify_func(obj) return obj, else: if isinstance(schema, list): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c105969b26b97..16ba8bd73f400 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -57,7 +57,7 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -852,7 +852,7 @@ def test_convert_row_to_dict(self): self.assertEqual(1.0, row.asDict()['d']['key'].c) def test_udt(self): - from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier from pyspark.sql.tests import ExamplePointUDT, ExamplePoint def check_datatype(datatype): @@ -868,8 +868,8 @@ def check_datatype(datatype): check_datatype(structtype_with_udt) p = ExamplePoint(1.0, 2.0) self.assertEqual(_infer_type(p), ExamplePointUDT()) - _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) + self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) check_datatype(PythonOnlyUDT()) structtype_with_udt = StructType([StructField("label", DoubleType(), False), @@ -877,8 +877,10 @@ def check_datatype(datatype): check_datatype(structtype_with_udt) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) - _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) + self.assertRaises( + ValueError, + lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) def test_simple_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) @@ -2636,6 +2638,195 @@ def range_frame_match(): importlib.reload(window) + +class DataTypeVerificationTests(unittest.TestCase): + + def test_verify_type_exception_msg(self): + self.assertRaisesRegexp( + ValueError, + "test_name", + lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None)) + + schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) + self.assertRaisesRegexp( + TypeError, + "field b in field a", + lambda: _make_type_verifier(schema)([["data"]])) + + def test_verify_type_ok_nullable(self): + obj = None + types = [IntegerType(), FloatType(), StringType(), StructType([])] + for data_type in types: + try: + _make_type_verifier(data_type, nullable=True)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type)) + + def test_verify_type_not_nullable(self): + import array + import datetime + import decimal + + schema = StructType([ + StructField('s', StringType(), nullable=False), + StructField('i', IntegerType(), nullable=True)]) + + class MyObj: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + # obj, data_type + success_spec = [ + # String + ("", StringType()), + (u"", StringType()), + (1, StringType()), + (1.0, StringType()), + ([], StringType()), + ({}, StringType()), + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT()), + + # Boolean + (True, BooleanType()), + + # Byte + (-(2**7), ByteType()), + (2**7 - 1, ByteType()), + + # Short + (-(2**15), ShortType()), + (2**15 - 1, ShortType()), + + # Integer + (-(2**31), IntegerType()), + (2**31 - 1, IntegerType()), + + # Long + (2**64, LongType()), + + # Float & Double + (1.0, FloatType()), + (1.0, DoubleType()), + + # Decimal + (decimal.Decimal("1.0"), DecimalType()), + + # Binary + (bytearray([1, 2]), BinaryType()), + + # Date/Timestamp + (datetime.date(2000, 1, 2), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), + + # Array + ([], ArrayType(IntegerType())), + (["1", None], ArrayType(StringType(), containsNull=True)), + ([1, 2], ArrayType(IntegerType())), + ((1, 2), ArrayType(IntegerType())), + (array.array('h', [1, 2]), ArrayType(IntegerType())), + + # Map + ({}, MapType(StringType(), IntegerType())), + ({"a": 1}, MapType(StringType(), IntegerType())), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), + + # Struct + ({"s": "a", "i": 1}, schema), + ({"s": "a", "i": None}, schema), + ({"s": "a"}, schema), + ({"s": "a", "f": 1.0}, schema), + (Row(s="a", i=1), schema), + (Row(s="a", i=None), schema), + (Row(s="a", i=1, f=1.0), schema), + (["a", 1], schema), + (["a", None], schema), + (("a", 1), schema), + (MyObj(s="a", i=1), schema), + (MyObj(s="a", i=None), schema), + (MyObj(s="a"), schema), + ] + + # obj, data_type, exception class + failure_spec = [ + # String (match anything but None) + (None, StringType(), ValueError), + + # UDT + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Byte + (-(2**7) - 1, ByteType(), ValueError), + (2**7, ByteType(), ValueError), + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Short + (-(2**15) - 1, ShortType(), ValueError), + (2**15, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (2**31, IntegerType(), ValueError), + + # Float & Double + (1, FloatType(), TypeError), + (1, DoubleType(), TypeError), + + # Decimal + (1.0, DecimalType(), TypeError), + (1, DecimalType(), TypeError), + ("1.0", DecimalType(), TypeError), + + # Binary + (1, BinaryType(), TypeError), + + # Date/Timestamp + ("2000-01-02", DateType(), TypeError), + (946811040, TimestampType(), TypeError), + + # Array + (["1", None], ArrayType(StringType(), containsNull=False), ValueError), + ([1, "2"], ArrayType(IntegerType()), TypeError), + + # Map + ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), + ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), + ValueError), + + # Struct + ({"s": "a", "i": "1"}, schema, TypeError), + (Row(s="a"), schema, ValueError), # Row can't have missing field + (Row(s="a", i="1"), schema, TypeError), + (["a"], schema, ValueError), + (["a", "1"], schema, TypeError), + (MyObj(s="a", i="1"), schema, TypeError), + (MyObj(s=None, i="1"), schema, ValueError), + ] + + # Check success cases + for obj, data_type in success_spec: + try: + _make_type_verifier(data_type, nullable=False)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) + + # Check failure cases + for obj, data_type, exp in failure_spec: + msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + with self.assertRaises(exp, msg=msg): + _make_type_verifier(data_type, nullable=False)(obj) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 26b54a7fb3709..f5505ed4722ad 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1249,121 +1249,196 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType, nullable=True): +def _make_type_verifier(dataType, nullable=True, name=None): """ - Verify the type of obj against dataType, raise a TypeError if they do not match. - - Also verify the value of obj against datatype, raise a ValueError if it's not within the allowed - range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it - will become infinity when cast to Java float if it overflows. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, LongType()) - >>> _verify_type(list(range(3)), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Make a verifier that checks the type of obj against dataType and raises a TypeError if they do + not match. + + This verifier also checks the value of obj against datatype and raises a ValueError if it's not + within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is + not checked, so it will become infinity when cast to Java float if it overflows. + + >>> _make_type_verifier(StructType([]))(None) + >>> _make_type_verifier(StringType())("") + >>> _make_type_verifier(LongType())(0) + >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3))) + >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({}) + >>> _make_type_verifier(StructType([]))(()) + >>> _make_type_verifier(StructType([]))([]) + >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... >>> # Check if numeric values are within the allowed range. - >>> _verify_type(12, ByteType()) - >>> _verify_type(1234, ByteType()) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType())(12) + >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier( + ... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type({None: 1}, MapType(StringType(), IntegerType())) + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1}) Traceback (most recent call last): ... ValueError:... >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False) - >>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... """ - if obj is None: - if nullable: - return - else: - raise ValueError("This field is not nullable, but got None") - # StringType can work with any types - if isinstance(dataType, StringType): - return + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.toInternal(obj), dataType.sqlType()) - return + def verify_nullability(obj): + if obj is None: + if nullable: + return True + else: + raise ValueError(new_msg("This field is not nullable, but got None")) + else: + return False _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) - if _type is StructType: - # check the type and fields later - pass - else: + def assert_acceptable_types(obj): + assert _type in _acceptable_types, \ + new_msg("unknown datatype: %s for object %r" % (dataType, obj)) + + def verify_acceptable_types(obj): # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) + raise TypeError(new_msg("%s can not accept object %r in type %s" + % (dataType, obj, type(obj)))) + + if isinstance(dataType, StringType): + # StringType can work with any types + verify_value = lambda _: _ + + elif isinstance(dataType, UserDefinedType): + verifier = _make_type_verifier(dataType.sqlType(), name=name) - if isinstance(dataType, ByteType): - if obj < -128 or obj > 127: - raise ValueError("object of ByteType out of range, got: %s" % obj) + def verify_udf(obj): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) + verifier(dataType.toInternal(obj)) + + verify_value = verify_udf + + elif isinstance(dataType, ByteType): + def verify_byte(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -128 or obj > 127: + raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj)) + + verify_value = verify_byte elif isinstance(dataType, ShortType): - if obj < -32768 or obj > 32767: - raise ValueError("object of ShortType out of range, got: %s" % obj) + def verify_short(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -32768 or obj > 32767: + raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj)) + + verify_value = verify_short elif isinstance(dataType, IntegerType): - if obj < -2147483648 or obj > 2147483647: - raise ValueError("object of IntegerType out of range, got: %s" % obj) + def verify_integer(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -2147483648 or obj > 2147483647: + raise ValueError( + new_msg("object of IntegerType out of range, got: %s" % obj)) + + verify_value = verify_integer elif isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType, dataType.containsNull) + element_verifier = _make_type_verifier( + dataType.elementType, dataType.containsNull, name="element in array %s" % name) + + def verify_array(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for i in obj: + element_verifier(i) + + verify_value = verify_array elif isinstance(dataType, MapType): - for k, v in obj.items(): - _verify_type(k, dataType.keyType, False) - _verify_type(v, dataType.valueType, dataType.valueContainsNull) + key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name) + value_verifier = _make_type_verifier( + dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name) + + def verify_map(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for k, v in obj.items(): + key_verifier(k) + value_verifier(v) + + verify_value = verify_map elif isinstance(dataType, StructType): - if isinstance(obj, dict): - for f in dataType.fields: - _verify_type(obj.get(f.name), f.dataType, f.nullable) - elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): - # the order in obj could be different than dataType.fields - for f in dataType.fields: - _verify_type(obj[f.name], f.dataType, f.nullable) - elif isinstance(obj, (tuple, list)): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) - elif hasattr(obj, "__dict__"): - d = obj.__dict__ - for f in dataType.fields: - _verify_type(d.get(f.name), f.dataType, f.nullable) - else: - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + verifiers = [] + for f in dataType.fields: + verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name)) + verifiers.append((f.name, verifier)) + + def verify_struct(obj): + assert_acceptable_types(obj) + + if isinstance(obj, dict): + for f, verifier in verifiers: + verifier(obj.get(f)) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + # the order in obj could be different than dataType.fields + for f, verifier in verifiers: + verifier(obj[f]) + elif isinstance(obj, (tuple, list)): + if len(obj) != len(verifiers): + raise ValueError( + new_msg("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(verifiers)))) + for v, (_, verifier) in zip(obj, verifiers): + verifier(v) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + for f, verifier in verifiers: + verifier(d.get(f)) + else: + raise TypeError(new_msg("StructType can not accept object %r in type %s" + % (obj, type(obj)))) + verify_value = verify_struct + + else: + def verify_default(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + + verify_value = verify_default + + def verify(obj): + if not verify_nullability(obj): + verify_value(obj) + + return verify # This is used to unpickle a Row from JVM From 29b1f6b09f98e216af71e893a9da0c4717c80679 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 4 Jul 2017 08:54:07 -0700 Subject: [PATCH 081/779] [SPARK-21256][SQL] Add withSQLConf to Catalyst Test ### What changes were proposed in this pull request? SQLConf is moved to Catalyst. We are adding more and more test cases for verifying the conf-specific behaviors. It is nice to add a helper function to simplify the test cases. ### How was this patch tested? N/A Author: gatorsmile Closes #18469 from gatorsmile/withSQLConf. --- .../InferFiltersFromConstraintsSuite.scala | 5 +-- .../optimizer/OuterJoinEliminationSuite.scala | 6 +--- .../optimizer/PruneFiltersSuite.scala | 6 +--- .../plans/ConstraintPropagationSuite.scala | 24 +++++++------- .../spark/sql/catalyst/plans/PlanTest.scala | 32 ++++++++++++++++++- .../AggregateEstimationSuite.scala | 9 ++---- .../BasicStatsEstimationSuite.scala | 12 +++---- .../spark/sql/SparkSessionBuilderSuite.scala | 3 ++ .../apache/spark/sql/test/SQLTestUtils.scala | 30 ++++------------- 9 files changed, 64 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index cdc9f25cf8777..d2dd469e2d74f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -206,13 +206,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("No inferred filter when constraint propagation is disabled") { - try { - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index 623ff3d446a5f..893c111c2906b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -234,9 +234,7 @@ class OuterJoinEliminationSuite extends PlanTest { } test("no outer join elimination if constraint propagation is disabled") { - try { - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val x = testRelation.subquery('x) val y = testRelation1.subquery('y) @@ -251,8 +249,6 @@ class OuterJoinEliminationSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 706634cdd29b8..6d1a05f3c998e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -149,8 +148,7 @@ class PruneFiltersSuite extends PlanTest { ("tr1.a".attr > 10 || "tr1.c".attr < 10) && 'd.attr < 100) - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - try { + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val optimized = Optimize.execute(queryWithUselessFilter.analyze) // When constraint propagation is disabled, the useless filter won't be pruned. // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant @@ -160,8 +158,6 @@ class PruneFiltersSuite extends PlanTest { .join(tr2.where('d.attr < 100).where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze comparePlans(optimized, correctAnswer) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a3948d90b0e4d..a37e06d922642 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType} -class ConstraintPropagationSuite extends SparkFunSuite { +class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = resolveColumn(tr.analyze, columnName) @@ -400,26 +400,26 @@ class ConstraintPropagationSuite extends SparkFunSuite { } test("enable/disable constraint propagation") { - try { - val tr = LocalRelation('a.int, 'b.string, 'c.int) - val filterRelation = tr.where('a.attr > 10) + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { assert(filterRelation.analyze.constraints.nonEmpty) + } - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { assert(filterRelation.analyze.constraints.isEmpty) + } - val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { assert(aliasedRelation.analyze.constraints.nonEmpty) + } - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { assert(aliasedRelation.analyze.constraints.isEmpty) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6883d23d477e4..e9679d3361509 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -28,8 +29,9 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -abstract class PlanTest extends SparkFunSuite with PredicateHelper { +trait PlanTest extends SparkFunSuite with PredicateHelper { + // TODO(gatorsmile): remove this from PlanTest and all the analyzer/optimizer rules protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true) /** @@ -142,4 +144,32 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { plan1 == plan2 } } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (SQLConf.staticConfKeys.contains(k)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 30ddf03bd3c4f..23f95a6cc2ac2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.internal.SQLConf -class AggregateEstimationSuite extends StatsEstimationTestBase { +class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { /** Columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( @@ -100,9 +101,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { size = Some(4 * (8 + 4)), attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) - val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) - try { - SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + withSQLConf(SQLConf.CBO_ENABLED.key -> "false") { val noGroupAgg = Aggregate(groupingExpressions = Nil, aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) assert(noGroupAgg.stats == @@ -114,8 +113,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { assert(hasGroupAgg.stats == // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) - } finally { - SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 31a8cbdee9777..5fd21a06a109d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType -class BasicStatsEstimationSuite extends StatsEstimationTestBase { +class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { val attribute = attr("key") val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) @@ -82,18 +83,15 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) - try { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { // Invalidate statistics plan.invalidateStatsCache() - SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) assert(plan.stats == expectedStatsCboOn) + } + withSQLConf(SQLConf.CBO_ENABLED.key -> "false") { plan.invalidateStatsCache() - SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) assert(plan.stats == expectedStatsCboOff) - } finally { - SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 4f6d5f79d466e..cdac6827082c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.internal.SQLConf /** * Test cases for the builder pattern of [[SparkSession]]. @@ -67,6 +68,8 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(activeSession != defaultSession) assert(session == activeSession) assert(session.conf.get("spark-config2") == "a") + assert(session.sessionState.conf == SQLConf.get) + assert(SQLConf.get.getConfString("spark-config2") == "a") SparkSession.clearActiveSession() assert(SparkSession.builder().getOrCreate() == defaultSession) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d74a7cce25ed6..92ee7d596acd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -35,9 +35,11 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{UninterruptibleThread, Utils} /** @@ -53,7 +55,8 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} private[sql] trait SQLTestUtils extends SparkFunSuite with Eventually with BeforeAndAfterAll - with SQLTestData { self => + with SQLTestData + with PlanTest { self => protected def sparkContext = spark.sparkContext @@ -89,28 +92,9 @@ private[sql] trait SQLTestUtils } } - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map { key => - if (spark.conf.contains(key)) { - Some(spark.conf.get(key)) - } else { - None - } - } - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } + protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + SparkSession.setActiveSession(spark) + super.withSQLConf(pairs: _*)(f) } /** From a3c29fcbbda02c1528b4185bcb880c91077d480c Mon Sep 17 00:00:00 2001 From: "YIHAODIAN\\wangshuangshuang" Date: Tue, 4 Jul 2017 09:44:27 -0700 Subject: [PATCH 082/779] [SPARK-19726][SQL] Faild to insert null timestamp value to mysql using spark jdbc ## What changes were proposed in this pull request? when creating table like following: > create table timestamp_test(id int(11), time_stamp timestamp not null default current_timestamp); The result of Excuting "insert into timestamp_test values (111, null)" is different between Spark and JDBC. ``` mysql> select * from timestamp_test; +------+---------------------+ | id | time_stamp | +------+---------------------+ | 111 | 1970-01-01 00:00:00 | -> spark | 111 | 2017-06-27 19:32:38 | -> mysql +------+---------------------+ 2 rows in set (0.00 sec) ``` Because in such case ```StructField.nullable``` is false, so the generated codes of ```InvokeLike``` and ```BoundReference``` don't check whether the field is null or not. Instead, they directly use ```CodegenContext.INPUT_ROW.getLong(1)```, however, ```UnsafeRow.setNullAt(1)``` will put 0 in the underlying memory. The PR will ```always``` set ```StructField.nullable``` true after obtaining metadata from jdbc connection, Since we can insert null to not null timestamp column in MySQL. In this way, spark will propagate null to underlying DB engine, and let DB to choose how to process NULL. ## How was this patch tested? Added tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: YIHAODIAN\wangshuangshuang Author: Shuangshuang Wang Closes #18445 from shuangshuangwang/SPARK-19726. --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 12 ++++++++++-- .../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 8 ++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 0f53b5c7c6f0f..57e9bc9b70454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -59,7 +59,7 @@ object JDBCRDD extends Logging { try { val rs = statement.executeQuery() try { - JdbcUtils.getSchema(rs, dialect) + JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) } finally { rs.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index ca61c2efe2ddf..55b2539c13381 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -266,10 +266,14 @@ object JdbcUtils extends Logging { /** * Takes a [[ResultSet]] and returns its Catalyst schema. * + * @param alwaysNullable If true, all the columns are nullable. * @return A [[StructType]] giving the Catalyst schema. * @throws SQLException if the schema contains an unsupported type. */ - def getSchema(resultSet: ResultSet, dialect: JdbcDialect): StructType = { + def getSchema( + resultSet: ResultSet, + dialect: JdbcDialect, + alwaysNullable: Boolean = false): StructType = { val rsmd = resultSet.getMetaData val ncols = rsmd.getColumnCount val fields = new Array[StructField](ncols) @@ -290,7 +294,11 @@ object JdbcUtils extends Logging { rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true } } - val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + val nullable = if (alwaysNullable) { + true + } else { + rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + } val metadata = new MetadataBuilder() .putString("name", columnName) .putLong("scale", fieldScale) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index bf1fd160704fa..92f50a095f19b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters.propertiesAsScalaMapConverter import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} @@ -506,4 +507,11 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { "schema struct")) } } + + test("SPARK-19726: INSERT null to a NOT NULL column") { + val e = intercept[SparkException] { + sql("INSERT INTO PEOPLE1 values (null, null)") + }.getMessage + assert(e.contains("NULL not allowed for column \"NAME\"")) + } } From 1b50e0e0d6fd9d1b815a3bb37647ea659222e3f1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 4 Jul 2017 09:48:40 -0700 Subject: [PATCH 083/779] [SPARK-20256][SQL] SessionState should be created more lazily ## What changes were proposed in this pull request? `SessionState` is designed to be created lazily. However, in reality, it created immediately in `SparkSession.Builder.getOrCreate` ([here](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala#L943)). This PR aims to recover the lazy behavior by keeping the options into `initialSessionOptions`. The benefit is like the following. Users can start `spark-shell` and use RDD operations without any problems. **BEFORE** ```scala $ bin/spark-shell java.lang.IllegalArgumentException: Error while instantiating 'org.apache.spark.sql.hive.HiveSessionStateBuilder' ... Caused by: org.apache.spark.sql.AnalysisException: org.apache.hadoop.hive.ql.metadata.HiveException: MetaException(message:java.security.AccessControlException: Permission denied: user=spark, access=READ, inode="/apps/hive/warehouse":hive:hdfs:drwx------ ``` As reported in SPARK-20256, this happens when the warehouse directory is not allowed for this user. **AFTER** ```scala $ bin/spark-shell ... Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.3.0-SNAPSHOT /_/ Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_112) Type in expressions to have them evaluated. Type :help for more information. scala> sc.range(0, 10, 1).count() res0: Long = 10 ``` ## How was this patch tested? Manual. This closes #18512 . Author: Dongjoon Hyun Closes #18501 from dongjoon-hyun/SPARK-20256. --- .../scala/org/apache/spark/sql/SparkSession.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2c38f7d7c88da..0ddcd2111aa58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -117,6 +117,12 @@ class SparkSession private( existingSharedState.getOrElse(new SharedState(sparkContext)) } + /** + * Initial options for session. This options are applied once when sessionState is created. + */ + @transient + private[sql] val initialSessionOptions = new scala.collection.mutable.HashMap[String, String] + /** * State isolated across sessions, including SQL configurations, temporary tables, registered * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]]. @@ -132,9 +138,11 @@ class SparkSession private( parentSessionState .map(_.clone(this)) .getOrElse { - SparkSession.instantiateSessionState( + val state = SparkSession.instantiateSessionState( SparkSession.sessionStateClassName(sparkContext.conf), self) + initialSessionOptions.foreach { case (k, v) => state.conf.setConfString(k, v) } + state } } @@ -940,7 +948,7 @@ object SparkSession { } session = new SparkSession(sparkContext, None, None, extensions) - options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } + options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } defaultSession.set(session) // Register a successfully instantiated context to the singleton. This should be at the From 4d6d8192c807006ff89488a1d38bc6f7d41de5cf Mon Sep 17 00:00:00 2001 From: dardelet Date: Tue, 4 Jul 2017 17:58:44 +0100 Subject: [PATCH 084/779] [SPARK-21268][MLLIB] Move center calculations to a distributed map in KMeans ## What changes were proposed in this pull request? The scal() and creation of newCenter vector is done in the driver, after a collectAsMap operation while it could be done in the distributed RDD. This PR moves this code before the collectAsMap for more efficiency ## How was this patch tested? This was tested manually by running the KMeansExample and verifying that the new code ran without error and gave same output as before. Author: dardelet Author: Guillaume Dardelet Closes #18491 from dardelet/move-center-calculation-to-distributed-map-kmean. --- .../org/apache/spark/mllib/clustering/KMeans.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fa72b72e2d921..98e50c5b45cfd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -272,8 +272,8 @@ class KMeans private ( val costAccum = sc.doubleAccumulator val bcCenters = sc.broadcast(centers) - // Find the sum and count of points mapping to each center - val totalContribs = data.mapPartitions { points => + // Find the new centers + val newCenters = data.mapPartitions { points => val thisCenters = bcCenters.value val dims = thisCenters.head.vector.size @@ -292,15 +292,16 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) + }.mapValues { case (sum, count) => + scal(1.0 / count, sum) + new VectorWithNorm(sum) }.collectAsMap() bcCenters.destroy(blocking = false) // Update the cluster centers and costs converged = true - totalContribs.foreach { case (j, (sum, count)) => - scal(1.0 / count, sum) - val newCenter = new VectorWithNorm(sum) + newCenters.foreach { case (j, newCenter) => if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) { converged = false } From cec392150451a64c9c2902b7f8f4b3b38f25cbea Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 4 Jul 2017 12:18:51 -0700 Subject: [PATCH 085/779] [SPARK-20889][SPARKR] Grouped documentation for WINDOW column methods ## What changes were proposed in this pull request? Grouped documentation for column window methods. Author: actuaryzhang Closes #18481 from actuaryzhang/sparkRDocWindow. --- R/pkg/R/functions.R | 225 ++++++++++++++------------------------------ R/pkg/R/generics.R | 28 +++--- 2 files changed, 88 insertions(+), 165 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a1f5c4f8cc18d..8c12308c1d7c1 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -200,6 +200,34 @@ NULL #' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))} NULL +#' Window functions for Column operations +#' +#' Window functions defined for \code{Column}. +#' +#' @param x In \code{lag} and \code{lead}, it is the column as a character string or a Column +#' to compute on. In \code{ntile}, it is the number of ntile groups. +#' @param offset In \code{lag}, the number of rows back from the current row from which to obtain +#' a value. In \code{lead}, the number of rows after the current row from which to +#' obtain a value. If not specified, the default is 1. +#' @param defaultValue (optional) default to use when the offset row does not exist. +#' @param ... additional argument(s). +#' @name column_window_functions +#' @rdname column_window_functions +#' @family window functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' tmp <- mutate(df, dist = over(cume_dist(), ws), dense_rank = over(dense_rank(), ws), +#' lag = over(lag(df$mpg), ws), lead = over(lead(df$mpg, 1), ws), +#' percent_rank = over(percent_rank(), ws), +#' rank = over(rank(), ws), row_number = over(row_number(), ws)) +#' # Get ntile group id (1-4) for hp +#' tmp <- mutate(tmp, ntile = over(ntile(4), ws)) +#' head(tmp)} +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -2844,27 +2872,16 @@ setMethod("ifelse", ###################### Window functions###################### -#' cume_dist -#' -#' Window function: returns the cumulative distribution of values within a window partition, -#' i.e. the fraction of rows that are below the current row. -#' -#' N = total number of rows in the partition -#' cume_dist(x) = number of values before (and including) x / N -#' +#' @details +#' \code{cume_dist}: Returns the cumulative distribution of values within a window partition, +#' i.e. the fraction of rows that are below the current row: +#' (number of values before and including x) / (total number of rows in the partition). #' This is equivalent to the \code{CUME_DIST} function in SQL. +#' The method should be used with no argument. #' -#' @rdname cume_dist -#' @name cume_dist -#' @family window functions -#' @aliases cume_dist,missing-method +#' @rdname column_window_functions +#' @aliases cume_dist cume_dist,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(cume_dist(), ws), df$hp, df$am) -#' } #' @note cume_dist since 1.6.0 setMethod("cume_dist", signature("missing"), @@ -2873,28 +2890,19 @@ setMethod("cume_dist", column(jc) }) -#' dense_rank -#' -#' Window function: returns the rank of rows within a window partition, without any gaps. +#' @details +#' \code{dense_rank}: Returns the rank of rows within a window partition, without any gaps. #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. Rank would give me sequential numbers, making #' the person that came in third place (after the ties) would register as coming in fifth. -#' #' This is equivalent to the \code{DENSE_RANK} function in SQL. +#' The method should be used with no argument. #' -#' @rdname dense_rank -#' @name dense_rank -#' @family window functions -#' @aliases dense_rank,missing-method +#' @rdname column_window_functions +#' @aliases dense_rank dense_rank,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(dense_rank(), ws), df$hp, df$am) -#' } #' @note dense_rank since 1.6.0 setMethod("dense_rank", signature("missing"), @@ -2903,34 +2911,15 @@ setMethod("dense_rank", column(jc) }) -#' lag -#' -#' Window function: returns the value that is \code{offset} rows before the current row, and +#' @details +#' \code{lag}: Returns the value that is \code{offset} rows before the current row, and #' \code{defaultValue} if there is less than \code{offset} rows before the current row. For example, #' an \code{offset} of one will return the previous row at any given point in the window partition. -#' #' This is equivalent to the \code{LAG} function in SQL. #' -#' @param x the column as a character string or a Column to compute on. -#' @param offset the number of rows back from the current row from which to obtain a value. -#' If not specified, the default is 1. -#' @param defaultValue (optional) default to use when the offset row does not exist. -#' @param ... further arguments to be passed to or from other methods. -#' @rdname lag -#' @name lag -#' @aliases lag,characterOrColumn-method -#' @family window functions +#' @rdname column_window_functions +#' @aliases lag lag,characterOrColumn-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Lag mpg values by 1 row on the partition-and-ordered table -#' out <- select(df, over(lag(df$mpg), ws), df$mpg, df$hp, df$am) -#' } #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), @@ -2946,35 +2935,16 @@ setMethod("lag", column(jc) }) -#' lead -#' -#' Window function: returns the value that is \code{offset} rows after the current row, and +#' @details +#' \code{lead}: Returns the value that is \code{offset} rows after the current row, and #' \code{defaultValue} if there is less than \code{offset} rows after the current row. #' For example, an \code{offset} of one will return the next row at any given point #' in the window partition. -#' #' This is equivalent to the \code{LEAD} function in SQL. #' -#' @param x the column as a character string or a Column to compute on. -#' @param offset the number of rows after the current row from which to obtain a value. -#' If not specified, the default is 1. -#' @param defaultValue (optional) default to use when the offset row does not exist. -#' -#' @rdname lead -#' @name lead -#' @family window functions -#' @aliases lead,characterOrColumn,numeric-method +#' @rdname column_window_functions +#' @aliases lead lead,characterOrColumn,numeric-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Lead mpg values by 1 row on the partition-and-ordered table -#' out <- select(df, over(lead(df$mpg), ws), df$mpg, df$hp, df$am) -#' } #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), @@ -2990,31 +2960,15 @@ setMethod("lead", column(jc) }) -#' ntile -#' -#' Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window +#' @details +#' \code{ntile}: Returns the ntile group id (from 1 to n inclusive) in an ordered window #' partition. For example, if n is 4, the first quarter of the rows will get value 1, the second #' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. -#' #' This is equivalent to the \code{NTILE} function in SQL. #' -#' @param x Number of ntile groups -#' -#' @rdname ntile -#' @name ntile -#' @aliases ntile,numeric-method -#' @family window functions +#' @rdname column_window_functions +#' @aliases ntile ntile,numeric-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Get ntile group id (1-4) for hp -#' out <- select(df, over(ntile(4), ws), df$hp, df$am) -#' } #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3023,27 +2977,15 @@ setMethod("ntile", column(jc) }) -#' percent_rank -#' -#' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. -#' -#' This is computed by: -#' -#' (rank of row in its partition - 1) / (number of rows in the partition - 1) -#' -#' This is equivalent to the PERCENT_RANK function in SQL. +#' @details +#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window partition. +#' This is computed by: (rank of row in its partition - 1) / (number of rows in the partition - 1). +#' This is equivalent to the \code{PERCENT_RANK} function in SQL. +#' The method should be used with no argument. #' -#' @rdname percent_rank -#' @name percent_rank -#' @family window functions -#' @aliases percent_rank,missing-method +#' @rdname column_window_functions +#' @aliases percent_rank percent_rank,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(percent_rank(), ws), df$hp, df$am) -#' } #' @note percent_rank since 1.6.0 setMethod("percent_rank", signature("missing"), @@ -3052,29 +2994,19 @@ setMethod("percent_rank", column(jc) }) -#' rank -#' -#' Window function: returns the rank of rows within a window partition. -#' +#' @details +#' \code{rank}: Returns the rank of rows within a window partition. #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. Rank would give me sequential numbers, making #' the person that came in third place (after the ties) would register as coming in fifth. +#' This is equivalent to the \code{RANK} function in SQL. +#' The method should be used with no argument. #' -#' This is equivalent to the RANK function in SQL. -#' -#' @rdname rank -#' @name rank -#' @family window functions -#' @aliases rank,missing-method +#' @rdname column_window_functions +#' @aliases rank rank,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(rank(), ws), df$hp, df$am) -#' } #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3083,11 +3015,7 @@ setMethod("rank", column(jc) }) -# Expose rank() in the R base package -#' @param x a numeric, complex, character or logical vector. -#' @param ... additional argument(s) passed to the method. -#' @name rank -#' @rdname rank +#' @rdname column_window_functions #' @aliases rank,ANY-method #' @export setMethod("rank", @@ -3096,23 +3024,14 @@ setMethod("rank", base::rank(x, ...) }) -#' row_number -#' -#' Window function: returns a sequential number starting at 1 within a window partition. -#' -#' This is equivalent to the ROW_NUMBER function in SQL. +#' @details +#' \code{row_number}: Returns a sequential number starting at 1 within a window partition. +#' This is equivalent to the \code{ROW_NUMBER} function in SQL. +#' The method should be used with no argument. #' -#' @rdname row_number -#' @name row_number -#' @aliases row_number,missing-method -#' @family window functions +#' @rdname column_window_functions +#' @aliases row_number row_number,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(row_number(), ws), df$hp, df$am) -#' } #' @note row_number since 1.6.0 setMethod("row_number", signature("missing"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b901b74e4728d..beac18e412736 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1013,9 +1013,9 @@ setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) -#' @param x empty. Should be used with no argument. -#' @rdname cume_dist +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname column_datetime_diff_functions @@ -1053,9 +1053,9 @@ setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) -#' @param x empty. Should be used with no argument. -#' @rdname dense_rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) #' @rdname column_string_functions @@ -1159,8 +1159,9 @@ setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) -#' @rdname lag +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last @@ -1172,8 +1173,9 @@ setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) -#' @rdname lead +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) #' @rdname column_nonaggregate_functions @@ -1260,8 +1262,9 @@ setGeneric("not", function(x) { standardGeneric("not") }) #' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) -#' @rdname ntile +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @rdname column_aggregate_functions @@ -1269,9 +1272,9 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @param x empty. Should be used with no argument. -#' @rdname percent_rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname column_math_functions @@ -1304,8 +1307,9 @@ setGeneric("rand", function(seed) { standardGeneric("rand") }) #' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) -#' @rdname rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("rank", function(x, ...) { standardGeneric("rank") }) #' @rdname column_string_functions @@ -1334,9 +1338,9 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) -#' @param x empty. Should be used with no argument. -#' @rdname row_number +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname column_string_functions From daabf425ec0272951b11f286e4bec7a48f42cc0d Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Tue, 4 Jul 2017 12:37:29 -0700 Subject: [PATCH 086/779] [MINOR][SPARKR] ignore Rplots.pdf test output after running R tests ## What changes were proposed in this pull request? After running R tests in local build, it outputs Rplots.pdf. This one should be ignored in the git repository. Author: wangmiao1981 Closes #18518 from wangmiao1981/ignore. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1d91b43c23fa7..cf9780db37ad7 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ R-unit-tests.log R/unit-tests.out R/cran-check.out R/pkg/vignettes/sparkr-vignettes.html +R/pkg/tests/fulltests/Rplots.pdf build/*.jar build/apache-maven* build/scala* From de14086e1f6a2474bb9ba1452ada94e0ce58cf9c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 5 Jul 2017 10:40:02 +0800 Subject: [PATCH 087/779] [SPARK-21295][SQL] Use qualified names in error message for missing references ### What changes were proposed in this pull request? It is strange to see the following error message. Actually, the column is from another table. ``` cannot resolve '`right.a`' given input columns: [a, c, d]; ``` After the PR, the error message looks like ``` cannot resolve '`right.a`' given input columns: [left.a, right.c, right.d]; ``` ### How was this patch tested? Added a test case Author: gatorsmile Closes #18520 from gatorsmile/removeSQLConf. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../results/columnresolution-negative.sql.out | 10 +++++----- .../results/columnresolution-views.sql.out | 2 +- .../sql-tests/results/columnresolution.sql.out | 18 +++++++++--------- .../sql-tests/results/group-by.sql.out | 2 +- .../sql-tests/results/table-aliases.sql.out | 2 +- .../org/apache/spark/sql/SubquerySuite.scala | 4 ++-- 7 files changed, 20 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fb81a7006bc5e..85c52792ef659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -86,7 +86,7 @@ trait CheckAnalysis extends PredicateHelper { case operator: LogicalPlan => operator transformExpressionsUp { case a: Attribute if !a.resolved => - val from = operator.inputSet.map(_.name).mkString(", ") + val from = operator.inputSet.map(_.qualifiedName).mkString(", ") a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]") case e: Expression if e.checkInputDataTypes().isFailure => diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index 60bd8e9cc99db..9e60e592c2bd1 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -90,7 +90,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb1.t1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 11 @@ -161,7 +161,7 @@ SELECT db1.t1.i1 FROM t1, mydb2.t1 struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve '`db1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`db1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 19 @@ -186,7 +186,7 @@ SELECT mydb1.t1 FROM t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 22 @@ -204,7 +204,7 @@ SELECT t1 FROM mydb1.t1 struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '`t1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`t1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 24 @@ -221,7 +221,7 @@ SELECT mydb1.t1.i1 FROM t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 26 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 616421d6f2b28..7c451c2aa5b5c 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -105,7 +105,7 @@ SELECT global_temp.view1.i1 FROM global_temp.view1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve '`global_temp.view1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`global_temp.view1.i1`' given input columns: [view1.i1]; line 1 pos 7 -- !query 13 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index 764cad0e3943c..d3ca4443cce55 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -96,7 +96,7 @@ SELECT mydb1.t1.i1 FROM t1 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 12 @@ -105,7 +105,7 @@ SELECT mydb1.t1.i1 FROM mydb1.t1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 13 @@ -154,7 +154,7 @@ SELECT mydb1.t1.i1 FROM mydb1.t1 struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 -- !query 19 @@ -270,7 +270,7 @@ SELECT * FROM mydb1.t3 WHERE c1 IN struct<> -- !query 32 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t4.c3`' given input columns: [c2, c3]; line 2 pos 42 +cannot resolve '`mydb1.t4.c3`' given input columns: [t4.c2, t4.c3]; line 2 pos 42 -- !query 33 @@ -287,7 +287,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb2.t1 struct<> -- !query 34 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 35 @@ -296,7 +296,7 @@ SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1 struct<> -- !query 35 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 36 @@ -313,7 +313,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb1.t1 struct<> -- !query 37 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 -- !query 38 @@ -402,7 +402,7 @@ SELECT mydb1.t5.t5.i1 FROM mydb1.t5 struct<> -- !query 48 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i1`' given input columns: [i1, t5]; line 1 pos 7 +cannot resolve '`mydb1.t5.t5.i1`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 -- !query 49 @@ -411,7 +411,7 @@ SELECT mydb1.t5.t5.i2 FROM mydb1.t5 struct<> -- !query 49 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i2`' given input columns: [i1, t5]; line 1 pos 7 +cannot resolve '`mydb1.t5.t5.i2`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 -- !query 50 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 14679850c692e..e23ebd4e822fa 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -202,7 +202,7 @@ SELECT a AS k, COUNT(b) FROM testData GROUP BY k struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 +cannot resolve '`k`' given input columns: [testdata.a, testdata.b]; line 1 pos 47 -- !query 22 diff --git a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out index c318018dced29..7abbcd834a523 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out @@ -60,4 +60,4 @@ SELECT a AS col1, b AS col2 FROM testData AS t(c, d) struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -cannot resolve '`a`' given input columns: [c, d]; line 1 pos 7 +cannot resolve '`a`' given input columns: [t.c, t.d]; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 820cff655c4ff..c0a3b5add313a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -870,9 +870,9 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("SPARK-20688: correctly check analysis for scalar sub-queries") { withTempView("t") { - Seq(1 -> "a").toDF("i", "j").createTempView("t") + Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("t") val e = intercept[AnalysisException](sql("SELECT (SELECT count(*) FROM t WHERE a = 1)")) - assert(e.message.contains("cannot resolve '`a`' given input columns: [i, j]")) + assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]")) } } } From ce10545d3401c555e56a214b7c2f334274803660 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 5 Jul 2017 11:24:38 +0800 Subject: [PATCH 088/779] [SPARK-21300][SQL] ExternalMapToCatalyst should null-check map key prior to converting to internal value. ## What changes were proposed in this pull request? `ExternalMapToCatalyst` should null-check map key prior to converting to internal value to throw an appropriate Exception instead of something like NPE. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #18524 from ueshin/issues/SPARK-21300. --- .../spark/sql/catalyst/JavaTypeInference.scala | 1 + .../spark/sql/catalyst/ScalaReflection.scala | 1 + .../catalyst/expressions/objects/objects.scala | 16 +++++++++++++++- .../encoders/ExpressionEncoderSuite.scala | 8 +++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7683ee7074e7d..90ec699877dec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -418,6 +418,7 @@ object JavaTypeInference { inputObject, ObjectType(keyType.getRawType), serializerFor(_, keyType), + keyNullable = true, ObjectType(valueType.getRawType), serializerFor(_, valueType), valueNullable = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d580cf4d3391c..f3c1e4150017d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -494,6 +494,7 @@ object ScalaReflection extends ScalaReflection { inputObject, dataTypeFor(keyType), serializerFor(_, keyType, keyPath, seenTypeSet), + keyNullable = !keyType.typeSymbol.asClass.isPrimitive, dataTypeFor(valueType), serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4b651836ff4d2..d6d06aecc077b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -841,18 +841,21 @@ object ExternalMapToCatalyst { inputMap: Expression, keyType: DataType, keyConverter: Expression => Expression, + keyNullable: Boolean, valueType: DataType, valueConverter: Expression => Expression, valueNullable: Boolean): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id + val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id val valueName = "ExternalMapToCatalyst_value" + id val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id ExternalMapToCatalyst( keyName, + keyIsNull, keyType, - keyConverter(LambdaVariable(keyName, "false", keyType, false)), + keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)), valueName, valueIsNull, valueType, @@ -868,6 +871,8 @@ object ExternalMapToCatalyst { * * @param key the name of the map key variable that used when iterate the map, and used as input for * the `keyConverter` + * @param keyIsNull the nullability of the map key variable that used when iterate the map, and + * used as input for the `keyConverter` * @param keyType the data type of the map key variable that used when iterate the map, and used as * input for the `keyConverter` * @param keyConverter A function that take the `key` as input, and converts it to catalyst format. @@ -883,6 +888,7 @@ object ExternalMapToCatalyst { */ case class ExternalMapToCatalyst private( key: String, + keyIsNull: String, keyType: DataType, keyConverter: Expression, value: String, @@ -913,6 +919,7 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) + ctx.addMutableState("boolean", keyIsNull, "") ctx.addMutableState(keyElementJavaType, key, "") ctx.addMutableState("boolean", valueIsNull, "") ctx.addMutableState(valueElementJavaType, value, "") @@ -950,6 +957,12 @@ case class ExternalMapToCatalyst private( defineEntries -> defineKeyValue } + val keyNullCheck = if (ctx.isPrimitiveType(keyType)) { + s"$keyIsNull = false;" + } else { + s"$keyIsNull = $key == null;" + } + val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { s"$valueIsNull = false;" } else { @@ -972,6 +985,7 @@ case class ExternalMapToCatalyst private( $defineEntries while($entries.hasNext()) { $defineKeyValue + $keyNullCheck $valueNullCheck ${genKeyConverter.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 080f11b769388..bb1955a1ae242 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -355,12 +355,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { checkNullable[String](true) } - test("null check for map key") { + test("null check for map key: String") { val encoder = ExpressionEncoder[Map[String, Int]]() val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) assert(e.getMessage.contains("Cannot use null as map key")) } + test("null check for map key: Integer") { + val encoder = ExpressionEncoder[Map[Integer, String]]() + val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null, "b")))) + assert(e.getMessage.contains("Cannot use null as map key")) + } + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { From e9a93f8140c913b91781b35e0e1b051c30244882 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 4 Jul 2017 21:05:05 -0700 Subject: [PATCH 089/779] [SPARK-20889][SPARKR][FOLLOWUP] Clean up grouped doc for column methods ## What changes were proposed in this pull request? Add doc for methods that were left out, and fix various style and consistency issues. Author: actuaryzhang Closes #18493 from actuaryzhang/sparkRDocCleanup. --- R/pkg/R/functions.R | 100 ++++++++++++++++++++------------------------ R/pkg/R/generics.R | 7 ++-- 2 files changed, 49 insertions(+), 58 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 8c12308c1d7c1..c529d83060f50 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -38,10 +38,10 @@ NULL #' #' Date time functions defined for \code{Column}. #' -#' @param x Column to compute on. +#' @param x Column to compute on. In \code{window}, it must be a time Column of \code{TimestampType}. #' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse -#' x Column to DateType or TimestampType. For \code{trunc}, it is the string used -#' for specifying the truncation method. For example, "year", "yyyy", "yy" for +#' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string +#' to use to specify the truncation method. For example, "year", "yyyy", "yy" for #' truncate by year, or "month", "mon", "mm" for truncate by month. #' @param ... additional argument(s). #' @name column_datetime_functions @@ -122,7 +122,7 @@ NULL #' format to. See 'Details'. #' } #' @param y Column to compute on. -#' @param ... additional columns. +#' @param ... additional Columns. #' @name column_string_functions #' @rdname column_string_functions #' @family string functions @@ -167,8 +167,7 @@ NULL #' tmp <- mutate(df, v1 = crc32(df$model), v2 = hash(df$model), #' v3 = hash(df$model, df$mpg), v4 = md5(df$model), #' v5 = sha1(df$model), v6 = sha2(df$model, 256)) -#' head(tmp) -#' } +#' head(tmp)} NULL #' Collection functions for Column operations @@ -190,7 +189,6 @@ NULL #' \dontrun{ #' # Dataframe used throughout this doc #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) -#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) @@ -394,7 +392,7 @@ setMethod("base64", }) #' @details -#' \code{bin}: An expression that returns the string representation of the binary value +#' \code{bin}: Returns the string representation of the binary value #' of the given long column. For example, bin("12") returns "1100". #' #' @rdname column_math_functions @@ -722,7 +720,7 @@ setMethod("dayofyear", #' \code{decode}: Computes the first argument into a string from a binary using the provided #' character set. #' -#' @param charset Character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", +#' @param charset character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", #' "UTF-16LE", "UTF-16"). #' #' @rdname column_string_functions @@ -855,7 +853,7 @@ setMethod("hex", }) #' @details -#' \code{hour}: Extracts the hours as an integer from a given date/timestamp/string. +#' \code{hour}: Extracts the hour as an integer from a given date/timestamp/string. #' #' @rdname column_datetime_functions #' @aliases hour hour,Column-method @@ -1177,7 +1175,7 @@ setMethod("min", }) #' @details -#' \code{minute}: Extracts the minutes as an integer from a given date/timestamp/string. +#' \code{minute}: Extracts the minute as an integer from a given date/timestamp/string. #' #' @rdname column_datetime_functions #' @aliases minute minute,Column-method @@ -1354,7 +1352,7 @@ setMethod("sd", }) #' @details -#' \code{second}: Extracts the seconds as an integer from a given date/timestamp/string. +#' \code{second}: Extracts the second as an integer from a given date/timestamp/string. #' #' @rdname column_datetime_functions #' @aliases second second,Column-method @@ -1464,20 +1462,18 @@ setMethod("soundex", column(jc) }) -#' Return the partition ID as a column -#' -#' Return the partition ID as a SparkDataFrame column. +#' @details +#' \code{spark_partition_id}: Returns the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. +#' This is equivalent to the \code{SPARK_PARTITION_ID} function in SQL. #' -#' This is equivalent to the SPARK_PARTITION_ID function in SQL. -#' -#' @rdname spark_partition_id -#' @name spark_partition_id -#' @aliases spark_partition_id,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases spark_partition_id spark_partition_id,missing-method #' @export #' @examples -#' \dontrun{select(df, spark_partition_id())} +#' +#' \dontrun{head(select(df, spark_partition_id()))} #' @note spark_partition_id since 2.0.0 setMethod("spark_partition_id", signature("missing"), @@ -2028,7 +2024,7 @@ setMethod("pmod", signature(y = "Column"), column(jc) }) -#' @param rsd maximum estimation error allowed (default = 0.05) +#' @param rsd maximum estimation error allowed (default = 0.05). #' #' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method @@ -2220,8 +2216,8 @@ setMethod("from_json", signature(x = "Column", schema = "structType"), #' @examples #' #' \dontrun{ -#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, 'PST'), -#' to_utc = to_utc_timestamp(df$time, 'PST')) +#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, "PST"), +#' to_utc = to_utc_timestamp(df$time, "PST")) #' head(tmp)} #' @note from_utc_timestamp since 1.5.0 setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), @@ -2255,7 +2251,7 @@ setMethod("instr", signature(y = "Column", x = "character"), #' @details #' \code{next_day}: Given a date column, returns the first date which is later than the value of #' the date column that is on the specified day of the week. For example, -#' \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first Sunday +#' \code{next_day("2015-07-27", "Sunday")} returns 2015-08-02 because that is the first Sunday #' after 2015-07-27. Day of the week parameter is case insensitive, and accepts first three or #' two characters: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' @@ -2295,7 +2291,7 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' tmp <- mutate(df, t1 = add_months(df$time, 1), #' t2 = date_add(df$time, 2), #' t3 = date_sub(df$time, 3), -#' t4 = next_day(df$time, 'Sun')) +#' t4 = next_day(df$time, "Sun")) #' head(tmp)} #' @note add_months since 1.5.0 setMethod("add_months", signature(y = "Column", x = "numeric"), @@ -2404,8 +2400,8 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), }) #' @details -#' \code{shiftRight}: (Unigned) shifts the given value numBits right. If the given value is a long value, -#' it will return a long value else it will return an integer value. +#' \code{shiftRightUnsigned}: (Unigned) shifts the given value numBits right. If the given value is +#' a long value, it will return a long value else it will return an integer value. #' #' @rdname column_math_functions #' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method @@ -2513,14 +2509,13 @@ setMethod("from_unixtime", signature(x = "Column"), column(jc) }) -#' window -#' -#' Bucketize rows into one or more time windows given a timestamp specifying column. Window -#' starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window +#' @details +#' \code{window}: Bucketizes rows into one or more time windows given a timestamp specifying column. +#' Window starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window #' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in -#' the order of months are not supported. +#' the order of months are not supported. It returns an output column of struct called 'window' +#' by default with the nested columns 'start' and 'end' #' -#' @param x a time Column. Must be of TimestampType. #' @param windowDuration a string specifying the width of the window, e.g. '1 second', #' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week', #' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. Note that @@ -2536,27 +2531,22 @@ setMethod("from_unixtime", signature(x = "Column"), #' window intervals. For example, in order to have hourly tumbling windows #' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide #' \code{startTime} as \code{"15 minutes"}. -#' @param ... further arguments to be passed to or from other methods. -#' @return An output column of struct called 'window' by default with the nested columns 'start' -#' and 'end'. -#' @family date time functions -#' @rdname window -#' @name window -#' @aliases window,Column-method +#' @rdname column_datetime_functions +#' @aliases window window,Column-method #' @export #' @examples -#'\dontrun{ -#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, -#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... -#' window(df$time, "1 minute", "15 seconds", "10 seconds") #' -#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, -#' # 09:01:15-09:02:15... -#' window(df$time, "1 minute", startTime = "15 seconds") +#' \dontrun{ +#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, +#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... +#' window(df$time, "1 minute", "15 seconds", "10 seconds") #' -#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... -#' window(df$time, "30 seconds", "10 seconds") -#'} +#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, +#' # 09:01:15-09:02:15... +#' window(df$time, "1 minute", startTime = "15 seconds") +#' +#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... +#' window(df$time, "30 seconds", "10 seconds")} #' @note window since 2.0.0 setMethod("window", signature(x = "Column"), function(x, windowDuration, slideDuration = NULL, startTime = NULL) { @@ -3046,7 +3036,7 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value A value to be checked if contained in the column +#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @export @@ -3091,7 +3081,7 @@ setMethod("size", #' to the natural ordering of the array elements. #' #' @rdname column_collection_functions -#' @param asc A logical flag indicating the sorting order. +#' @param asc a logical flag indicating the sorting order. #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases sort_array sort_array,Column-method @@ -3218,7 +3208,7 @@ setMethod("split_string", #' \code{repeat_string}: Repeats string n times. #' Equivalent to \code{repeat} SQL function. #' -#' @param n Number of repetitions +#' @param n number of repetitions. #' @rdname column_string_functions #' @aliases repeat_string repeat_string,Column-method #' @export @@ -3347,7 +3337,7 @@ setMethod("grouping_bit", #' \code{grouping_id}: Returns the level of grouping. #' Equals to \code{ #' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) -#' } +#' }. #' #' @rdname column_aggregate_functions #' @aliases grouping_id grouping_id,Column-method diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index beac18e412736..92098741f72f9 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1418,9 +1418,9 @@ setGeneric("split_string", function(x, pattern) { standardGeneric("split_string" #' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) -#' @param x empty. Should be used with no argument. -#' @rdname spark_partition_id +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) #' @rdname column_aggregate_functions @@ -1538,8 +1538,9 @@ setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) -#' @rdname window +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @rdname column_datetime_functions From f2c3b1dd69423cf52880e0ffa5f673ad6041b40e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 5 Jul 2017 14:17:26 +0800 Subject: [PATCH 090/779] [SPARK-21304][SQL] remove unnecessary isNull variable for collection related encoder expressions ## What changes were proposed in this pull request? For these collection-related encoder expressions, we don't need to create `isNull` variable if the loop element is not nullable. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18529 from cloud-fan/minor. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../expressions/objects/objects.scala | 77 ++++++++++++------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f3c1e4150017d..bea0de4d90c2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -335,7 +335,7 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - CollectObjectsToMap( + CatalystToExternalMap( p => deserializerFor(keyType, Some(p), walkedTypePath), p => deserializerFor(valueType, Some(p), walkedTypePath), getPath, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d6d06aecc077b..ce07f4a25c189 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -465,7 +465,11 @@ object MapObjects { customCollectionCls: Option[Class[_]] = None): MapObjects = { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" - val loopIsNull = s"MapObjects_loopIsNull$id" + val loopIsNull = if (elementNullable) { + s"MapObjects_loopIsNull$id" + } else { + "false" + } val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) MapObjects( loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) @@ -517,7 +521,6 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState("boolean", loopIsNull, "") ctx.addMutableState(elementJavaType, loopValue, "") val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) @@ -588,12 +591,14 @@ case class MapObjects private( case _ => genFunction.value } - val loopNullCheck = inputDataType match { - case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" - // The element of primitive array will never be null. - case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => - s"$loopIsNull = false" - case _ => s"$loopIsNull = $loopValue == null;" + val loopNullCheck = if (loopIsNull != "false") { + ctx.addMutableState("boolean", loopIsNull, "") + inputDataType match { + case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + case _ => s"$loopIsNull = $loopValue == null;" + } + } else { + "" } val (initCollection, addElement, getResult): (String, String => String, String) = @@ -667,11 +672,11 @@ case class MapObjects private( } } -object CollectObjectsToMap { +object CatalystToExternalMap { private val curId = new java.util.concurrent.atomic.AtomicInteger() /** - * Construct an instance of CollectObjectsToMap case class. + * Construct an instance of CatalystToExternalMap case class. * * @param keyFunction The function applied on the key collection elements. * @param valueFunction The function applied on the value collection elements. @@ -682,15 +687,19 @@ object CollectObjectsToMap { keyFunction: Expression => Expression, valueFunction: Expression => Expression, inputData: Expression, - collClass: Class[_]): CollectObjectsToMap = { + collClass: Class[_]): CatalystToExternalMap = { val id = curId.getAndIncrement() - val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" + val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id" val mapType = inputData.dataType.asInstanceOf[MapType] val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) - val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" - val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" + val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id" + val valueLoopIsNull = if (mapType.valueContainsNull) { + s"CatalystToExternalMap_valueLoopIsNull$id" + } else { + "false" + } val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) - CollectObjectsToMap( + CatalystToExternalMap( keyLoopValue, keyFunction(keyLoopVar), valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), inputData, collClass) @@ -716,7 +725,7 @@ object CollectObjectsToMap { * @param inputData An expression that when evaluated returns a map object. * @param collClass The type of the resulting collection. */ -case class CollectObjectsToMap private( +case class CatalystToExternalMap private( keyLoopValue: String, keyLambdaFunction: Expression, valueLoopValue: String, @@ -748,7 +757,6 @@ case class CollectObjectsToMap private( ctx.addMutableState(keyElementJavaType, keyLoopValue, "") val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState("boolean", valueLoopIsNull, "") ctx.addMutableState(valueElementJavaType, valueLoopValue, "") val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) @@ -781,7 +789,12 @@ case class CollectObjectsToMap private( val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) - val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + val valueLoopNullCheck = if (valueLoopIsNull != "false") { + ctx.addMutableState("boolean", valueLoopIsNull, "") + s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + } else { + "" + } val builderClass = classOf[Builder[_, _]].getName val constructBuilder = s""" @@ -847,9 +860,17 @@ object ExternalMapToCatalyst { valueNullable: Boolean): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id - val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id + val keyIsNull = if (keyNullable) { + "ExternalMapToCatalyst_key_isNull" + id + } else { + "false" + } val valueName = "ExternalMapToCatalyst_value" + id - val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id + val valueIsNull = if (valueNullable) { + "ExternalMapToCatalyst_value_isNull" + id + } else { + "false" + } ExternalMapToCatalyst( keyName, @@ -919,9 +940,7 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState("boolean", keyIsNull, "") ctx.addMutableState(keyElementJavaType, key, "") - ctx.addMutableState("boolean", valueIsNull, "") ctx.addMutableState(valueElementJavaType, value, "") val (defineEntries, defineKeyValue) = child.dataType match { @@ -957,16 +976,18 @@ case class ExternalMapToCatalyst private( defineEntries -> defineKeyValue } - val keyNullCheck = if (ctx.isPrimitiveType(keyType)) { - s"$keyIsNull = false;" - } else { + val keyNullCheck = if (keyIsNull != "false") { + ctx.addMutableState("boolean", keyIsNull, "") s"$keyIsNull = $key == null;" + } else { + "" } - val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { - s"$valueIsNull = false;" - } else { + val valueNullCheck = if (valueIsNull != "false") { + ctx.addMutableState("boolean", valueIsNull, "") s"$valueIsNull = $value == null;" + } else { + "" } val arrayCls = classOf[GenericArrayData].getName From a38643256691947ff7f7c474b85c052a7d5d8553 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 5 Jul 2017 14:25:26 +0800 Subject: [PATCH 091/779] [SPARK-18623][SQL] Add `returnNullable` to `StaticInvoke` and modify it to handle properly. ## What changes were proposed in this pull request? Add `returnNullable` to `StaticInvoke` the same as #15780 is trying to add to `Invoke` and modify to handle properly. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Author: Takuya UESHIN Closes #16056 from ueshin/issues/SPARK-18623. --- .../sql/catalyst/JavaTypeInference.scala | 21 +++++---- .../spark/sql/catalyst/ScalaReflection.scala | 44 ++++++++++++------- .../sql/catalyst/encoders/RowEncoder.scala | 27 ++++++++---- .../expressions/objects/objects.scala | 42 ++++++++++++++---- 4 files changed, 91 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 90ec699877dec..21363d3ba82c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -216,7 +216,7 @@ object JavaTypeInference { ObjectType(c), "valueOf", getPath :: Nil, - propagateNull = true) + returnNullable = false) case c if c == classOf[java.sql.Date] => StaticInvoke( @@ -224,7 +224,7 @@ object JavaTypeInference { ObjectType(c), "toJavaDate", getPath :: Nil, - propagateNull = true) + returnNullable = false) case c if c == classOf[java.sql.Timestamp] => StaticInvoke( @@ -232,7 +232,7 @@ object JavaTypeInference { ObjectType(c), "toJavaTimestamp", getPath :: Nil, - propagateNull = true) + returnNullable = false) case c if c == classOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String])) @@ -300,7 +300,8 @@ object JavaTypeInference { ArrayBasedMapData.getClass, ObjectType(classOf[JMap[_, _]]), "toJavaMap", - keyData :: valueData :: Nil) + keyData :: valueData :: Nil, + returnNullable = false) case other => val properties = getJavaBeanReadableAndWritableProperties(other) @@ -367,28 +368,32 @@ object JavaTypeInference { classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.math.BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case c if c == classOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bea0de4d90c2f..814f2c10b9097 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -206,51 +206,53 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil) + getPath :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil) + getPath :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) @@ -446,7 +448,8 @@ object ScalaReflection extends ScalaReflection { classOf[UnsafeArrayData], ArrayType(dt, false), "fromPrimitiveArray", - input :: Nil) + input :: Nil, + returnNullable = false) } else { NewInstance( classOf[GenericArrayData], @@ -504,49 +507,56 @@ object ScalaReflection extends ScalaReflection { classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t if t <:< localTypeOf[java.lang.Integer] => Invoke(inputObject, "intValue", IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0f8282d3b2f1f..cc32fac67e924 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -96,28 +96,32 @@ object RowEncoder { DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case DateType => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case d: DecimalType => StaticInvoke( Decimal.getClass, d, "fromDecimal", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case StringType => StaticInvoke( classOf[UTF8String], StringType, "fromString", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case t @ ArrayType(et, cn) => et match { @@ -126,7 +130,8 @@ object RowEncoder { classOf[ArrayData], t, "toArrayData", - inputObject :: Nil) + inputObject :: Nil, + returnNullable = false) case _ => MapObjects( element => serializerFor(ValidateExternalType(element, et), et), inputObject, @@ -254,14 +259,16 @@ object RowEncoder { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - input :: Nil) + input :: Nil, + returnNullable = false) case DateType => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - input :: Nil) + input :: Nil, + returnNullable = false) case _: DecimalType => Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), @@ -280,7 +287,8 @@ object RowEncoder { scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", - arrayData :: Nil) + arrayData :: Nil, + returnNullable = false) case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) @@ -293,7 +301,8 @@ object RowEncoder { ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", - keyData :: valueData :: Nil) + keyData :: valueData :: Nil, + returnNullable = false) case schema @ StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ce07f4a25c189..24c06d8b14b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -118,17 +118,20 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param arguments An optional list of expressions to pass as arguments to the function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. + * @param returnNullable When false, indicating the invoked method will always return + * non-null value. */ case class StaticInvoke( staticObject: Class[_], dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends InvokeLike { + propagateNull: Boolean = true, + returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") - override def nullable: Boolean = true + override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments override def eval(input: InternalRow): Any = @@ -141,19 +144,40 @@ case class StaticInvoke( val callFunc = s"$objectName.$functionName($argString)" - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + val prepareIsNull = if (nullable) { + s"boolean ${ev.isNull} = $resultIsNull;" } else { + ev.isNull = "false" "" } + val evaluate = if (returnNullable) { + if (ctx.defaultValue(dataType) == "null") { + s""" + ${ev.value} = $callFunc; + ${ev.isNull} = ${ev.value} == null; + """ + } else { + val boxedResult = ctx.freshName("boxedResult") + s""" + ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${ev.isNull} = $boxedResult == null; + if (!${ev.isNull}) { + ${ev.value} = $boxedResult; + } + """ + } + } else { + s"${ev.value} = $callFunc;" + } + val code = s""" $argCode - boolean ${ev.isNull} = $resultIsNull; - final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; - $postNullCheck + $prepareIsNull + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!$resultIsNull) { + $evaluate + } """ ev.copy(code = code) } From 4852b7d447e872079c2c81428354adc825a87b27 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 5 Jul 2017 18:41:00 +0800 Subject: [PATCH 092/779] [SPARK-21310][ML][PYSPARK] Expose offset in PySpark ## What changes were proposed in this pull request? Add offset to PySpark in GLM as in #16699. ## How was this patch tested? Python test Author: actuaryzhang Closes #18534 from actuaryzhang/pythonOffset. --- python/pyspark/ml/regression.py | 25 +++++++++++++++++++++---- python/pyspark/ml/tests.py | 14 ++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 84d843369e105..f0ff7a5f59abf 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1376,17 +1376,20 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha typeConverter=TypeConverters.toFloat) solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + "options: irls.", typeConverter=TypeConverters.toString) + offsetCol = Param(Params._dummy(), "offsetCol", "The offset column name. If this is not set " + + "or empty, we treat all instance offsets as 0.0", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, - variancePower=0.0, linkPower=None): + variancePower=0.0, linkPower=None, offsetCol=None): """ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ - variancePower=0.0, linkPower=None) + variancePower=0.0, linkPower=None, offsetCol=None) """ super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -1402,12 +1405,12 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, - variancePower=0.0, linkPower=None): + variancePower=0.0, linkPower=None, offsetCol=None): """ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ - variancePower=0.0, linkPower=None) + variancePower=0.0, linkPower=None, offsetCol=None) Sets params for generalized linear regression. """ kwargs = self._input_kwargs @@ -1486,6 +1489,20 @@ def getLinkPower(self): """ return self.getOrDefault(self.linkPower) + @since("2.3.0") + def setOffsetCol(self, value): + """ + Sets the value of :py:attr:`offsetCol`. + """ + return self._set(offsetCol=value) + + @since("2.3.0") + def getOffsetCol(self): + """ + Gets the value of offsetCol or its default value. + """ + return self.getOrDefault(self.offsetCol) + class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ffb8b0a890ff8..7870047651601 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1291,6 +1291,20 @@ def test_tweedie_distribution(self): self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + def test_offset(self): + + df = self.spark.createDataFrame( + [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) + + glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], + atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) + class FPGrowthTests(SparkSessionTestCase): def setUp(self): From 873f3ad2b89c955f42fced49dc129e8efa77d044 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 5 Jul 2017 20:32:47 +0800 Subject: [PATCH 093/779] [SPARK-16167][SQL] RowEncoder should preserve array/map type nullability. ## What changes were proposed in this pull request? Currently `RowEncoder` doesn't preserve nullability of `ArrayType` or `MapType`. It returns always `containsNull = true` for `ArrayType`, `valueContainsNull = true` for `MapType` and also the nullability of itself is always `true`. This pr fixes the nullability of them. ## How was this patch tested? Add tests to check if `RowEncoder` preserves array/map nullability. Author: Takuya UESHIN Author: Takuya UESHIN Closes #13873 from ueshin/issues/SPARK-16167. --- .../sql/catalyst/encoders/RowEncoder.scala | 25 +++++++++++--- .../catalyst/encoders/RowEncoderSuite.scala | 33 +++++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index cc32fac67e924..43c35bbdf383a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -123,7 +123,7 @@ object RowEncoder { inputObject :: Nil, returnNullable = false) - case t @ ArrayType(et, cn) => + case t @ ArrayType(et, containsNull) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => StaticInvoke( @@ -132,8 +132,16 @@ object RowEncoder { "toArrayData", inputObject :: Nil, returnNullable = false) + case _ => MapObjects( - element => serializerFor(ValidateExternalType(element, et), et), + element => { + val value = serializerFor(ValidateExternalType(element, et), et) + if (!containsNull) { + AssertNotNull(value, Seq.empty) + } else { + value + } + }, inputObject, ObjectType(classOf[Object])) } @@ -155,10 +163,19 @@ object RowEncoder { ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) - NewInstance( + val nonNullOutput = NewInstance( classOf[ArrayBasedMapData], convertedKeys :: convertedValues :: Nil, - dataType = t) + dataType = t, + propagateNull = false) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } case StructType(fields) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 1a5569a77dc7a..6ed175f86ca77 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -273,6 +273,39 @@ class RowEncoderSuite extends SparkFunSuite { assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } + for { + elementType <- Seq(IntegerType, StringType) + containsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve array nullability: " + + s"ArrayType($elementType, containsNull = $containsNull), nullable = $nullable") { + val schema = new StructType().add("array", ArrayType(elementType, containsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == ArrayType(elementType, containsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + + for { + keyType <- Seq(IntegerType, StringType) + valueType <- Seq(IntegerType, StringType) + valueContainsNull <- Seq(true, false) + nullable <- Seq(true, false) + } { + test("RowEncoder should preserve map nullability: " + + s"MapType($keyType, $valueType, valueContainsNull = $valueContainsNull), " + + s"nullable = $nullable") { + val schema = new StructType().add( + "map", MapType(keyType, valueType, valueContainsNull), nullable) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == MapType(keyType, valueType, valueContainsNull)) + assert(encoder.serializer.head.nullable == nullable) + } + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema).resolveAndBind() From 5787ace463b2abde50d2ca24e8dd111e3a7c158e Mon Sep 17 00:00:00 2001 From: ouyangxiaochen Date: Wed, 5 Jul 2017 20:46:42 +0800 Subject: [PATCH 094/779] [SPARK-20383][SQL] Supporting Create [temporary] Function with the keyword 'OR REPLACE' and 'IF NOT EXISTS' ## What changes were proposed in this pull request? support to create [temporary] function with the keyword 'OR REPLACE' and 'IF NOT EXISTS' ## How was this patch tested? manual test and added test cases Please review http://spark.apache.org/contributing.html before opening a pull request. Author: ouyangxiaochen Closes #17681 from ouyangxiaochen/spark-419. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../catalyst/catalog/ExternalCatalog.scala | 9 ++++ .../catalyst/catalog/InMemoryCatalog.scala | 6 +++ .../sql/catalyst/catalog/SessionCatalog.scala | 23 ++++++++ .../spark/sql/catalyst/catalog/events.scala | 10 ++++ .../catalog/ExternalCatalogEventSuite.scala | 9 ++++ .../catalog/ExternalCatalogSuite.scala | 9 ++++ .../spark/sql/execution/SparkSqlParser.scala | 8 +-- .../sql/execution/command/functions.scala | 46 +++++++++++----- .../execution/command/DDLCommandSuite.scala | 52 ++++++++++++++++++- .../sql/execution/command/DDLSuite.scala | 51 ++++++++++++++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 9 ++++ 12 files changed, 216 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 29f554451ed4a..ef9f88a9026c9 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -126,7 +126,8 @@ statement tableIdentifier ('(' colTypeList ')')? tableProvider (OPTIONS tablePropertyList)? #createTempViewUsing | ALTER VIEW tableIdentifier AS? query #alterViewQuery - | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? + qualifiedName AS className=STRING (USING resource (',' resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 0254b6bb6d136..6000d483db209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -332,6 +332,15 @@ abstract class ExternalCatalog protected def doDropFunction(db: String, funcName: String): Unit + final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(AlterFunctionPreEvent(db, name)) + doAlterFunction(db, funcDefinition) + postToAll(AlterFunctionEvent(db, name)) + } + + protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit + final def renameFunction(db: String, oldName: String, newName: String): Unit = { postToAll(RenameFunctionPreEvent(db, oldName, newName)) doRenameFunction(db, oldName, newName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 747190faa3c8c..d253c72a62739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -590,6 +590,12 @@ class InMemoryCatalog( catalog(db).functions.remove(funcName) } + override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized { + requireDbExists(db) + requireFunctionExists(db, func.identifier.funcName) + catalog(db).functions.put(func.identifier.funcName, func) + } + override protected def doRenameFunction( db: String, oldName: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a86604e4353ab..c40d5f6031a21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1055,6 +1055,29 @@ class SessionCatalog( } } + /** + * overwirte a metastore function in the database specified in `funcDefinition`.. + * If no database is specified, assume the function is in the current database. + */ + def alterFunction(funcDefinition: CatalogFunction): Unit = { + val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) + val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)) + val newFuncDefinition = funcDefinition.copy(identifier = identifier) + if (functionExists(identifier)) { + if (functionRegistry.functionExists(identifier)) { + // If we have loaded this function into the FunctionRegistry, + // also drop it from there. + // For a permanent function, because we loaded it to the FunctionRegistry + // when it's first used, we also need to drop it from the FunctionRegistry. + functionRegistry.dropFunction(identifier) + } + externalCatalog.alterFunction(db, newFuncDefinition) + } else { + throw new NoSuchFunctionException(db = db, func = identifier.toString) + } + } + /** * Retrieve the metadata of a metastore function. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala index 459973a13bb10..742a51e640383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala @@ -139,6 +139,16 @@ case class DropFunctionPreEvent(database: String, name: String) extends Function */ case class DropFunctionEvent(database: String, name: String) extends FunctionEvent +/** + * Event fired before a function is altered. + */ +case class AlterFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been altered. + */ +case class AlterFunctionEvent(database: String, name: String) extends FunctionEvent + /** * Event fired before a function is renamed. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 2539ea615ff92..087c26f23f383 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -176,6 +176,15 @@ class ExternalCatalogEventSuite extends SparkFunSuite { } checkEvents(RenameFunctionPreEvent("db5", "fn7", "fn4") :: Nil) + // ALTER + val alteredFunctionDefinition = CatalogFunction( + identifier = FunctionIdentifier("fn4", Some("db5")), + className = "org.apache.spark.AlterFunction", + resources = Seq.empty) + catalog.alterFunction("db5", alteredFunctionDefinition) + checkEvents( + AlterFunctionPreEvent("db5", "fn4") :: AlterFunctionEvent("db5", "fn4") :: Nil) + // DROP intercept[AnalysisException] { catalog.dropFunction("db5", "fn7") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index c22d55fc96a65..66e895a4690c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -752,6 +752,14 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac } } + test("alter function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1").className == funcClass) + val myNewFunc = catalog.getFunction("db2", "func1").copy(className = newFuncClass) + catalog.alterFunction("db2", myNewFunc) + assert(catalog.getFunction("db2", "func1").className == newFuncClass) + } + test("list functions") { val catalog = newBasicCatalog() catalog.createFunction("db2", newFunc("func2")) @@ -916,6 +924,7 @@ abstract class CatalogTestUtils { lazy val partWithEmptyValue = CatalogTablePartition(Map("a" -> "3", "b" -> ""), storageFormat) lazy val funcClass = "org.apache.spark.myFunc" + lazy val newFuncClass = "org.apache.spark.myNewFunc" /** * Creates a basic catalog, with the following structure: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 2b79eb5eac0f1..2f8e416e7df1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -687,8 +687,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * * For example: * {{{ - * CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name - * [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri']]; + * CREATE [OR REPLACE] [TEMPORARY] FUNCTION [IF NOT EXISTS] [db_name.]function_name + * AS class_name [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri']]; * }}} */ override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { @@ -709,7 +709,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { functionIdentifier.funcName, string(ctx.className), resources, - ctx.TEMPORARY != null) + ctx.TEMPORARY != null, + ctx.EXISTS != null, + ctx.REPLACE != null) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index a91ad413f4d1b..4f92ffee687aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -31,13 +31,13 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} * The DDL command that creates a function. * To create a temporary function, the syntax of using this command in SQL is: * {{{ - * CREATE TEMPORARY FUNCTION functionName + * CREATE [OR REPLACE] TEMPORARY FUNCTION functionName * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] * }}} * * To create a permanent function, the syntax in SQL is: * {{{ - * CREATE FUNCTION [databaseName.]functionName + * CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [databaseName.]functionName * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] * }}} */ @@ -46,26 +46,46 @@ case class CreateFunctionCommand( functionName: String, className: String, resources: Seq[FunctionResource], - isTemp: Boolean) + isTemp: Boolean, + ifNotExists: Boolean, + replace: Boolean) extends RunnableCommand { + if (ifNotExists && replace) { + throw new AnalysisException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE" + + " is not allowed.") + } + + // Disallow to define a temporary function with `IF NOT EXISTS` + if (ifNotExists && isTemp) { + throw new AnalysisException( + "It is not allowed to define a TEMPORARY function with IF NOT EXISTS.") + } + + // Temporary function names should not contain database prefix like "database.function" + if (databaseName.isDefined && isTemp) { + throw new AnalysisException(s"Specifying a database in CREATE TEMPORARY FUNCTION " + + s"is not allowed: '${databaseName.get}'") + } + override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val func = CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources) if (isTemp) { - if (databaseName.isDefined) { - throw new AnalysisException(s"Specifying a database in CREATE TEMPORARY FUNCTION " + - s"is not allowed: '${databaseName.get}'") - } // We first load resources and then put the builder in the function registry. catalog.loadFunctionResources(resources) - catalog.registerFunction(func, overrideIfExists = false) + catalog.registerFunction(func, overrideIfExists = replace) } else { - // For a permanent, we will store the metadata into underlying external catalog. - // This function will be loaded into the FunctionRegistry when a query uses it. - // We do not load it into FunctionRegistry right now. - // TODO: should we also parse "IF NOT EXISTS"? - catalog.createFunction(func, ignoreIfExists = false) + // Handles `CREATE OR REPLACE FUNCTION AS ... USING ...` + if (replace && catalog.functionExists(func.identifier)) { + // alter the function in the metastore + catalog.alterFunction(CatalogFunction(func.identifier, className, resources)) + } else { + // For a permanent, we will store the metadata into underlying external catalog. + // This function will be loaded into the FunctionRegistry when a query uses it. + // We do not load it into FunctionRegistry right now. + catalog.createFunction(CatalogFunction(func.identifier, className, resources), ifNotExists) + } } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 8a6bc62fec96c..5643c58d9f847 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -181,8 +181,29 @@ class DDLCommandSuite extends PlanTest { |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', |FILE '/path/to/file' """.stripMargin + val sql3 = + """ + |CREATE OR REPLACE TEMPORARY FUNCTION helloworld3 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val sql4 = + """ + |CREATE OR REPLACE FUNCTION hello.world1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', + |FILE '/path/to/file' + """.stripMargin + val sql5 = + """ + |CREATE FUNCTION IF NOT EXISTS hello.world2 as + |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', + |FILE '/path/to/file' + """.stripMargin val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) + val parsed3 = parser.parsePlan(sql3) + val parsed4 = parser.parsePlan(sql4) + val parsed5 = parser.parsePlan(sql5) val expected1 = CreateFunctionCommand( None, "helloworld", @@ -190,7 +211,7 @@ class DDLCommandSuite extends PlanTest { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true) + isTemp = true, ifNotExists = false, replace = false) val expected2 = CreateFunctionCommand( Some("hello"), "world", @@ -198,9 +219,36 @@ class DDLCommandSuite extends PlanTest { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false) + isTemp = false, ifNotExists = false, replace = false) + val expected3 = CreateFunctionCommand( + None, + "helloworld3", + "com.matthewrathbone.example.SimpleUDFExample", + Seq( + FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), + FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), + isTemp = true, ifNotExists = false, replace = true) + val expected4 = CreateFunctionCommand( + Some("hello"), + "world1", + "com.matthewrathbone.example.SimpleUDFExample", + Seq( + FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), + FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), + isTemp = false, ifNotExists = false, replace = true) + val expected5 = CreateFunctionCommand( + Some("hello"), + "world2", + "com.matthewrathbone.example.SimpleUDFExample", + Seq( + FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), + FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), + isTemp = false, ifNotExists = true, replace = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + comparePlans(parsed5, expected5) } test("drop function") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e4dd077715d0f..5c40d8bb4b1ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2270,6 +2270,57 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create temporary function with if not exists") { + withUserDefinedFunction("func1" -> true) { + val sql1 = + """ + |CREATE TEMPORARY FUNCTION IF NOT EXISTS func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val e = intercept[AnalysisException] { + sql(sql1) + }.getMessage + assert(e.contains("It is not allowed to define a TEMPORARY function with IF NOT EXISTS")) + } + } + + test("create function with both if not exists and replace") { + withUserDefinedFunction("func1" -> false) { + val sql1 = + """ + |CREATE OR REPLACE FUNCTION IF NOT EXISTS func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val e = intercept[AnalysisException] { + sql(sql1) + }.getMessage + assert(e.contains("CREATE FUNCTION with both IF NOT EXISTS and REPLACE is not allowed")) + } + } + + test("create temporary function by specifying a database") { + val dbName = "mydb" + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + sql(s"USE $dbName") + withUserDefinedFunction("func1" -> true) { + val sql1 = + s""" + |CREATE TEMPORARY FUNCTION $dbName.func1 as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val e = intercept[AnalysisException] { + sql(sql1) + }.getMessage + assert(e.contains(s"Specifying a database in CREATE TEMPORARY FUNCTION " + + s"is not allowed: '$dbName'")) + } + } + } + Seq(true, false).foreach { caseSensitive => test(s"alter table add columns with existing column name - caseSensitive $caseSensitive") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 2a17849fa8a34..306b38048e3a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1132,6 +1132,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropFunction(db, name) } + override protected def doAlterFunction( + db: String, funcDefinition: CatalogFunction): Unit = withClient { + requireDbExists(db) + val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) + requireFunctionExists(db, functionName) + val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) + client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) + } + override protected def doRenameFunction( db: String, oldName: String, From e3e2b5da3671a6c6d152b4de481a8aa3e57a6e42 Mon Sep 17 00:00:00 2001 From: "he.qiao" Date: Wed, 5 Jul 2017 21:13:25 +0800 Subject: [PATCH 095/779] [SPARK-21286][TEST] Modified StorageTabSuite unit test ## What changes were proposed in this pull request? The old unit test not effect ## How was this patch tested? unit test Author: he.qiao Closes #18511 from Geek-He/dev_0703. --- .../scala/org/apache/spark/ui/storage/StorageTabSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 66dda382eb653..1cb52593e7060 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -74,7 +74,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { // Submitting RDDInfos with duplicate IDs does nothing val rddInfo0Cached = new RDDInfo(0, "freedom", 100, StorageLevel.MEMORY_ONLY, Seq(10)) rddInfo0Cached.numCachedPartitions = 1 - val stageInfo0Cached = new StageInfo(0, 0, "0", 100, Seq(rddInfo0), Seq.empty, "details") + val stageInfo0Cached = new StageInfo(0, 0, "0", 100, Seq(rddInfo0Cached), Seq.empty, "details") bus.postToAll(SparkListenerStageSubmitted(stageInfo0Cached)) assert(storageListener._rddInfoMap.size === 4) assert(storageListener.rddInfoList.size === 2) From 960298ee66b9b8a80f84df679ce5b4b3846267f4 Mon Sep 17 00:00:00 2001 From: sadikovi Date: Wed, 5 Jul 2017 14:40:44 +0100 Subject: [PATCH 096/779] [SPARK-20858][DOC][MINOR] Document ListenerBus event queue size ## What changes were proposed in this pull request? This change adds a new configuration option `spark.scheduler.listenerbus.eventqueue.size` to the configuration docs to specify the capacity of the spark listener bus event queue. Default value is 10000. This is doc PR for [SPARK-15703](https://issues.apache.org/jira/browse/SPARK-15703). I added option to the `Scheduling` section, however it might be more related to `Spark UI` section. ## How was this patch tested? Manually verified correct rendering of configuration option. Author: sadikovi Author: Ivan Sadikov Closes #18476 from sadikovi/SPARK-20858. --- docs/configuration.md | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index bd6a1f9e240e2..c785a664c67b1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -725,7 +725,7 @@ Apart from these, the following properties are also available, and may be useful spark.ui.retainedJobs 1000 - How many jobs the Spark UI and status APIs remember before garbage collecting. + How many jobs the Spark UI and status APIs remember before garbage collecting. This is a target maximum, and fewer elements may be retained in some circumstances. @@ -733,7 +733,7 @@ Apart from these, the following properties are also available, and may be useful spark.ui.retainedStages 1000 - How many stages the Spark UI and status APIs remember before garbage collecting. + How many stages the Spark UI and status APIs remember before garbage collecting. This is a target maximum, and fewer elements may be retained in some circumstances. @@ -741,7 +741,7 @@ Apart from these, the following properties are also available, and may be useful spark.ui.retainedTasks 100000 - How many tasks the Spark UI and status APIs remember before garbage collecting. + How many tasks the Spark UI and status APIs remember before garbage collecting. This is a target maximum, and fewer elements may be retained in some circumstances. @@ -1389,6 +1389,15 @@ Apart from these, the following properties are also available, and may be useful The interval length for the scheduler to revive the worker resource offers to run tasks. + + spark.scheduler.listenerbus.eventqueue.capacity + 10000 + + Capacity for event queue in Spark listener bus, must be greater than 0. Consider increasing + value (e.g. 20000) if listener events are dropped. Increasing this value may result in the + driver using more memory. + + spark.blacklist.enabled @@ -1475,8 +1484,8 @@ Apart from these, the following properties are also available, and may be useful spark.blacklist.application.fetchFailure.enabled false - (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch - failure happenes. If external shuffle service is enabled, then the whole node will be + (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch + failure happenes. If external shuffle service is enabled, then the whole node will be blacklisted. From 742da0868534dab3d4d7b7edbe5ba9dc8bf26cc8 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 5 Jul 2017 10:59:10 -0700 Subject: [PATCH 097/779] [SPARK-19439][PYSPARK][SQL] PySpark's registerJavaFunction Should Support UDAFs ## What changes were proposed in this pull request? Support register Java UDAFs in PySpark so that user can use Java UDAF in PySpark. Besides that I also add api in `UDFRegistration` ## How was this patch tested? Unit test is added Author: Jeff Zhang Closes #17222 from zjffdu/SPARK-19439. --- python/pyspark/sql/context.py | 23 ++++++++ python/pyspark/sql/tests.py | 10 ++++ .../apache/spark/sql/UDFRegistration.scala | 33 +++++++++-- .../org/apache/spark/sql/JavaUDAFSuite.java | 55 +++++++++++++++++++ .../org/apache/spark/sql}/MyDoubleAvg.java | 2 +- .../org/apache/spark/sql}/MyDoubleSum.java | 8 +-- sql/hive/pom.xml | 7 +++ .../spark/sql/hive/JavaDataFrameSuite.java | 2 +- .../execution/AggregationQuerySuite.scala | 5 +- 9 files changed, 132 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java rename sql/{hive/src/test/java/org/apache/spark/sql/hive/aggregate => core/src/test/java/test/org/apache/spark/sql}/MyDoubleAvg.java (99%) rename sql/{hive/src/test/java/org/apache/spark/sql/hive/aggregate => core/src/test/java/test/org/apache/spark/sql}/MyDoubleSum.java (98%) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 426f07cd9410d..c44ab247fd3d3 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -232,6 +232,23 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a java UDAF so it can be used in SQL statements. + + :param name: name of the UDAF + :param javaClassName: fully qualified name of java class + + >>> sqlContext.registerJavaUDAF("javaUDAF", + ... "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ @@ -551,6 +568,12 @@ def __init__(self, sqlContext): def register(self, name, f, returnType=StringType()): return self.sqlContext.registerFunction(name, f, returnType) + def registerJavaFunction(self, name, javaClassName, returnType=None): + self.sqlContext.registerJavaFunction(name, javaClassName, returnType) + + def registerJavaUDAF(self, name, javaClassName): + self.sqlContext.registerJavaUDAF(name, javaClassName) + register.__doc__ = SQLContext.registerFunction.__doc__ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16ba8bd73f400..c0e3b8d132396 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -481,6 +481,16 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + def test_non_existed_udf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + + def test_non_existed_udaf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", + lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) + def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ad01b889429c7..8bdc0221888d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.io.IOException import java.lang.reflect.{ParameterizedType, Type} import scala.reflect.runtime.universe.TypeTag @@ -456,9 +455,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends .map(_.asInstanceOf[ParameterizedType]) .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) if (udfInterfaces.length == 0) { - throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") + throw new AnalysisException(s"UDF class ${className} doesn't implement any UDF interface") } else if (udfInterfaces.length > 1) { - throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") + throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") } else { try { val udf = clazz.newInstance() @@ -491,19 +490,41 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case n => logError(s"UDF class with ${n} type arguments is not supported ") + case n => + throw new AnalysisException(s"UDF class with ${n} type arguments is not supported.") } } catch { case e @ (_: InstantiationException | _: IllegalArgumentException) => - logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") } } } catch { - case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") } } + /** + * Register a Java UDAF class using reflection, for use from pyspark + * + * @param name UDAF name + * @param className fully qualified class name of UDAF + */ + private[sql] def registerJavaUDAF(name: String, className: String): Unit = { + try { + val clazz = Utils.classForName(className) + if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { + throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction") + } + val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction] + register(name, udaf) + } catch { + case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") + case e @ (_: InstantiationException | _: IllegalArgumentException) => + throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + } + } + /** * Register a user-defined function with 1 arguments. * @since 1.3.0 diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java new file mode 100644 index 0000000000000..ddbaa45a483cb --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + + +public class JavaUDAFSuite { + + private transient SparkSession spark; + + @Before + public void setUp() { + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + spark.range(1, 10).toDF("value").registerTempTable("df"); + spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName()); + Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head(); + Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6); + } + +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java similarity index 99% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java rename to sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java index ae0c097c362ab..447a71d284fbb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.aggregate; +package test.org.apache.spark.sql; import java.util.ArrayList; import java.util.List; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java similarity index 98% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java rename to sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java index d17fb3e5194f3..93d20330c717f 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.aggregate; +package test.org.apache.spark.sql; import java.util.ArrayList; import java.util.List; +import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; /** * An example {@link UserDefinedAggregateFunction} to calculate the sum of a diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 09dcc4055e000..f9462e79a69f3 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -57,6 +57,13 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index aefc9cc77da88..636ce10da3734 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.test.TestHive$; -import org.apache.spark.sql.hive.aggregate.MyDoubleSum; +import test.org.apache.spark.sql.MyDoubleSum; public class JavaDataFrameSuite { private transient SQLContext hc; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 84f915977bd88..f245a79f805a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ import scala.util.Random +import test.org.apache.spark.sql.MyDoubleAvg +import test.org.apache.spark.sql.MyDoubleSum + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ + class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { def inputSchema: StructType = schema From c8e7f445b98fce0b419b26f43dd3a75bf7c7375b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 5 Jul 2017 11:06:15 -0700 Subject: [PATCH 098/779] [SPARK-21307][SQL] Remove SQLConf parameters from the parser-related classes. ### What changes were proposed in this pull request? This PR is to remove SQLConf parameters from the parser-related classes. ### How was this patch tested? The existing test cases. Author: gatorsmile Closes #18531 from gatorsmile/rmSQLConfParser. --- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 6 +- .../sql/catalyst/parser/ParseDriver.scala | 8 +- .../parser/ExpressionParserSuite.scala | 167 +++++++++--------- .../spark/sql/execution/SparkSqlParser.scala | 11 +- .../org/apache/spark/sql/functions.scala | 3 +- .../internal/BaseSessionStateBuilder.scala | 2 +- .../sql/internal/VariableSubstitution.scala | 4 +- .../sql/execution/SparkSqlParserSuite.scala | 10 +- .../execution/command/DDLCommandSuite.scala | 4 +- .../internal/VariableSubstitutionSuite.scala | 31 ++-- 11 files changed, 121 insertions(+), 127 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c40d5f6031a21..336d3d65d0dd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -74,7 +74,7 @@ class SessionCatalog( functionRegistry, conf, new Configuration(), - new CatalystSqlParser(conf), + CatalystSqlParser, DummyFunctionResourceLoader) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8eac3ef2d3568..b6a4686bb9ec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -45,11 +45,9 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging { +class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { import ParserUtils._ - def this() = this(new SQLConf()) - protected def typedVisit[T](ctx: ParseTree): T = { ctx.accept(this).asInstanceOf[T] } @@ -1457,7 +1455,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Special characters can be escaped by using Hive/C-style escaping. */ private def createString(ctx: StringLiteralContext): String = { - if (conf.escapedStringLiterals) { + if (SQLConf.get.escapedStringLiterals) { ctx.STRING().asScala.map(stringWithoutUnescape).mkString } else { ctx.STRING().asScala.map(string).mkString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 09598ffe770c6..7e1fcfefc64a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} /** @@ -122,13 +121,8 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { /** * Concrete SQL parser for Catalyst-only SQL statements. */ -class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser { - val astBuilder = new AstBuilder(conf) -} - -/** For test-only. */ object CatalystSqlParser extends AbstractSqlParser { - val astBuilder = new AstBuilder(new SQLConf()) + val astBuilder = new AstBuilder } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 45f9f72dccc45..ac7325257a15a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -167,12 +167,12 @@ class ExpressionParserSuite extends PlanTest { } test("like expressions with ESCAPED_STRING_LITERALS = true") { - val conf = new SQLConf() - conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true") - val parser = new CatalystSqlParser(conf) - assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) - assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) - assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) + val parser = CatalystSqlParser + withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { + assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) + assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) + assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) + } } test("is null expressions") { @@ -435,86 +435,85 @@ class ExpressionParserSuite extends PlanTest { } test("strings") { + val parser = CatalystSqlParser Seq(true, false).foreach { escape => - val conf = new SQLConf() - conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString) - val parser = new CatalystSqlParser(conf) - - // tests that have same result whatever the conf is - // Single Strings. - assertEqual("\"hello\"", "hello", parser) - assertEqual("'hello'", "hello", parser) - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld", parser) - assertEqual("'hello' \" \" 'world'", "hello world", parser) - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%", parser) - assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) - - // tests that have different result regarding the conf - if (escape) { - // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to - // Spark 1.6 behavior. - - // 'LIKE' string literals. - assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) - assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) - - // Escaped characters. - // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work - // when ESCAPED_STRING_LITERALS is enabled. - // It is parsed literally. - assertEqual("'\\0'", "\\0", parser) - - // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled. - val e = intercept[ParseException](parser.parseExpression("'\''")) - assert(e.message.contains("extraneous input '''")) - - // The unescape special characters (e.g., "\\t") for 2.0+ don't work - // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. - assertEqual("'\\\"'", "\\\"", parser) // Double quote - assertEqual("'\\b'", "\\b", parser) // Backspace - assertEqual("'\\n'", "\\n", parser) // Newline - assertEqual("'\\r'", "\\r", parser) // Carriage return - assertEqual("'\\t'", "\\t", parser) // Tab character - - // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. - // They are parsed literally. - assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) - // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. - // They are parsed literally. - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", - "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) - } else { - // Default behavior - - // 'LIKE' string literals. - assertEqual("'pattern\\\\%'", "pattern\\%", parser) - assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') - assertEqual("'\\''", "\'", parser) // Single quote - assertEqual("'\\\"'", "\"", parser) // Double quote - assertEqual("'\\b'", "\b", parser) // Backspace - assertEqual("'\\n'", "\n", parser) // Newline - assertEqual("'\\r'", "\r", parser) // Carriage return - assertEqual("'\\t'", "\t", parser) // Tab character - assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) - - // Unicode - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", - parser) + withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> escape.toString) { + // tests that have same result whatever the conf is + // Single Strings. + assertEqual("\"hello\"", "hello", parser) + assertEqual("'hello'", "hello", parser) + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld", parser) + assertEqual("'hello' \" \" 'world'", "hello world", parser) + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%", parser) + assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) + + // tests that have different result regarding the conf + if (escape) { + // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to + // Spark 1.6 behavior. + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) + + // Escaped characters. + // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work + // when ESCAPED_STRING_LITERALS is enabled. + // It is parsed literally. + assertEqual("'\\0'", "\\0", parser) + + // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is + // enabled. + val e = intercept[ParseException](parser.parseExpression("'\''")) + assert(e.message.contains("extraneous input '''")) + + // The unescape special characters (e.g., "\\t") for 2.0+ don't work + // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. + assertEqual("'\\\"'", "\\\"", parser) // Double quote + assertEqual("'\\b'", "\\b", parser) // Backspace + assertEqual("'\\n'", "\\n", parser) // Newline + assertEqual("'\\r'", "\\r", parser) // Carriage return + assertEqual("'\\t'", "\\t", parser) // Tab character + + // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) + // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", + "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) + } else { + // Default behavior + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') + assertEqual("'\\''", "\'", parser) // Single quote + assertEqual("'\\\"'", "\"", parser) // Double quote + assertEqual("'\\b'", "\b", parser) // Backspace + assertEqual("'\\n'", "\n", parser) // Newline + assertEqual("'\\r'", "\r", parser) // Carriage return + assertEqual("'\\t'", "\t", parser) // Tab character + assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) + + // Unicode + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", + parser) + } } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 2f8e416e7df1b..618d027d8dc07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -39,10 +39,11 @@ import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. */ -class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { - val astBuilder = new SparkSqlAstBuilder(conf) +class SparkSqlParser extends AbstractSqlParser { - private val substitutor = new VariableSubstitution(conf) + val astBuilder = new SparkSqlAstBuilder + + private val substitutor = new VariableSubstitution protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = { super.parse(substitutor.substitute(command))(toResult) @@ -52,9 +53,11 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { /** * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ -class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { +class SparkSqlAstBuilder extends AstBuilder { import org.apache.spark.sql.catalyst.parser.ParserUtils._ + private def conf: SQLConf = SQLConf.get + /** * Create a [[SetCommand]] logical plan. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 839cbf42024e3..3c67960d13e09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1276,7 +1275,7 @@ object functions { */ def expr(expr: String): Column = { val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser(new SQLConf) + new SparkSqlParser } Column(parser.parseExpression(expr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2532b2ddb72df..9d0148117fadf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -114,7 +114,7 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf` field. */ protected lazy val sqlParser: ParserInterface = { - extensions.buildParser(session, new SparkSqlParser(conf)) + extensions.buildParser(session, new SparkSqlParser) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala index 4e7c813be9922..2b9c574aaaf0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala @@ -25,7 +25,9 @@ import org.apache.spark.internal.config._ * * Variable substitution is controlled by `SQLConf.variableSubstituteEnabled`. */ -class VariableSubstitution(conf: SQLConf) { +class VariableSubstitution { + + private def conf = SQLConf.get private val provider = new ConfigProvider { override def get(key: String): Option[String] = Option(conf.getConfString(key, "")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index d238c76fbeeff..2e29fa43f73d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -37,8 +37,7 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType */ class SparkSqlParserSuite extends AnalysisTest { - val newConf = new SQLConf - private lazy val parser = new SparkSqlParser(newConf) + private lazy val parser = new SparkSqlParser /** * Normalizes plans: @@ -285,6 +284,7 @@ class SparkSqlParserSuite extends AnalysisTest { } test("query organization") { + val conf = SQLConf.get // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows val baseSql = "select * from t" val basePlan = @@ -293,20 +293,20 @@ class SparkSqlParserSuite extends AnalysisTest { assertEqual(s"$baseSql distribute by a, b", RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions)) + numPartitions = conf.numShufflePartitions)) assertEqual(s"$baseSql distribute by a sort by b", Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions))) + numPartitions = conf.numShufflePartitions))) assertEqual(s"$baseSql cluster by a, b", Sort(SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = newConf.numShufflePartitions))) + numPartitions = conf.numShufflePartitions))) } test("pipeline concatenation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 5643c58d9f847..750574830381f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} // TODO: merge this with DDLSuite (SPARK-14441) class DDLCommandSuite extends PlanTest { - private lazy val parser = new SparkSqlParser(new SQLConf) + private lazy val parser = new SparkSqlParser private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { val e = intercept[ParseException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala index d5a946aeaac31..c5e5b70e21335 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.PlanTest -class VariableSubstitutionSuite extends SparkFunSuite { +class VariableSubstitutionSuite extends SparkFunSuite with PlanTest { - private lazy val conf = new SQLConf - private lazy val sub = new VariableSubstitution(conf) + private lazy val sub = new VariableSubstitution test("system property") { System.setProperty("varSubSuite.var", "abcd") @@ -35,26 +34,26 @@ class VariableSubstitutionSuite extends SparkFunSuite { } test("Spark configuration variable") { - conf.setConfString("some-random-string-abcd", "1234abcd") - assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${some-random-string-abcd}") == "1234abcd") + withSQLConf("some-random-string-abcd" -> "1234abcd") { + assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${some-random-string-abcd}") == "1234abcd") + } } test("multiple substitutes") { val q = "select ${bar} ${foo} ${doo} this is great" - conf.setConfString("bar", "1") - conf.setConfString("foo", "2") - conf.setConfString("doo", "3") - assert(sub.substitute(q) == "select 1 2 3 this is great") + withSQLConf("bar" -> "1", "foo" -> "2", "doo" -> "3") { + assert(sub.substitute(q) == "select 1 2 3 this is great") + } } test("test nested substitutes") { val q = "select ${bar} ${foo} this is great" - conf.setConfString("bar", "1") - conf.setConfString("foo", "${bar}") - assert(sub.substitute(q) == "select 1 1 this is great") + withSQLConf("bar" -> "1", "foo" -> "${bar}") { + assert(sub.substitute(q) == "select 1 1 this is great") + } } } From c8d0aba198c0f593c2b6b656c23b3d0fb7ea98a2 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Jul 2017 16:33:23 -0700 Subject: [PATCH 099/779] [SPARK-21278][PYSPARK] Upgrade to Py4J 0.10.6 ## What changes were proposed in this pull request? This PR aims to bump Py4J in order to fix the following float/double bug. Py4J 0.10.5 fixes this (https://github.com/bartdag/py4j/issues/272) and the latest Py4J is 0.10.6. **BEFORE** ``` >>> df = spark.range(1) >>> df.select(df['id'] + 17.133574204226083).show() +--------------------+ |(id + 17.1335742042)| +--------------------+ | 17.1335742042| +--------------------+ ``` **AFTER** ``` >>> df = spark.range(1) >>> df.select(df['id'] + 17.133574204226083).show() +-------------------------+ |(id + 17.133574204226083)| +-------------------------+ | 17.133574204226083| +-------------------------+ ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #18546 from dongjoon-hyun/SPARK-21278. --- LICENSE | 2 +- bin/pyspark | 2 +- bin/pyspark2.cmd | 2 +- core/pom.xml | 2 +- .../apache/spark/api/python/PythonUtils.scala | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- python/README.md | 2 +- python/docs/Makefile | 2 +- python/lib/py4j-0.10.4-src.zip | Bin 74096 -> 0 bytes python/lib/py4j-0.10.6-src.zip | Bin 0 -> 80352 bytes python/setup.py | 2 +- .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 2 +- sbin/spark-config.sh | 2 +- 15 files changed, 13 insertions(+), 13 deletions(-) delete mode 100644 python/lib/py4j-0.10.4-src.zip create mode 100644 python/lib/py4j-0.10.6-src.zip diff --git a/LICENSE b/LICENSE index 66a2e8f132953..39fe0dc462385 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/bin/pyspark b/bin/pyspark index 98387c2ec5b8a..d3b512eeb1209 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index f211c0873ad2f..46d4d5c883cfb 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/core/pom.xml b/core/pom.xml index 326dde4f274bb..91ee941471495 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -335,7 +335,7 @@ net.sf.py4j py4j - 0.10.4 + 0.10.6 org.apache.spark diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index c4e55b5e89027..92e228a9dd10c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9287bd47cf113..c1325318d52fa 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -156,7 +156,7 @@ parquet-jackson-1.8.2.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.4.jar +py4j-0.10.6.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 9127413ab6c23..ac5abd21807b6 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -157,7 +157,7 @@ parquet-jackson-1.8.2.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.4.jar +py4j-0.10.6.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/python/README.md b/python/README.md index 0a5c8010b8486..84ec88141cb00 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.4), but additional sub-packages have their own requirements (including numpy and pandas). \ No newline at end of file +At its core PySpark depends on Py4J (currently version 0.10.6), but additional sub-packages have their own requirements (including numpy and pandas). diff --git a/python/docs/Makefile b/python/docs/Makefile index 5e4cfb8ab6fe3..09898f29950ed 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.4-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.10.4-src.zip b/python/lib/py4j-0.10.4-src.zip deleted file mode 100644 index 8c3829e328726df4dfad7b848d6daaed03495760..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 74096 zcmY(qV~{R9&@DQ)?b&1Qv2EM7ZQI&o;~CqwZQHi(`+n!1b5Gs;NK#2xex$0qv)1Y; zNP~i*0sT))&s3EBUz7jcf&Vu;c(Pd0EBtR0C?LvzO%f{;IB_}u?IGX*0U`Y#6C*=o zYX^HL7di*eVGZ5NB?ctl+ghj(2W^e$w>;vSJ@$f#yuDzFf7e&r7MV+=%bWhUq_X^r zH#f`P&42VV$1|9dGMLYINZV=CIs65z8arM2I~)(ioa!ao@ltTF#5lo^S4-PblK zy-)^8!w`v0o|QoxZ99dFDMc75zp?k3fo2|d1Z5ZRo8&x)3LBju_h)PcAkPz;!weMI z{)}eE=T9UMfkS~x(jv?~6%`-Wl>*`w0Zcx<@i_P*VX4?uGh!yOxFp444%{297-k?m z73R}T0KhRE(Gd<_T1K7%!Q$a?xK0?PJ!+t3Os~N@G}+=JO!)L>@SVh~N)E#5XjaVA zRRDFHHyv^zAeZ2SfNR`pgV!L@{#q3~20qP;IFg(Lug1yL$V^9<3?>ECLu##>*(a?(5%JZX?j+I~RO3kc(>$>z<^^d3qMZcn#>Yyk>3K3x@ z41k6#Wfu>{KQP9Ch!`waWhjP4OnHeDZz?j5zfmZ*-&VpVG=IBJk}bN_0U->h(1&=Y zp~`^s;&XijeW46KegJUco0?Fqo%t~6*a47k(LUcZ?bpCjm|U>zb{nnFC_C=*zfTp~ z{PT<6x^kEuR+R|~$-&ES&H*Z5d-0<0iUrH`F1V~G9oquImm??3f(8A$X4u&kMt+q_ zM3ZcJW_!RqWM%)WpjL8jM{N$}ufkEtuO{oyk}uXA%2x4V&5?AxV8#2I!e8miKmH@m z?=@RFwUHG-LGB1rUTKaoJ^I&ya=SIlLYGdM^Y4Ah{#Dxc8KFh3uJhJMq&_YnLD@VdMa{Ec{;W{F%n!HV&j7gCSgph=&y;u>)}hhC|R z7tx;^DWaz6-v*>G65LUz(D0J_hWhps8Wk@_6) z9jV8Ps*QfoVZ)cr>*8}*4hem=l%?0QDO~1rkQ+kLzWVKlhV1*h0 z-SccbgI_x%J&EHeXW-iq0Xa{m%l3p)( zg(;{&Kw2i{#Ud&wy`v|b8$9xax-=MFPvZNXggoBAovLZJkK-%pv>f=jQgf3?&T7=+S9>@qnQ zRmjEhp3~G66(z%1Q*ip6#&qJ5z$ur6fMC1>*D0cEnph38HZ}}sekyF0RZ0eSp+3lr zG=`V;jB&z;OBL;T8F0{4cO@D?D$c@8Q8$N9gkA(YLMb-g6MbZo4^$dsg#gubMz7Sj z?9mGRIwrZpb>e@WkswsdjMJbsttq8VbF0)T7+oz&Y-4JmJ2oyP0{Bod)5?A!W zN6+KN@>p+TiPB1m6j&4lm=f0{@_bgsMv3m_$n*1lpBP~P0J*y|_V%v*tE=qHRu$5I z>Arrx=zhM=-`|K|TCIHCS$@2*=K4BvwY`Cvn5E{6v$gayrR+{+Ly}nR-QlJw_H0Is z3k50*n`oXXqn;O{WWQ`h)jR8nKVGl>X^RD8(E)yX{X-MckeW(0{ceMy2Dm!;QV-8K zRcu>*utsrGu)q+hp>c{@(enlj0>;Rp?gdq5a4B?UQ-mn=jOS6}Y(p51I4NZ1ax{bP zOU#)HX$G6S=<{Zvx(eKG_bsMq^YktLX99?OzMG)vatO}` z=LmeO)`zKY3PMpisuG*WtjWn)^FkJRWF+>akWin*E`W=!ZEus_05F=5NBubT*qXjp z!i3)XwB5RTy@5E?(hQ_3OV`)K*oe0XV=Av|hn;VFLx}sndTN?YW<)!NqgNlefECuK z&e)1IyPH9zZXZr!HFV@&JSs@sV2ngllwiY!Mi)JJcwYT(_@`htK0DaRm>^P-UIe5# zDwN7f=lYn7N8DxK5h}jo8!)7j zqA?>t6;gG}u^s5J(K_Az8DZc3pPwO)K>_hFA`hmhwY)0i!&|OTd6yj7DSA=l#Ho1} zOJOvsywZXd@ViS1YO3eM48DK32G8KPwYHJreW*YyT1BPF6veW09BO0;p0&(5Ln(1KwEj|rg_3tHSjY=5Pw#6?6U0-UxuaSKE+$Dzny%=wb1>#>hla@`@$lPdb z;l6+++M)?hkm_5S7igBrBVjS$>|T^^A)2KJ+yBIRh9^-)pqjN|B4I#c#gx(V(iN23GQ&d{yID$&I*SXv+uZHy%@ zZwj#hjuzu!opRqD2?frtnQuc)D2JL|n+(wobld!^V@h11N4tzz2qNuO7#ImZx^M;r zJE;$IC8f=}SJVCUhN)m5?5U6P#05(jy+jORX*PZ)rap;)NaIx;DqS48a`5;su(33qP z(C-Ndau(t96nS4zW16r5$8(~+B^KwUGgM1`HL6E^(ikk)X%C9B4b1c@)2)Pv%j~3{ zQxe=XMuJwFhdTVXpd&-&Vf2uvrf}4bb0aJJRgEO?je(Zzs$rGiq>g^9>TZ|*%2coV zMS@6NLHXCMMtvSgM~)l%;?ov13z1@oJXE(bk3d~ zjVD5q%4SS4-yY#!$P1)m#lizf!L2AwMEsg)t+?rIZmof9}g+jCpRqSTQI27ks>hF21Y$50fb&D0p*5q$>)lMrlVy5nXva^{R= zA2R7>%cNS;#AtA=y^D+YD^SCH_>c_2n<*0Th9{j33SHzkda@@GaAS|iD`0O0BV-%3 z>nu|T+EgH;nn8cv9I4NAwUNBZOIl95l^5!PAB!u$+yt$?nj<_ZY#-cQt0}hwCcU!i za&YB9GUxJ>_^oUJu@Tu(e4AZ#b^qR)!2C1%?)u2R$TQUfw?LH|XB4|{OoH}{!wGDC z%j|FuNwFD`;uf;(gdlxz=2>8{#e>Dw)PmZD$L&KSxDoLaE` z)rzSEffj;y3Y&y5y_yg5D;(J<=h>S^&fvDdj>M_EiT!cCg%Ely+^VdJ$|eTvbX#Bh zx;On6DE%SC>4C=*pv7v_z}1JUvz+W+286fuVR6$%pbiG%zx%^#FG#(xVYbF#FbDSF znxIQhZXLx_?_F5VQBb3x22$8+!8I>Q-?Z!N64B=i34EpYAT+>lzbK%d14{6~Up#N~ zMkOKbCfzjlO1uPaXyJf;75(vezt#SnzwCXS8a5n&g8&pUxzB&j-3|V-vt?KWXRdZz z_vz|j!wZOg*n(v}i#|}QJ70??w-7$96SeZ4{~N{dPUZ4yvq9XrfjW)dK<*b-plz_kN75`RWl39lW6qTS~CUa7W0 z55eDBgJ3Ryb@m80E0mb^sgH`-TasMw);Itf*FN4$-a>Dv`+er}dAfx>9dl?$skzHL z!G#Sl%1C7?CMVaPV9um{_!!2xrbp<1%_~W|dHC#8t#d5PW*s}lZwA3;sC%@ve{Wa2 zq<-IF3NwpE$<5&N@>L@=NaLo3q8tMnleoKHk+#Z?hxBF;Jc^_ zaZ*~q!NP?%`jWFK;%mDap(_iQlH^*JtUKXvxL#TE+=FBD`AP{ZTujVq(YA9mKTN`P zs?U`$dFRV_#BPn3tBhJrA*tETuhg&hjw=rj&85vfHxwf$t~2A&5umj^;7W_v{0c>m zm{OCOWDa9E_CecHo1TuuP)3q;r!(6W=Dk`f+~|8VcU-UH!_iTA&pNU(Z8}~}EPKS< z*Q?be=XWaR!i!Hv)Fvpb=yZA44k_6WFV8d1`0VT#s7)KK_r9gFb534`VG)|cj3c6WB*=cB)OJLr66UAbhYgf zLEEs>`vc~@*MrNn>n~uX7j((JMFvc?Wv#NoPHQRl!W2U zFE(O3P8S_B9ca#wI?F2K* zbe^lV6!uWP4AS6ago29*mlp>%+~JX{i73lIEN!+?*|*X;P7a|(QD@ThVM}I606O^X1xFLIGx>3 zrz2D8F*14rJF+mqvU)tHA{uY0?%GYizW}7WV3Tje^B29#m6R7op{Kohny#Lc^3W5MAWr=z<0Lq`!puV%Z)ZYxW;Noa ze68c`cnd(E>q|G2bNL$Q?V8Enc}UV$j4h zqq3p4NxTto^`}@!o|U7gqZHgmkeRamwHRdW(;JgngzRA0b+r80oH2I;t!4KjI8*~T zF|u)g9+s<#r;+k;cfW%r!x^=;r30p8QKVYwyebFbeR?%Fe#ag@-j~P6s;>%??Pq5C z#)LQb5YbDGw`<~1csP5<379~S`H(|&b{}#)gm)DR-o{2VRq2uvO;b6O!M(?(7zeAE z8VD(cwRiu7A9w<*|82aNhRqZj?Kk@2zQ50fisE{E5s|<;Uw?|WFNJ_unB(dbhRAn! z;0kwvSjO!eH<)TqWx(U4fvU!rRqv!Pg>fq6 zR5fSr;FLq1wlLJ_Ir4`@w%j&$Y3oXp4-t;Ze02B?GXPpvTODb8#>RtWzWBVqyQ{rB zQ=q?dZ-Eua*h`u3uMKQ=NXIk&MA7H^@ zthG>VS0%h}YqI@}-BC?)Jha(&cBjkMGaZE6ie16Y+M~=sK*)$Rpi*mxHG)7axd9aw zHk)W5<5@$xxs2*b3mwD!WBiWt_$0CpkHypeA zyhjDze{7R$bL?jgLbIr2tJ2B0A1wEyuy1kILaq8%Y3qP!q%P-iDKV=>UfG*N&o(W# z#eXSPp`UyL#k_qf zW^`@;QI0sP(HgU%kC*&XjOCezb8jowfgs)6rl3~Lz*%>*_<^4>SXd(u-{Sr0?P>S* z4tD;Y<5bz*ya6ppbcc-q^YL5Sil4c&LNgHgr4b=Rx|&9xypEpE5b zm?CtdzM4yj*m_w3G-ZzQX&ku+>7t%0_x9Giqo^6aG=m+`LI*@67QA z`~Ro$?muHazCnP1jG+HhW&S@J&)(L-@PG7}Zq@17MFymvQ!24-=>}@#3*r!JFAbu- zfDnN{jslfqk`&eM%G}o6n_3WQih2@yD3W8+5ubxAy$P(p#V_nW?Jgt zNOLg3!6Tux^o$q9)s0gY?a-;4a@qH7V8OcpsYNjqIwxovTqHyysq8_dw|1SaJ}dBj zBZS&QQ=lhD^ofsU+w6v`nG*ZTF9Ek#c=$eWD9>3CS{4MOmm{@w{sF*>pC_pO{se`_ zCPy$gevK{ev^bHhTp!m<4JDS-vS&}`5EM`P9YZmp(WPZZR3!Yu6~c@~=WCD9!kD@p zHn&oVBJOZ?r#A<2^Qm0S2-k!nu8CB=mtf3IKH@g@KsO_Fk=KfJF9DMi+>7(#Ka!;+ z!T2OF^Yk%>C1+ercu;AI-%93tLM&(_tcB}?+hr~*{Y67DC&K|39J=MqJ!2+Ux2WtF z9;3H&o*8TOujT=mzk7c*-dSJTl73N_v>j{*|bt~-C5@f!)lqjYzLnz0w zD(k{aCKkmj3U^7$Ju7#k6U<%qbzb7lWz4f@3;Y)fN%)i#N}0as154zWN`g-V2K1FH z?Ie`P0k}6T{SuR}sDF16o38X`v`aGb%`&yaW37Wnh$Blv%^S;9y+?IJn7VU;{G zP3Os@N6SCz+WoCX6o1_>5(g_-%loSgw1RuvzFjdtim7BUZ{E~i^&A}QbBA3O$GUIm z9*{#x;bbh23 zES1x}kki`J=PQtlb>Wlv12xB?5R&Im=Jp_>N^cBaeN+f@)IN@vWQ zZ<504AzqZS{OCRzDIgaxXdtCAN%SLCy-g>TXJv{eqlBi<2!n)eb3eVwAsyE+w=QP( z7-*_iG;oHVUBe+e94S??TowSR=UlU#aNd=uZU|9urUbO_p`}AN0<|znxnMjRBzVYE zhk|q}d5^Nwt2iRlZ%%&ZE()>!ZTx2KXC;?RR5<6p{Z3VFJ2og$%dM|_qy_$<92~tn z4_K$zi|uGupdlx=&mjV-+IRrRHNiv(lcQduggt14WMDX?Dp5&e$IY;5t3Gcs5USE} zWRyn@7i?D=LW=m*P^BSXD;-jk_{}$S%cD6{ta`>|7#*sZ7->&q zDKi36Fw=QKTy2R!qPa+GGDRk5WP85t(fsZ^oEz59T(Ggxv(Ed-d+X0X*>gmxB7$Vo z9`JN@I|7gPyW=AjITdS_t1L=j<`seD6jFY5o3Dmc%au9Jlg}B;fKhndGHT;oH!KZ!92oZ*{B?XYe@a^xnuJ!> z;k6>Pixz&*@~VB*S^2TREj&pLN|2eUytb${LFk*D-m-EdLwSV83#|qz*?-WLkHbsr zfJ*jeowK{$Bb}sUrzN2co8cuL`;N&&Z{Fn`f&VXFp9fC*i0$)TX8HoIodaXJlH~p| z`^%#|Phq~bmUlZzGvgT+C0~cKbw#7L|nursAQUdyP zK+G;$Z0lLQ9HuONI;c+3s3oFSJK$LU|2qBuh+MbzX_VjEaky{ti&!pd|U z^_Wzz{R9^zp)$<~Pd@1ELZFM_=jlP)%CdR$ zdc~4QGen70Y5GtCoWut(0v%C^n^9zd)ce8C*At$PS1O1X>3H^^0xV(dST9pMZ^W%# z(nu4w8i>OD7ccjK)AchS>#q{>$d8EJkOdXF(*^xN)23B(EXsMQ#Xk~@n4wB_{Cg{V z|Jr4)JIuL6ya>WC((cs!?7zctKPfR^0{2KIl02vYcHN?A5ibCg*&~PKFfj~+$$6+C zH7dUf{(hVa9-Y#;WmM`((@JH`vL*CaHk}gT56APV^IJX%#6xXp%t%Ah^AWs=9j-F5 z>~<=0N$P|}CVj8xF8*j^du}RgLg*WN@V|sE(BJ0TFBVxn*S*eprWC)#H)*xtFi!7p zYb$lj1!NNadC0XhjHfQ~wJw^rrxkr1TfB+KdGDB(!Cd)c8&aZusgfNTp$m~#LPfOd zsu*9ukR_ZuwEp=lxT~hVrvc`qz27y3q>YQg6*GUPhqH)MD=CU#bQdnaPG4aj$F`a@1p%(V%)Z>=P<^2*q6SbRc)!ZPEDC!F4>FdQltnIC z@@D5>GXg#7`x?sMoOEL61MMj=7t2V_q-XRI8kC}n5Iu}1yZ8}j3?a@}If;k3w(eA9 z=mwOLba2I(aTMM0A(lzH5dkHO6DUMK7^PX~WVllNgb(O;76ByLt!_uGv50Vczbg3- zT<6m%k+o+^;qI|q$$3=UkA)fN9<)J?7?2naxDux>JmZH^WjGoIQT9g;D*o!YNJqdJ z?ir0lCyuq7CR`XkFDuR`tnd=m@Q9XCL!%UwI*OEI?%X-O9jbo~=tpekuIV+fLb*j9 zQ`k|4V>b3!rp{y;k_6CH2r5zfL0Y`KF{e50ztKZcV!ZY43pXcE^Lqovo*U4!y?=gU zmu~Sn#Y^>0d7`IlBB$X$fh}5N1suZ~GL3AC{Uk4?ThAo}ziLk~efvi-@YoCcVRE1u z(;SO)$?ld7O>7%TyT%nsune&c7AlD)yy|X$Ju-$CA9&!h_>vtU6b{gUnj3Uj1ZU#< z)71w|H2ZMl(|x$$gsW0UVDfCLo-*NPbcVo)Lvkg8v~ozvZ{*?N@=&@5w2a<1Ei5}l z59b6bb!+y6byiMlQ_k}`hiWda@cU^ZQVQ1Vr8rtzNlBCEwa(5e_UzA-HOQH&E~=?z zj%a%lZ0^+rmcd7Xgk3wa`;t`_mC`;0i`X1A#XhohR(R8PjKjjgJ>^)y*|ybA58+`! z^uz*RM!WqHhDN{rH*~K^%>6EqS|5S$h6(L-S|w()rV*i#6vpj>o2Xljh)w)Mp8TuM z*cvK7`4EW1*S8+TfDm+91?RU|IOk2x$$U4B)TrCdl)UtQpN!tZ+0PIym%*wZPpUV~ zLi8%=S{Pvx8E;GOv$|UkxbK_e^Aj@{Y^||}uHlK!jvUrzhxos#b;l2RxG zU)(}5Q0g9aC8v2=z5Zifna*m65d?|c*$l$Pzg4%=!K|(*tlygoa7yj$4B$$(RI~u< z)r#qhU4HAu|AITMfbAKAzf)(AIvR6y1t0v%Hj(z%9=sXk5N}b1wp9ASk5@W|9)_w6 zm3`BSi0)@h?>Z+zZaQwPg8#ZD?_I}R4S@8mgEYh0V`nq+4&p+Y_e+C@bFK!)qs4L)#T;VuK1^H^O$=VXOa z`AxQfK34HAhP7m_bIL^<*NzKYTV$@ha3P+toYn-lXivT+(9IN(Q0XDeh|cPdcSqCJ2wv9=?`H{Lyr4XFS&}HW+Tyh*I>hpgN~O&8S}nnGcs)) zU3e$SenoagtkZ3q3qw$yQ*?$hIY_UaN=j+rH+mRja1CBDfJ-6GV2~rPNjBh zHRUV9)Q2Jd!xV($lzNeilCQ5@R0e)5L2~=l6;cTv$;ru3^u7hLG96bdEbtz-j zvXPVCq&)%$9?ev>JO*8N^JkEktgW}T>wA*G-Cj9Sn~+WRXK!%uzgTTH5%5F-RZ+*q z24EW#8#`Sr3YM`+1x0c&b0Zcl0M7V7aw@U$Sh>l!g3%B!8P~Qq{A>t%0mWa-Pbxgl zRdYBFp|+?iHBP`qcd~5fMLCS4VHEM2lh91FOz|NGQ{CVRPFGd14yp(nQ2etd-?=;J z0Um`7wQcBd6ZY4Vf{pa9_-1PM(<}jgJb=&vje03bpqej`S+e&=tHuYnby^T_?2^Lq zpp>Kk>VCec*4h&yAA&WFe(Krmcki;brVT#NbZT`z$gby9T6v8L+2#IJ-=BXu^so;K z=4{6zpW~aM4`?UI4kuhIW)0nLRHZK7I_iuuE(9`5o2;Lq2SO;GfU5*1BFyH_@_lN; zrA1kNe35Eva>)@z2fWRk!qDc1xeY5XtLHlQwG279;(YJea$)p?cd zf%jhIj|L48^(U#Z*>!>I<%faqKY{M!&b0^EOTi>sxY@A=A!prwMh$y{>F*D z?8bE+jHU^_8-w6nS=|oBA6u2&12Y0{Ks6)?NpqdIPvx*RY9;1q9B9n)adQ8K;Oe2g zVy$4Qtj@hJebNrgH!i4Jm1qU4oTlXLOv-mFycN!A;eMC?JdIkD|6EMOCd{nbF=R?7 zZfe(SS~qS=PgYuQXG&k3DnuIoS3BNL72eLz*w2v%!&V*(Ig__Rue0UXd)LpS=eIu_ zRb9>RXl(5g%I5Wr`c369EQw1a{7WzhwCyZ=L^GuP*up1(nOKb`fL-p_uT{B@7p5hh?NtgK!@!Xqg=|I!i8s7$?IS*xsl0$#Gaw`p31q$<)&l zGjf_{Hx`S$sJ*92<4|Ivaht=Mma3#cTuPk-n;kcnj4~p9)cNv!M-wmQTc+COSpfgv zP4no|LGc=WyTxX_3aV8SGql;xUcFoickT+I);~-|tt0=UcORFZkE{7pX56-GA3vcn z)%uKyP0nVOo*ey~i;>a6&H>cbCx$MczBL@wlZIdKU#B|d;*au?j+65TdQpBQ57yKh zR9_ujsSf;QcF49?5t3kn#$)*%rAk06IAg3|?AyVbEsz(+JlOLIUMtV?E%-Gt#czwB zR!ra~W%dP9GBY}Sv7xEVT>Z^k#q_}2$C5;sGvbi$i-+IN=1)i*;Ih>k2Y_DNi!*o| zXme)-r)R}Ev6JO$^ypS4s>MH{SIIh_oS=e+QAltamy79Wnv(8alL4u+v4I)yiq#R# z0@d03c7d~fJyL0gkCr06*II7XO0pLK`;S%2GbNWDJ*?&8B%9Km;^x`?=``)&Oe^+1 z;VG;&eA%@N>bVMU9t{1;)7ePMK2TfpgwvDanSTsw164Q)wknp=e~~B7h`UqLz7rb@ zO<#!JWlLR%m$<4s8=$By`P@`rTg%Vm{k9D;RKV!Z6y(iJxg^B&5l)IMViPZ#H?bI@ z2ekb+8mBNaC!Hyw${k|{VH44vZ(0^EZ~VFBlW%lZBZ#QR13$Qq;C3M%udtW;0G`Vc z*I43Rdsv*)vU2(96e1f%qQj1vNM<{xE(SGZ^q>E9pXIvl==pvq``2N?ZolY*4aZTUE{MvR3;>RYU_a8#TrP0BkESL5<`lC`_Hy%i;wo(WYIMLJbcT)9<=oldSvD(e1p`}BqPbRt}L&bF-SD}&O{jWg|o?7(lu#u{e@*=gS zN@Wop`SP`GeLBd)wm@doCKyXioe(hn>2ar#A3qo#;w7;@`XX2!A3gc_tN8dUfbZL~ zwl%FbU;SC{fhwSz>SXxZ_kpLE4#RVOuk5+(QIk!(cEvTpRMNAv>H@uw&M-OJM(;pl zCQC6c0cmvPM*LF#I(!dEXgpYijdXPt2DcnvcJmGSgwC`%2l! zGC}s&^_P!y^Ly2vadwbH%zT@xDkbOF+-56G>#@kpN6p1cf-oB7Z@`0>;bX886UzN0 z5yi&er7afGdTAcqT^TR2tNiay(Zc*PYGtYg=>vPY^mQ-`kI3yGmMeaCDBCBpKei~sXG8h#UiMU-VUZdv1B;H`x~XdOgWbsK&q^N& zF4zT(PJC0xgK00Hz1pZVk4$)Te{tfoF`oq@=u&nEc?UeTrV@Z>Gx@PE=l~p4Kw9SOVVKjE0Ky~JI-?h*-cjKJZ>L8 zf9h60pPtQg@pH^ENBwz%wD>(f773pLguZTfAIGL{cVF-C^TXggH~gPFaf_vQz2?53 zFIP`rD0;rP;k&;-uh090bpF0xi8(pBKR<|ek%W|UgM)>QlUEOK%k$^fWp@vIcW0}m z^tpsN#@tF$+MGNd?+=T?gS%#!;-~*DLN4b|AbOkozxCTUMDG&hnueFs>l5JX5#$1p z7OLq#bbcPbSYKMH)PmW+KMT&|5`TDq8Vmd{T<86{L)Pu0aL5TGt@?c^tkssCpsVcl z&&(J^Mn-mlcRLyH;B!mG9D#Od(vmOZga$w$lUpd1W@C1DcIgn!=JqFTVaKj8=0#<7_&*&NY4~Z> z^83iS5^sL=O*x;(RSBXl1Rt<%b44(R9d>WzSZpxf%D98(LeYO` zda1Ei#%;2OfS|p)x>Z1aa!ONBel&~`59qn+s$ZJE+#ulCLo^Hv7Z-{))`-TGVvzTt zl|(KLm|?FZCx?oHi&}lA8kFz`xq5<{2I?uq+oOI=P_EDu1+&2ajQ4YYK!UVgikT7`vTVl49^`)kWolJwaf}DVxkzqM3 zg$cv7jtM!MHn>eOjtiQN`2|5)*j4TS@Nw{D>gMS6wplotz_(Bnn`kaQc6i8 zG2}3s0i(Np)JHG;R6APkz|%f=X<4 zx|mOODVnip3QFx83y|;$sf07da`m-Pd{KJR@dD-PA3cMVSOh$qR7jQ^**e~xzP7G^ zHnibJexwH))xaL71Mb5b!L-1&ML-6KXi}yd(vure#Evfbu&tjNpYW`akOFf5e1Uat z0>ub%Z0w*=@iAWClGP2b>jMFsE{wx75EibUY+bp7$0g@mG=d2#;G&(oauvjuT&i-) z>jy3HiYzr#1nA$Ls?t&(PBpB=-PX zI((lm_KrN=obE0t-(!2GV)gUsGIhItSSSlair`a~%Qk-(3lIqdeU0{H%BhMNiuy?c z?6f$!In32SaL4g4js@kF(F9}qjr($i%4)hnFG9)VfratFfyBopn@jML4GgpGosnD{>Nb+9pUbvcAoB^t7FABwcSEQ&hC>aKhIh|%E1eAe zWhQYj({4&nz@gZkajIhe2sa9^{abkI{PGtYV3;}^3D#1WM9Pg+J49HzFNC0;(QFq3znCATg8MJ&S`!l!A zB?!DhB+-YAUC{91WcI1%1N~Kvz2`n^X1yRqoW&N_-6`B$ATVk3390@8$v-2{a4wa-N6OSw_xVJXCyrZ^C1N8z(-O>?fU#;wIOw4K!mLg+ zo@)+GC&$yhIvrhXAPz|i;_zM<{wErE1uu3~IA}^bR!piq=CxvmGeMZyiung?^ zo<*sIpOEzeH9IIwTs@^^p{2JZRA=K-;=8 zm@s{bLt&f>l(5`hHciQu=MKt`J=13`Xa3*I5)@q*oy(jD#b&7}rD;jdfdI1L7D4t?2}J9m9pOfH@qqWkd*KM}pJ<&7tmb@OgRO63M#-O!m1Cz+ zx(1&2Xw#H-)6_ZGqnI(6^iAa9#0@)vacT?N#edw@S(A|NI0imBIpkrZ{@XdRuX5z9 z8~mP{{V*D!oyckS#!b4~%By9nVF?xg0^1RDMa z=|-x|0Av0ajod|w;{iw!^a>f)VC&IaV-&p(RuJfCwgeK&Fu=X3R6lT)P-j7+zoUel zGxjv5n}lQZF`I&8z#}A#QZLv-WQdqlqpdMdX&gXA^+NY(QGk7?-OKs*_2={Xyt43C zFD3fc2t1p@w@QgnvnBwoM1UrBU?OF}Xna&OPZvf=)SApA<0aM z_76&l(>X2~JPFfM<)`Vs1Z?#4wW{F|X9}mbY^-$5_SllmcO1Z2qH!R=a*Zb-F>vBa zt&yc1{e2$(#f%VJv))60Wf{cxA-IQ3gi)S}wAv`2l2mX&+1JK%&L?2*^DZzBtPJa< zBT&-vW{lmiIK^4K0hrmfpCj>8=FO*E8kr1Q4Qbk0W69aIRCz)W(Pf*V$w--PiI85A3}8Ec&^b} zF1qshS-^!}V^i#iXns53E=(J+zMq<-(E^^hPz>3jyf&SwajRDsdH4nz1-ji~C~*5R zBBL{tS8!hg50YN6oNqdhyDh8g;Y>JU;oy*}3wr+vl=$~<`Uv5L^glXc>dr2i{>Gk> ziEg19o&yICTXfxyJ(d_VQjC#yjx!si&YQjfgs4|~b&?8=@akwXS?rDsfg*?z%$PLF z7tiEFk?yDy3-_&XxZ-Y{f=(#)4p1244~tuW1uwBZQ__&#}h<41(bV+@$Z-845xB6dYrE)6A@W!6vC!K}qYL|3QBEXlj8B&*- zOA3!RhSQ8cl{TEp)sx{$Cm<#=Q63WopM@skE-nX|^Dv?A@1GdO2lBI{&73_>auxvj z=pq8lzRv_Mp7cRs2!W|ZtP1y{CvZaCAFd?jQ0aJklgyarxAY-%40yZ<D8xrSqx!$D^d09x%re6VYHK4l^y28_z<`rCQZDl>TndyYBvZQ!^ z;uEP)yD{RG(MOGkg=w?KkdZxJjv_w^a<36DMAK_USDaF4i{c-mT@Y?giDLzIe`M6< zw+Yx#J&*sJ@&);$y*v*31ld^-h|G*y_oj(T;lV0U)`>$Z#fY8#)5qZ07K6y*!P zM;rHs(YfY}K(4G@7Xk}jU!52kFJVXzUtij(nyinQSW0K{Ry=?e38&Q=TGdoXEZ`K?Y1SA{R1v&~dR_j4T-f1Xv16XrDu{Ri z?>3d_wm%KS`(upeK1xu<32t!iv;|GcuNFwwN?E)&Q_h^v zV zF_KXz2ZKuI$8)Bj;sE!}*kmd8p?~&dM@g#oe<=TvxY864usFle8uhUZRsPu4(!<>z z{0^_y^ueOQ{*$py;SYRC*~xB^6LVy;Mx^Gn9uP<`cG;2|zC-$1(PVT0fO2{=67$#< z)C@_KH|!U$5F4zZ2ud+3P8d!ith4 zqNXIvLa3{6AG*;G+P-HE6G)k^cX@PSyW>8Z@0=7y?-$%8<|sYUBprA?$8(uXR)C0d z_s)axI8vyCt_#EhM#dF-koF{%G6S`&jkZDTaxJ#y2wN$ON5IFe*mpt2s!m%r+K5j^ zE!Z8%=K1eKibgi!SiIn>!Q9v}L!bY>PEqX2?Ok;nsYN5BXoz|IJgkbMlhST~L_GUi zcER1qph{>gV50-Ai-G*?uzsB{2;73^!ws+A;x2L$FKE?1jf4090bW3%zlnQ=|6(bPzya*SMt%SA!A6Fq@FiQOqLI5^Ct0*pH z=uU=>=pS%2vE}yhv|!1GV)miKRQf?P62RX}E))bzi@~FFW&!|T(*y-bq`~l|7r%y6 zaOcmzIRCJzL$K%=^oKj1NmcxiLlhv+w#L$9IQ+}wlNU$lB({-1>nVOkK^nlHUt9vC z!+J7CAkyw)T zqUjZVJNF|R)i|pE5G$}FQde-I;JPcQsc&>Qs zLK|9DO-T+A=LXx3WZo1Af&ivhw7xlsQ_PWEv zHQXms4-S)Yc&tLRWD(7}IZpv+YM(xR#!H{K&!#YBqI|@EtAJqm-w(d$F%Een@O@UN z=TD5}@a2-?EWY6#N~RxlI(#necVX`ig4*L0)BWVC-w#ALWo?5#4y2M}>!kG!@fvLd z1IPxUKKnjLQ#Nw(dciS{$Uf>7={Q}njnlGO&hvYEtu+6fpCWByQ!nSEbH8*6g>=JR2B1ks41X3)ZH7p==GbVfF*LWho zVC~q@3;502B(>DI0RqxOAH>#AKC7xea!kkm@ai5Oj_{lF`6*7s)moxP@|vE?I-t=o zC}8kPzM9+sB5+;}cSk}4h6uE6tC-^sTya6N5E2|$6Si9YhyykF#|gX~eA4%aLJ2}+ zGzciK*%Er}DcBlx#;u|hj10*zCw^ciz)%xR0792I-4$CJ1aZ{^-FCX*c-WiAL3oK$ zXxW$|<~X6vOHztu01Hw_%XwZKYs?ryjE&h%O{qaIOd~x<*ukl5ta=qnUp70%fEH&z z15#jp2u%uoS#WLk$6p`4Jd2NM=$~8c^wZxQ9d=kF^$##e{q_gm@eTcpt*t_^!>t9C zzHHeY{2z}y8&b^U0Y(-4V-4%!%Y%!HokkTJ{jVEIvCBd1#t_$g_!<#&2O0oK*J6Kc zr-4{%Xbdb>V%-LsVc|%(01fA{ZSB(p7&AJsyPh*K^49hw&{O=~)WdYs*6Qu5`gPQ6 z+g9cC;WixZK1NPkLqp>3kcqcNCU$^qd^u3j^#>i9f*BnM`kti<^(EWD28W(4BeDya zO6x25nQ{6e!-pa6R2RVPc4M2c**%+Ota1O@m`1*ev zS%)ng^Ca*Xou2<8V6fZEKR3W^=-BW(6T!n|Et*t)efU zz~1oZTOK~ly73q^{N_7WqH4n%>{*A_q+xa7SwilpDV$YXk?@6$7z>AL$`rU|L^)%y zg<=xFPG-X;$!KVv;lG?J`o$V zQm(R$GJt8E{Rk!I0(}IZVE7w+d?JBxwr3dB1~~XAK|O#oOfA^3x~Wa&Qq!Ch=I9q- zEZ{YN>DjmK0h$t-sYu36W$3&Gfb=lK(AA92U~Gj90h|wpOm(P6sBkFUKA98y4}?6M z>Kr<~kZGiod3)_+o!aw#3c+H#8w7g>Fal(d+XV{sI|NX#7gSek5fE_Lt|XFXKf$j9 zhdgYIB&hX$-CW2dfsa9B41jDK>wzO<2#R(Qn$W@rtU@_0VZt-xLyd4CB0jycb+J-R zqxG(UCAC>?U$27wBNwa$+2V&o);5ZAM<8+2WBJlB>1ccluorkRdYe24h~W6MTX zddSj0Tjgq6JuBz1q3_u;$ELrH^U9atvg zJmVxm3}pZGs}EQ=9ules$eR@P`i0jJeG%gGiRKJ9FpruR4#VRP ziPzL~NeopsoGvtOOR0+w?h<{Bkz^(&r83FPxCqRVa~i2O+XbFn>#tC|qeIPA^=`f7 z=&**4BHtThtAGT*r39c}c|&9bR}DrV!c^O1#RAd*qdve?m00VZnkBEL{;ij|vamgP ze^a8nOsv-!+~G&8iHq``mWeO}$26Txp20-pY{a9gt1>zcty5*Q&_VB=xwWv`qM?{0k4 zz>Q&5!vnAkAIuO4#HZd&qo`yENcy;H?8voYUh5CPkn$@iq(8z! z%1Nl0{`3p#!h&M@!!M+iiNN$%0IP$puujLT0Dv>8SjqiqgvTHO>AzHl^qc-_#KVwST4=ti3tm@5@r2jU^)wsVsFZsWbr1`bVXm&#O93hCnOU8ymiX`Tg ziIr-IsscuC0=W7b)yzZ>xafK=N z2h_#G9I~Y?F^-qF!s0{rc!Z*R_O}~XJuw@;g=Q2IVpJA&cN(ImWl2AH&*Zfd(vfYG zN2_)3RBW)3qLGvdE~s$S#nY?FxkpGJGEgIe&yg_LL?;8_g)-H(`a{#DJ(i6~GxEaX zTu9IyZ3e6ST>tdbeYGW`7(bAl`#V2yM^Tx3~9niH{Z2+HK0=)i0(LV1P%=#5utG1knEe zv(5^VA+sH-wv9}*&a$n=zEra}#-f=+zwqNKsAfAIIr3htA=S{90T_uy0-v^CXzkSm z!2Bh6s}-Sy{@9_{5}AQd-GR)FH|3@TyW*2)QH`i!*Do@_TEiU&6WqK7a#JsIh~9=L z6XF5E({?2UdSHFmWgQ~7Dd~@tjeM`sq*8;C&T&raB3V6%Et8^yJX#?l(+EWd$Mvla z>|9TN<1jvH@)_G&6c&p+6MN0rlW9T8ZxovDa1=@OM2~>0Wsi(nHg~@(m~on6JkW9N zCke%u#+g(QMq)!D^PuX;>$1z{*ABAF3>Yp0TU#G>O3>{sbsGQ3 zdk%HzbKvgx;1~+G2Z zzNFT+0g*&uZeNFipwupoX}x%i;!>M`ua98~NHdt|HT11fDkAV1UWt>?am*PI>1W~+oa{(^V3YG$Q0V^y=s09K_K z!1Yi4F{ZR{LFir;;qu>UU8d1^4IB#e`YYF7}#4nL?vt;A=lIbKPo3_ocL3bsb3wq`{&$iZ_8MPq56 zh6?GZM!zW3wZ}@Oupba$JE$X&q+gUNQshC9~F45r56 zsBN$zPfhoL$3#?<; zuYD!g1XBm~k^#J8i9R%L-Bs*051MboCmK1}WX@Z&3DhQ-9zeIH zgwTm16I0BPhv|CF$)Q`6Mveh*Dtpe)qlz(4K9-Wn{|fk}bJ|%eY&gcH_!-)?kNvIQ zLLPNVHW`jgyj91Bm52|v&aDe*8>gmAwXGRQyQCo6x&=CBv=y`mTm$Tsf>~@z0sqK< zD{L#Z2Y-6k@WM8-}@Ir;4eH&}yY8 zyA|=rE|YK2q*+z5b*rcrbzlR=_8isu`RRGD*Fn^}?{?C4EWn1qV&iUmMQzNRUMbi< z$z)22n#4E6i;+OddbIELH-or=3pn(tEoxM+0@I^9Bv@LrogBhxcOjAp!KozijYtv3 zJ#7t>SY|;N99)4L9b@k3cK;O~xxZE!LG#C`R2w&WFrP=3PMjRLpp0+MQQ?+p@SXKN0*omFM=BcwBQ;{k%##o+0PDhzixO1GOf z+7ekEG0hz{*X@pq$gVAM4@Tbgws~`eDTkRDM#6z!jHFR&WKL7{v4ovmB>EJHaBO3GFW!VSWs zy~=HRq+l^=DE!}VkE<&j{ul^+C;1OLI4|K2eOyy~vV-@%ed;lCEohp(vTw^!2~ z5@LDNp{8`D3{`I=AIn@(3RixDt7Ghf zb5(+~q%)yCGl~ zbV8)A3K;XkRm< z_9){o7Kl%lv>JaO?aJ|eI2Hp_PpyACs<)&7#(ZyQdI*^#!iPbQHdEHDDDmX;(=Sdk zPOjmnu7H4dI-c8KuVznv(ic9uGv)2J|8{vk;*53Q&_WGgOUEn}@8qE%j9oHX7^b_) z+zrVjcFkw8gNgg=Fl2*Kw%w5JWD|0L{X5yOzjDK8w6tS#IrMwC*Kt)=yh`;*c3nzX zGb`^izEaJkJuOXoQa(bDNnKzwyT^!6JgV&jV=Fv`Ci(l`!0vh5*S}yLF6jpeFfvF; z{^B(9ZCH`O{{ zHcn8yg;`$VGT|($J}ha($TBL3x##+DPSVc))g>5n&|qg$eeN*vGNiPcHF6tpXB_*( zmK)cl#g9U@Je6M&-7@el$~`|iJUagTNW(6bEw{v6n$y|np`$X%piGO@+`BRK8;_C*`h`ux!W&p?*e*YfZ@W>< z<&h5tU4y3bE?4&WwdL3&?D`!rv_KF@ZOToniFa?|uS{z?dsrrpr~W4L7xPk5fU?-Z z!p1QR*xx<7e_p)*Zn2ucj%t(CDiqFOs)?P88X1wf*275rRLYUX<4*Vy)EoNRk@;`~ zA+Q!bGIxT!$Z& z>9}^o1!$=Rz-}QU1rRiBi-3c1k;yM?zopHv=g>v1?o=JMAv4?xt})S8=0RM=8RQ4n zxB|K19lhP|7i?<*zVG5vY^@g3+`yC?Y2RJ1+||0APhHCUgh?-2Njkr7o;BqA>J-v4 zLp(mx^Kc1vr1ij5t_5EED~>{ z0r%Q+da6UtvSm2hvC=_Dz1>@!*47U8ZT35VdxM6AthX(0!)C!}x1Sl5yWCzM?rT8i zM_0eh1iFO8o$7UF8Du}e=fx^!WW}$aG!LuC{ne0KDOy;F#*`|cp!eN^-tpMV0{|H{`ZIa9c zOaQ)?pgR&W2n$)ybH`G=qtb$Z?~40YYfsjZ_d4JNJJdv7*;ebAlU^uVZNG@MGh>n0 z4lTzn|8eX??(ic|Ci`tjjbKIW2R(^I(yk=tC?$5i8%!n$+qjQ6WHQHjanhK>Y?`Zw zl@|!l3N+Eqw%dCebkhOvUl7rLemJcDf2CP#d~B^5=L7TkdWIuUd7&p{c{CD9KQzXl z$&lm=o0|=xuv`H!4n6G>`CeDv@SBH29COws$&VXON==#lAO8 z4idb%k}a4_=hzinVFn+Vncp=hXi`Dc1=B|XZ#oS^qoHy}t zdAmn!&lrSTREs%bP&WK-cRfL_#e=(gGuf-?@av1~ijrkaP?s$ES`0{odo|}P&?yhJmt2xZI!N8&&%zt* z*V|;@zdShj!!o?cBc>`3;tPVk-ZK@HdNFgv9Acb$E{84SG!I45450k&g^X+UDA$WtsN1xjdoL4pnwo6#iLJQZK7U zWAOM*${iZil%MMU>4{F*Td~br;ZQr!^}tzBVFri*E2CEr?3iN zeR^bSEba@0%uez)BMetgog(v>L!t2&dLF8Xn@e$YywujP1)YKn^X4 z^ujeDHY?0kX)Y<198ZFpMl7TrSmOkxum(EMO7}>!362y%zCb!hzfR4WocQlGht7>$ za3%o4sp)uaN^+v|Wy}@dfGWUjl5|!VP4r=KXlwlE_qSs&ZE9|GZipdCuTyK*8O!4o zYNUF_K2NS*qts{BC!$dgQj76aHa@<*qdt z80iKcd`fv9hV5|K&^v24{BgpJS6|jE?a)jz8klQKBR+Q}W z4`)YRnM$^5=r|s7cdZ-J=*3~%|GN0N*WmX5H7Lsm;6FR8FwJ>ut{MC!Cw!9ALB9N2 zyr7gVk3i9HHbs#YuJ|%XaD31`$9>}SvbZy+I^V;0hW?9R{35B6eR`5gsNOL7rdXYz zrd3zut*uqVY*e>WRn1CkxaclDr*}S}@Gar$UQ1$;EK!?r1K}i3!MgCPq<95^&1ff=RZns)<9`(0kIO**}%BG|G}TB$C*L0 zyrmQ{^o+lo8t$`!gEYw3K-Q_u73+h~`a@Gq^SST4U6pc%#so(kBO!%}r6jUp(g`l@ zKhkv1yL8Z?UND`6;&A4FoMU z4RjLOzs`CPY5$Rx0NObGEt~*vqZf}0O82^zRPCQD)OH?0R}4sd2l{@P(k|uz*8uR@ z?_icLVa)d4wxjtF)pXkDizo-n){|#sqNGw8B2E9R?fV(x>p?t*r0{& z;%fV-?%I#wvfYJxsc_s~nEjD_#q=x5cQ#6p;q+~670l>w_P$P`b1-}qsk#kplj0} z)mZfr-LNp~yosF;aPK|hLV5{e!jBKX58e!!BTTl*s`{GlSl}!HkkN*(?8|9%03-MX za_8!xO_>ESgkCOyR!dJYR680{4TKuUE?1^J*j+$1N z(9~c!Cc4!q7Xw90I^djICQ^oOkFLQ-@uZr~43V>!Z&BiQiqQDJfNI5ePU73yx{+J$ zA8^;C&o4z~Qg*P<(bE2oMr#d&7r(sE{3|K{0ABqQ@V}q_{tqqpQ__Zz83riCNi9g5 z#=B*n>1ELJWXpYof_~6KP@uIprl>P9kIXfHu<))g9Y+@=j*YJDqTwf@Pu`8@B!El=NgkH<;GANR z=Uv%2bJ^hSN_~W;9)M^Zx+inb*ePn1NShUL0H1w<%oJUyPT2(O^1ZQ9Yvux8X-XO& zcoXH0S8x*SN8iJ#t6IMX3(x%#1FwSHf6Z(faN>;-w~X1EcQW+wrU2AH*ff{BlDs4LA9|ytlh}PJ zta}r+mVdy-ur5J7pc$(1pc`oIpzPQiTmB)vY2AUk+KVUDL!cEF?aI-V$61vv#vwH+ z;twH99SuERrrq{8`5mXxtD_Uy5!%x>^qZ6~A9{Grs{eKna+nS-PjEWkMRwe1jNeXloW~A(?*EM8g}f|jzj`C2QhS_3X~gSdJ7u+*qxJ2; zH9kCmUSp{3D{HU?q+@_>CnIBz?IOY8{{L;FX2*Di#;~B7ZY0gjg=of97up~P|XL(1u+dL8c-={DE2?v5X1PE0WRPJ&s{kQ}uaTB+-NRwbY zg$nWR%(XjvIdRGd5&{8@?cLdHc6N5ozMFl#^8I|6*HF?As@U$F*&nv?CsrW#WQsSB zi!YzdzIyLdkj{Y)JzSEgSzVNP3nukUVO;yJO@zB*o2%Q}6L^>?DN;oxkz4FcQ$wk! zdw8aB3J(>ApahOAe5~tqgHIF}mJqvi!9J=Uy+2ZG z=ok0SC{=a&c)kyZ54xMv=Y;I%U~?w>#pQ=nnB${c1T;J|Nj?`^d-LvIg;aK75ZBYs#RhT${GZGc0t!gd*cJe ze@BLSTwAGf*f9^7wws7HJEghgButWUM_^!RS?ZpZCx~$;s_XD-7YP;Wj_Q;vR0%%V zy17J%gpshG{f@~n$Sh#4O)(czIGpi5vK-^Up2z@?y@>C-G#}yzPY&{{E~gz~?7)8X zB+TexYFR=}4M2jd{KvJ>$T2J&Qi|v*UfVkgVwNa6+caIZDn(3lL&%?(pCm1!>)A z`d~9-T?(3tb=gwA25Jb)CkDwaFkD)oTDKriRZ)(KS}lcvze8ap?R=`BN=X#Pf~j-I zZDYJMWa_0 z6t6eUI3n>hai`f=Ted={M-dlWzO~3cFr!*msU%Ms$XS2IKq#E=12!8v*LU@L(GdW$ zJfj|~5d<{k!^lZzxTbNF6*3(ST&q-g^)i)!{yegb%5rLk)tIF){c4MB%Z-P5FNREpLf+`*PL+^&{)0{ zevx)O==@e`8xtQ{9eLr+D$ zN?4GxCH$qTwjtbv;a-}0%$%*Vhh{(}mFQ{>6tL){O`_gZAY=k_PhSkH`KYp-9E-C$ z%p%IRqjPoEkEI;R;htu-5|jiUh@*+y8O^AV26G)bUL!uPc5Qvnfe((Gy&~x7;TmqQ zxirqR-)cW$i}1FeJrn{G%9v>?pORS)HBqDhstwCN0KI!nZg|J;OZ12sumUk`sngFu zI_Y30uDr$HH(!*$RIOS%$|#|-dTNr|g6n!ZWl{r&w!8_>ubT?QZ@#CZ*HD`HyNzi`7R3OJPs;0$ z*0~c0}+7jA70PKr-Hl6=5lfV>jSY)gSSt~R^Q#%0-{ZJW$pG%MKr zsprDUQYxvaF_2IiApt40DW3rl+HC{DF5}i@oTnM;9bb|sYqES@ui%-cS9}NhB~~L4 zEd2DsN1M$EaO1&|Hhl2_FDLlR`TJhUaV;VeepNPHq9BwoJwjwIyTvW^tGp=@JE&JE zr)w%$_M3GH?djU{zT0pmGR(m8R(@WZtwFlY5kGthAM>AH7{#-s>rZSh;eY~dYn%*M zHm_RHNPI|M7cx8YLYB!O6)Y1r!lyk!qpMxzPudU@jG1q}b&vmbw+ z_~&w6%{FUb@oVF?VclnDZMOsp68l_uCE#);^lDiG?OsDrnVI>{bc?AVcLje22e@c4 zfr&<)A~JTfn`X1yrkKdjX(j>>FXq6R6f*=uca8WYp_op>i>h)Bp0wI%g!D7Wsu9cN zymwtcYL}K?Rv>S@h8~{Ib;S>o*O6kgd-B3UZ8*qVSrI@-CaV z4{fo9lfC|((hW1a0O*SkY#87-9jG(`-J?MhHrZk8BTdnsRR{suO?nD5u9~On_pT|8 z*A9Z)mM#8eSb3-|(&mNFx~iP@+-bNbTX3!ihqp6{v!9zz&5Lo+Uzc`-_a^R~Ah+yF zJ|wUfF?4QIgxRu)wlu;e_5(is0ySfN3rW=}0&b9fvtKkJDwXh$>Wga*(U=Q`@NNiF zBdY_4!fewP%j$j&$77+MPUHi*MR9-&19=&YE7vZXdP`BYN;pL;I>;NMOz>1>-{HKV z9)&P*Rj(_?_^$T#;j)`rlkhdVDu$H#sZR*N{BKQ2bSM$3D*r)_r#m2R66V2hGF~$Z zWr-&X6%nho>K@37n)Ui99K}f?oQ54}Z%XBw?snm{v{}xRJ&|SFd;Ug8_Bgw%Z*Gar zg^`F;!Rmppv9gN-Q)4Q-Ygu|h=O~zp-!Zpbqqv!Zwy-D_*Pl`$1xpc)0@Zuw)LNA& z&16O_N%>|>LInjUM2poD@HqfKgeE9NJy?z-hgtmEF1ji$e!hlx>BR$JP=0w7d z^fDvLwA%estLdQOLEP?G9H9EL;9@#NhYjB!RyMc2Vd3FP&?pH`9HEh2_^oWjDO397 zdp6Ss3J=5gyJq#$`!8Q!d<^!<=#%+;$6R)^+ja`)LT(fuIwR?BC&@doN8)i_0Et$> zwNLo6NSCv#N)d~=op$5zW*bQbCup;~xh+m&Q-ktR?3>$+I?9JG@^g`nrCEP*Guq3J zV-d@TJ50E8;1l=Q^d>W6yo2XS#CHeJ9<&rY9=~##bNg!`f8O_D9Qhucc`+jr&N#96 zD_v=5#>ajBAAlUsW|(3BJ&1^sU4H%H`xhTwyf0BI0z<|qPz7A`{UUO%wOcq3Y6(vY z&x?w@huJC^3p6h_IW^%84NP=)W6?peH)vk?2ReHse5JL^#QhjR`g_LbshAZ~! za(EpXJAf@W!>j5s5dkC{xh;PF!3R@_5ObvUc|e4zU&Kq(FK1C$=Fk9KSP~(U=Hw@~ zn6`AkknZngy@w$&C~5IH)A)bUjXyHs?-%eN=LYjHyN2)88TfNX+3krp6O`;#7CbGN z-J?9o%|4O8vdM_L1yhXg4%^#PF~1^%9$>3)N(M^Bc7^(kuKImQ^G@l30f&@yVHdxv zN|-fN3aHBa^9^8Z-CC5d+YK_A88DL+jtyf1LUj%(sNnc~5-WU-k9db=mK7Npiy@=h z1+=VDXh_0{JQrT|SH$f4nY$k3Hh^$|A7KnzR^^IKxC?)bD;&XDHpP*In%%mc8D$d` zAuc#BAQm1ow^ZQmqnhi;c_%v2JbjFc18D!*02&68kUJt$aFkGl#vRzk6fVsuuV-Y9 zdM4%$1xU>iok*|mqTxib88dmpzy4?ir5K8W<3G7HGcm6(zk2^O2owgEm0!5i+Rf-F zok^Q{40+sdArTiY;9{}NdxsH53Vi(~aNrg`q2AuVImur0-@NWsz$wbkyk>O{@fV7N zlkOsD&ReY!FLaM1G@jJ{(%5TL^wpnxjXqH39awDKwKq_4bS|bdY%bH-xZj&Zu{{MI95MFBC|P9clT*o0&2o$Zr^^@!cSIdm}LVHf%JV7W==$LIK>e7LHFUn189 zl%>(GK&PMGHrs9A^>G!;a$E-bkyEBnF<6H!;z9pFQV7(o?(q;(!EE>dGHdwtG#GJi z^85`8y_`(fvr;jsMSyGZ9P1Z zMAU{)dM?VXMA%dn_w?ETupSv1-W#7tJ5|kM=s*GmHvN|Jh0C`K3JFR{; zq@QU^Y2V0b;gb*8(k_cAfOvnSCJxY=$Ej>IQk?5VHdqVOpt`DEkc^Z|@lR|@s)f{B&@_3XNVtE*Rlkjll9xS*uo^(inZ zr{u9&&`tN4oakfS_CTqAuNFIO3zvq19ng;W%?Rr{YK4=N_JC+AjEq#mofSvQI{V+M zbYTW;V>T=(S}F9MF3Yrf*i>uG#R;M=vG&<1()E z&ErS|RsFm4!QY^f3r0-*3zTFvc|e(JilT8kc)fbXSDYAjoca{b*cPN9gZHf}Zlgnv z1#e!Kfw@nx)~sLThM98b0CLv zVf#8BsGc#`4f^YU;^D+skw+>LZQ|v-%{_(50WBRUD#~BbK1UZKbcADbh1J8;QAj>k zBbn^HgoL~m!iv4Us$>vt_0+;sABfW(57;mN2&uC}m-^+B{b#zOWyaCcx{cYQz(}+h z&C=1ia)eAJJUu5hpR06A==(TwhbQ4u3@1SWoYN@Q`tkH7qZQGAP`5v;{}zP9uC7y_ zgr)5-X0EZ83|S2?URCK`Qi3NnC(_^XIKj=(H_cs|Z5oQ9!fy+7a@gi$6Gvruws!z5 zt!r)~w89}H5kF4K=^n;mCMcjX3cwqV-7H;{yGly_5Ys0*RscfoLocc|E#PR^UBC3! zC^1jfDm9>w&!_J}k|h|;=a?gE@Ys@wk}SEbR=K6gKk|rw9zM*zn}3ijYWm@mZ>C?) zzJT8_oFTKJckuMJnIr3bv;JjGw=TutQQ_le09{Gwch&PvWsDQa2+;VtTYK886>1_D zH_XdYc(q;znga+Uz2&jPF^?#J)qqpsx_DX>`r&l?)pYQ5icCeCh{d8x?ue2dA?ts_ zu!syM@6D*2a-|rf7)EKx?`YyX&JT?Y@T`;QW{r!!a)Kci#Sz0H(aS$Q7*8nkBB;OT9=+AW8|3AaWk!a~amj!F!W7 z8C{f=eD!{h0Z1@ZE<#7AmNJJ=nx*tVY2)%G;L?JIPEYoV34Q6@$q2gLzZ@bXO0@t9Ti$_K%1OD#|z#d1$Cy@fjy_oNY(j zO2v&FJwWzLs-2WXk7FopQEs=mh3Sb4{EhB&Zv?cFIfMxnQfUo?BJ+&%C$t&Tgu2o+ zb!v#*6CvQ7syy)VDXWd$5D@6sO=%fS5K8vXxGom}B7||vm}tVJ3LI42FF8NnoaLaU zxYF|{17)sl&Dmy#J0$@m*#$gs`);$G{iDae>9O{^T56A5H`{m>&!Rutc6S8_q`-wM z|Dsb!r?%5XX@JViC)pgiX!AVA7d8V0%@4m)nHW42 z;QewUxg-xa5ykPsw?F>$U#~&k%vaJ^bi3mSVgB>$;Pc1M=T3=t=(Zw_xXU2XW zpiguXu=~ypkmC2^`|QQ=;;$=IPRy*_H zLWfO0n;p|vfeH0=0>cf#3mxHqc>UAa@6aWAK2mqKQL`mc-L8{0P z006i$000mG003}#G-@w!a&L5RV{dFOaCx;FYjYYm@caD=m41Pm-o$C9?a;b27+)Mx z3>0uuPm(#rAD`T}iE0%UWy&+a+-pdvU?x9-h6_Ol%8y`YD6|ILpMx z?E4i5JPijHJVU71Y<$lbQN}iL!fr(*5*~cN+b6p?6;>#+ZMi-|@uft3X)znRc1TJR9vWNCeM-GLF~;M4dkCvZqf@pIBCp?}1g)W9K{p zO#%@f*aZjoa6h#$Vm3(}GS}Vgvwb35m-*o?PBONSd@o)LH}ErbB#7dxG!FKea3zK% zDbW)@K2uaM9WVn=*6!-ZXCgud-O|`v;FuzX4(_Y9=1>7B;y|=wM-0>_&Xj=^a z1u?`AiIukN=BmQrArm}F_~B_6QSlp*y5Nk$pBf84b_cLd^n(H$vMh$c=&C=75PoDa zes?U(df9)~cdz*FIfea?&Tskv70_QLr7C5CX3ufcYobkX=CM~N!oAqz>Kn8TqLVs>So#EVk2QBNWKh>YSWsR># zMrzcurv1V1{YxWZ2hJJw=VsC#uy7VHCKq}n1TGWz%DFms$oFx}dNmmu_#K#O;mrDr ziLuiyNV)+o z;5WboxXib(SiOCg<^(QIPU7S&PFQim8j+@4LNi8C_IUos*o82g(XvBD-NEGQst>=X zau9Vtft3&y%#-7kSZgRMJ&3~)wyD$_^hcu{T}Pyc+@24e(M1ik0T8Yu2%GBxn7W4G zBf|W}_2OdkdhDtdQ^QIDi(;k*2>^m2IFbvs22T(U6Pe8FTtkGta7NP_WD5`(!!-~` z5Uj^nL+5o3_#T1Ip95#wsLT$*7sI+0)(8L_IsK~|Yz{E=K_53|mhM-ca)wgy5?0py z`n*PWYvwRNf|3?cM#F*Ah!hy4kuz>Y5e0L>92U)tJT;M-*aX7M{5UGkhMY)~xgP&M zp1f|BvCXl)jxX047~z%k5UWI~I)Dm##k~(K-9haLm3cY2 znKERj5}8{|T9ttlMtCxT=#2LWGKHBFQW}wJl(g41)$APT05rmF{RrYMP4dD|ivFP8 zl|pcSEU*cMiiKvJFvpSQT4NaJC@K#$;oQ+gC)bPVH7#k_8x~ZFe4;8^S2#V^4tTg) zbM~>z4m~!GBhitSwyM!8Xm>o=3q+NF*#Y!oNxQ?&u+U2OqNS*8y;GDRQMaX=wr$(0 zv~AnA?MmCWZQHhO+h%3w=^m%g`0wo)@emL3xc1z!=3X(s`U5UeNgKgTUM?t|WzSKH zhku3AMNrd)!tw^GkY&m|5pM~9$a3k@V5$4-Wap}zVd3p+c!tOw_XbhwLkh2v9heN&3 z`HMGC;x1=AW`^kY@G(RHp@eXOVC?>?m|Q>0fSxAW+mubcTI|XY3~(!s+zqrA*En>> z(6G^#G@#T|6d@&SjNT|ZcHoYJDV0luWIE3WG?N?@OC*7KMc`GABT$UFpx?3}O|yus z&k`q+^qG*G;aTR`pN({Zxp$rH8DsMxV7owE@3VfyK!de0Og7uKlFeY7RO0JRefhjK z){oq>dlUhood5ymxVLKdK}}Sf8s{lq#GqCv@hFAEXd^=@yYGDtz&}FxxkFs>AmCpx z9w`<^2mX*0q_Gt%&S?k!!^i&O;}-8hV`iwSC_JL}*Zbh!s^44**AMZMKY4E=NgXr+0 z&y?Xn=te1E!t%vqv2su-OIz20e`wU%yMV#AHUrZ8UGTC-^KO(l{V&T$8Xu|CuWO} z&*^caji59?+UP^=fwJi|DDL9Anz*nGEPu4RLuFB%GFnl0z&o1+=&WJ9`3}sB|I_qA~6-(&=Aq zplpM5jKE(gk+~c$ZL^!i88M^o|IH<~o)lv)H=)(4p{W;_fAMl?_%QIJ7Xc{eT1Llc z`tclvN)NPy#_S9F%qIMY4S3Xf`ZM}(IW;REZkZ2W+5{tPS>;FS0V59&LB2c8U&V=Q zG2^rA&AwPC`YqI~B_yLnzYW^CAJSsmiUXb%2aRFu@I*~5mDJ~B?qz}^Us`*`UCHb1 zo-1Odj)@@9*DPf)ylD_kU!73owJ2W;fAwjJRBW;48tv&U6p+p|lB0j;AD65}cVOYl z2k6W)koj>oB+4lXHh9IaXI)xtj@_r=tI&6n7_8o19ikRVjvdY>UM!!&(hUX`HoDAR zuSMP1HYv=dP0G_MuvFEW`!iH?Qt#0KuF}M$M46^jGg6} z=*X!|TrLUC?damvEo5jL!j3RI^l^emg4Topk}ioSm4#=DNE_Bjpx|Z10-Go~h-I2# zOUl6rXRZFBzzq1$P5kTml^&qHyrdBCHA?cUK|mIg@qVH#96Fw#zL#E`0n$r4MIe)) zM~=OT&ZWEiADR{LUetT)McW##`gF?j_kDC{0s&T3Wdrw@Dg+j>$^^A+N3J9ZB(h5b z=UpBTHS(`Eb*{<;G$?YDULYx)bql0l)90N?2&=?`m8EfZaThvTh&q}v9|^O6uzm8q{AfE~SwEfzPyZHp)pRUvpM9%v z^6Bcy-%%$~H+Seo3I8qBHj z&n;o@Ul6P8YrZel1{ za1^6SZy5QkCmn##L{QV0Cnee)zAK?*4@yR>m!|PWUB|tf;>U|VkZO(aK{zJSHS7WZ zs}(Lo8&FcpCM0uEz*_0NcisDoz|b*-<|v_1$w4bKUcY2}k+cvK5wj5>sP2=D8U^v$ zo9_H4@4F0V_JiKrfx^|E!%_4ypQAuvc@(5+pkkDn9}mh52$>kUTra|m<>;_*9mLvK zGoULYpljp@5iq8!Lg1e#wY-CS{RFVPaeK=17SayONIUk7eXgdfKsO$4@-qb|9W?4u zt!NpPOsC(&CQ=$0plWQH=;W?(bGc(~^#HGXu7X?h>5q-}AFwrVzMb(e`plh!!p_MZP7rMUk>OsHt&tDi>m=Wm7~1bf}6!1 zRRPqUBH{d>Z2yy2USjl5T&Nx2CKD6@0OFVE_S^ge*n6;8(mPq0{Zifj3+q&^Ha_!z z@P{;QBxN9QIP%4K{4|1w@#?Y=pls<)CE>W%p}9Rg%(@3flwq`?tl>Pn*1J>et?S!d zZW`5Cgz#D~``?Y7@aHnFmn+q36_{N5FJrXWO&?@f3{lL)ANjyvT;e;JE8*QIAKIqt zpGKZs_ErK?WH6s#AJT#ahZ!2d?tX$J0U$nd*RbYWh@c2yf6S^?d5{-iV&Xo(`gWJO1hQ!98)d~MkI-V|3V#HQLR+PV2tOkkU;PRYxJ6I%*X! zOfLU}?!Zm!#4=z2c%JX?XNy}2 zy2BPXG&xzp$1DjEz%;jq&!C(|heXq7Wj4_UPMlyeN*2YmU4y+-vhM|jZ6E_IOv{_9 zXS@&<1IZY8V{k4wQKgWeE{wvPFX2^x#&NRlY%~fW;Gyf5>(OUY3Cp$N5j@Ky;_xmp!0~LxGD9C`ZR?L$@4_;}W z=fh)12w*ymhg{cUXuaFij0hye)B9sKS1Go)4ewt%Q`fGzvc)#jlLBmD#H2c1oEYQG zMwmN~t167L&u~1}oFPP*R!k6ue9D8I_C&TZ{t+RVbcko?ffe?PB0fVI$1pk2Kd7wh zd0haapo5G~rd;;ZpAybe3NeJwRC5tEk=WvTQ&XTJlV1D~ug1623QMEGT8MKnk_l3z zUh_n%9XL*gvv)2LG9-qM*L#$fF%-x*!93xfGesyOM7`M_?NJqVgx*exzKD`9ru;4M z?j)9Suh(lXs25xBniX*Rz^78cSGK9@cKT0W0{HYh!by93lh6;##mM# zDi6HFy8~NFJ0KQk`1eON7TPL!UOAkqU&zAjVLDmC85SA`vXYK=0{O$GxPv0OoNl*l zWy`lCG}IJIO8!ot9|X)uP&s&*I5HWGAEOVEH8J6|;>di}t9O7j6ITEz=6Bfb^^ltG zBDzy$^?}V{bnmzcx)we`#sWOY_wd(YN>OhxCm>fJK7O>N;Fy)+m!A(_H2~7h&i8J= zKXeVq%R)u0vwUxd%MC2{#p=aDbZg(i^Vzsn-s9<2HDIIDHBuWkek^yb{Zo(@$t37H z@>D)%CZ&v8{}O&wn$KfGl&|8E@!07%PFkocAE|9e6IKK*YUcQ$cycB239 z7a;)t3+Pv4H}1O+0svt7tM>mF==XnKG&Hd_GPf~sv|{{K{$e#PyF~^R-`m={0Vyb* zjhM_f6pKE60?v|Vq&!9u+k=*!EcS+M8`OhdC-3oM63Xl`A^Oo%@6$)KTq83|_ka`+ zEXtX5VV3bWd5S?ng(=UQTPiAF%8a;!)~zw!*&(w`a6EjpQaK( z&rnRQy)#eIAy`B+^}QepB7mY-3tw+vmY=6*XEfh8Pdr}^-fwokcSPCPQ*yoCLP$>- zUc^Dg0ypk)0>}Qbbqq*CF{%2Pm2|8m0NcmgVPN=2&xYqXvqJ_O@*u;;pg z_5u9nvzJB|p$i z$|`df(mWUMiRprO6jj(N?qpyRt3z2TsIkJFV{)gGiQZPL2?&^8QPcp~axOc8k7HxP zn|Y9EFV!o`9y0Nh0KX=6%zA@{3V{3t3n4=ijs`8z;Vkha$6!M$Vj4r-lRNjC25pV| zDX4aLaCPJ(C8)23Fm>PEGR@hQy5-D?xrzD((Ucwr1-le8;XHb79YUkWDqkoffCVF%3N!+mMv%oO zG9>{jK+wSnn02eaA?DlxCkkxc^JD`g<(ROF^Zew`JLjn#A~m5Q3oP>t)SYKVmE9k< zObT=$OW!yk>53y1zfu37YHSkn%>nD8R0$x_09?x=ZdpLVYIaT&hD2Fs?33bHmle7o zFRelxTi7L6GXXRPv`wl#i`xPm#y@qtrq3%uj5p7hHh33;NA@S-g)BJz?GCkRKzUk=>lp3Bg2)0{q#bU|ypdlSK?xlbO^L0mOuNC(y}#jVC?HBUad$ zt(d1owb|rRMZBiprq8tg)pov8Oh8vzyM#S)5nbC#BaXP+dZCmY-w%1Auv}wTh|z^M;)M zws3#Kr};f#B;a|1e4XEl{<-`hnQuOu-fMyA+j zSVB)^c+%~FzGmk?^|0KTf@9eSBme33W#$ zPkG+{P=S)7a>E0xo6rnPba}S$ER3ziSoFYP#w1w5PO;VP)3_F1)8@}%p)GjrY=F;2 zK4RlbjLykaUxB2TDu-k^dqME@*?=mjx%@EVm=I3WTT$gk-hxZ@5RXuS7J*J)fzDvN zSJNUz1XE%7p3dZ3c6U%cTC2M`HDu#!hJD9bloO3#-0fxB910*euR+FA-Pg~`{gZHy z|8osV96lGnR07MLY~xox`#Z3w##J>#C@h*RMtUO}a|%*m)sb4D9=fL>Sbfat;$%X| zS?b|Zagq{S)W+X9iW+RrQ)kUCOV;(y1zkxE`wlrE5OnzIYCS?Esg8wWRBWnv2TW38d6GvAQ$Nw&vYaA=5P2s!Fzr*_sfv`1X zvIhE41{8#}{w?DyCx80fXCBs7sn4TiV;Et6vbVdrvWpUJIAq*m(^>0p>RUToXFHQ) z*8<(}>*Qq%o&SaTFnd1i5nG>6F{Sf6Wkl*dvS-v$XiU(w(n89;Ya8P(a>75q8Ka)g zGbQHBL+!0gsC5j-s8)_QY$VV#^&nSN?+K_sYYx&R)KeiSxwg;!qo)ExQodH}vB3K=6&5e74K9%UOed=R} z>TNzhCEPx_{7f`zxE3c4Df!QJj6XLF2BB{@uv%0bj2zMj4&A<7v@Tn47$bD~9#&lBaM;nyy(6X=p_r}n z-x|YmG$h|s1)`&i+=$0b9kzV#BP<7eizTDD1*_&FW;}bjr4Tuk&}pCj_(iW?_tq&>#OLqIra0Yr!Z4P>+jEhI^6|h z3m-2Lqj-Bu?pJrSe!)n81^fyyrA%&aK@ZQBU*FQWsVFoWl&rG77;Nbt7P~K&OV^2F zOk0DtH7TzoC7b*rnr^{cV9cFkbDGWs1gyrVq64CO=+G5DT(A1*G7+Z!I|N z&tGprX}n^ypR_Di2sgpDc+DH_=wA({k`3uWrSbZ8pRDt+_7*Y8i;F$Pfn2?3N?AHe zg}DP~1S2|i?CUDkvCm;b8DKJ`?*<(u{H^uL*IbwIhuI@dI7<;}5a$?S*USmaQm2dM zF%&14y&3frQAbt>2K`2sOuMSpQk7db!o()|$z6S|%F`&Y1~NHC z&Z08vmnIwMtGS*x59Nnaqk_G&v%l}P-I62opqPFGyPrgU!{w6z%R*W4u~P9?8{&6$ zNCZ?obFJdf7}gwjyf_UWq6RP?0ymJ!6-X=SyQPn0?tqVu{^Lej0nmuSdH4%8kLs7a zZV%~!X!8QoXH<|z(k5fYzYAqFXA_cAlK|Nl&rG}xnHwQ?^?JzH(JyTZIqM}X)*Dfp z=i}DCP-sf!#|U(=W=-n(Hraotvi-gER)bdQO|_~IR59_bJi)W!Q)goBpq$i5xdvopAP5MVn*KgX1*JEdIoOOd()UA= zY56B7EUp&inyXsZkD43opxQ565pqFGkjhcD78g!}Mem3m9I<}#f|Qn}8w|)?5a=JK z3Tw!M3t%3lH58&S!k^(pC>kaM01<{baWG<7VlRFqEJ9ing>;=zbvDgS8=iC|GY%h5 zAXkmuNI;(2MhbLLxH<7R|AHFqJd~i2bbwW!ZX@81QEg7{B(t*!V90G8>0x^6d5*b- zj5n<2p6GcVL$fw$F`?Q~wt-dbanaXMy45}oS%jR2RnBCdhyA(r`LG<1#zLs<_diQ;4Z7Og1Ryc7^?vRQ_MN6S3m8e8c;z}UminX{VT^9 zoWy|9X32zn!Yg8Wn*=ws^R8)Tfrcn115HNUEJ40xc;6W&+9Bf(}lF*%Km^ z9z0oqV#8ptIP}Nt9ol>@>_BK8AUeb(`MDSw*>}&d3#|eL!P6x3x<>R9HYinW0BNo$ zLb(J4Yf-CzFu45-*KeP&3d)^PuXs*|B3*!AT!oLqE)mlndMT9z=<`ym$l)Ej0sCj{ zOgAR9#j~vQMQ(FDaR+CX5f%fF+$Q=k^xO4Zp>Wk-p7&P8g2Fb*Mu@|&)ZCkz3g36m zF8y>1m&p!qE7$rtpCbG7@RK>PV5H6y!3o@;n**=cZ}6)fE^QMxf%rl5zDd&JxzZ~N zGzBOMbp@zyQPF_E_AlQCj7>2qJRFxy*LsXoaTmUwHU$DXPPPYeAbrV7o6WR|oHL!4 zV74K{n0f2$d3PM8JHgj=nX8h3rtp3aTp#B0^(@ZCJR1LJ67DN}9PDJ%C62fBL#=-Lt>c~%7yNapReYcFMbHH3B%4$`d@Z;_=mf7{NtQ7AUu^ySz zVMb$joKe4i@h0tY_lpW+4WI|yR+$U8TP#>lr}aMOmY>(WZM>#NcJVwrC!hQ zP^uLsuj+Xvqz9|8(tzH1qGm{Fh~UA1z=+%s*0d|CjMw#SItQ44^HHrGz<~j9(&hD` z{D8BA-g2mSqkP4Kho_mRnXSH(d^eu{{=H`sl`1kUddJD#xFU!l@tOLRCRf48vcMGC zbC5bu*jJEkCb$N}NRH9rvb6kVkpPuYW7A8eIy=5uTazJKSBKMsY zhXaj^C8y)+yXd{VYrH2A8BS7AyPbc#11k7G($%11TRIk&n!DY}fAy(X5BpHvY z3>N-6x4IJJlFvI!BE(&6k76aS5=mnRuo*bk#=_@{DIbWE*}gtEAUo((vt7L-y}N^5 z={u9>gZoGHvD~Q>gs#*CY%IW%Mo9;!t?=8En-loE$|$)ip()X}%8GOC@JN6rbV@k{ zXMzgB*Gw{Cdr_ypsJWMg0~CKi4;7Kbd%5J((a0HtaXB4FO;ZI&Y*`PSx{u(E4r6lZ zq^1>6^Np5xV8?^QrW8h8al^o9U)+JKVR1}6eB0NF4n-c2Ev&20nsasPZV9L9)Cda6 z%(vaHnDMyfg|-NV*U+uDEJ3LfUeTT74Cbo^Hu-r;G(f5wqzvi_d*KyRG7zs>tY%yL zSWQGOV1BP2tSA22#_h#VpRma^!R|FNas4De~-wWS+GYQNOdWL>!nnp0H-7P4m`fan& z7fXp|EheYj-SN_Ebd!Vr#k>!GKZbsT@q+S4=CD=tT`tdArEytvw06*OSa9WJcNERx zfyBB%W6(43^4`YA>}?~k7XcNDwhCNkI`u+RG5 zw3&_pPl>O)1i3fdpeiDSTOoTdf?H2@3@@`9%&%YAYApQjgtsZ(4iW^~)@iV)2;e6y zK{uxAsIF*Y2R?b3bJ*Ls+)15#1NhA}HfL|OsrTh!kS@(2NBmN%`-Hl=5K_7(AgY2I z)FRUPU|M!m1Qh_%T0WlHi4@N*LrS#7*u~qZgR;p)$j(EcbVa@g`(>8FgO4X8Qc^_8 zn?tobMwomk!^#Rfh!h{lfY(Z~PX-W`MbA`#Zpm#4I~D_w$(GrLX!C}FW(hC?^!e)$ zNRnbCL#aa5Mcwk++dqL-M~dZywY}5WLc?v*Tte>KCi^5zztZCH z9dzW4XB#+5Ah!mbbB4rAzVy(hU*=Jz{qkYIT>$gOc9@Py;k}nh!EI4@6_Y(i>edNY z_;kMeS-Sl{cihr#Z}JAN7P1}ieI&iBmTS?xBnp5SZhx>D)(PMs4`Qhjd*~s4{K6D`BEwuRUOx7ED*?)TdtAC-r03bQ*uQA)GT;0A6 z^aC`<<|ER!$GCWZ7QAv8rb$+r!^>=-birgOl3puSiwag7%FfAs%(M~&5^%nCbxXBb zNe~ces(u`H-8_?zsz04h<^I#6NfUelKmi5-Q2VuMaQ@SxF}3(T-nQ`kAA4q5ecLvZ z0j1}ZlExEQGYF`Hr5tIY!n!+cNoJ#?(xL%}y{TwKHdb?wAIp#kwB0nKD>q*MW|sC6nEEMO8Qpk1sRtNtz0l+dnUq=3Lh(!xu$W|e zMbSbrj3t3lKF3l%>XdmPrTr)vOw&SLwb}z)BRhH}?sDQ3Y^^(>$RVs%l~kQCr?!g( zHLCP4|DT6UE@|$6Ge~Wg!psYu(d-N&T8&!`RT$l)>X0dybre($0nD@tY|-1$k$cUm z<#TYnDRMlx@VDcgac^g22Q;(eCgjdXj-c^1K4}z5CFo26+@*fR;rs(?BS`&-3-WWJfktsV==o-n#za zJ?Dy+N4N$w4FD4C$KkciX_~|lrX8-=uY1B0wU`y7yX`fm|E*nckVL;h7zZT_7B9Dp z^8-Q@xFZs;JOG2tji*Nt$YFjKPeG_B0w+R7S_TZaBs$4CZB}hXjG;Xp1D`=A_@o^> zbhJkpj(twTnnK&|?RGgbV?zMnzcSp2A=ZA~eR2ryjooY3ctdV9AKf0ochcU7?we~H z{0awVZZCPyIlsAhB|Ww4K6W?k>NxxqeEW!bUBJXCBe#zN`^+lre;Z17PwQqF8^^IK zmC8^$J-P5x_Goh8AQ9*h(ahRd0#4YU_vFm$Tk2}s!w!4zB^EYIGJR_iDAys}O~}^} z8TtYG&(%a@UDJ2*yP5`20RXW6b2VAm{+i=1HovB)iSe&Deyk?t^j{+*pK9iNhRh@6 zvf1N+l74rR;1hs!Si?I-w%>w^SP!@HG~*4%Y_LP1rgzdx9?c&w-77T~It5L1P~slC zYB0wpN=h+IZ#}{My;pu)^6KCx-RpdvvRImrnpwg@`L4L)G|j1K9LiaM-Jq5Tg8<>D zUEA+lH(9GH#X}=-Jfe2vvlo6;Bb|2=os76sw zg86Q>=uNhM4?Z-2P2vKd^}S`619%NkZh|z`wK(7-8bqn=%TZ~99mO=gpfG&m>nM56 ztOpO4$RM%v=wK6{VXNk(zO@u_SCqt75WG`-G<-h8b}5r?x%1NX!3y^OEdG5t{%a8P zXnuS9`aZ}Mf)0RL2^%0Y!=aD`ty1QyQd^zZuxeE|7%68b_SN!~TDgO?z$$BycPj7h z5e*S2l_L75%+KPw)s0f5Z31$AhPjZj1Bfrtycoa`f|i_TD~0%HtZG!S6GgPitSA}d zRAq({4in`Pqgq0DN;f{hze$vr}r4 zJjgcf?e)l%Tn$*oN4cd(-11Vg8S`2wb^}>|`Yt^R*LK$))}eW2o5;f0P2MZ#^x)q)u;z4E zwGiIDdMMrf^W)&@E%snHnR}mOrH)`r*T^MK1kDS!DaNpv0KbST=pNEU#*KpnRRf?# z_{xKKH5^+tz#(xz(OV{G?OQsrISao;;~x!+TrYj2X9=kje0mYLPIWq7-Gk%E*2hPw zF86bAQ-YNje*+J8J0e!{KfvHqO@WY|0vkSCU&;+!zUXJHF_MMuT@3|@0+5G65B7UL zlLlYBImV+T%M7fb&j2Dm68Pv zc-BSmafeC%xGI(9Uiaev_^Tba^RfHCbe{D^>_@#zwcy5Y{wxepR&$zcC7u36YCMwv zr$XR%`2X&l003nW006lEsSuV1t_FGrj*bQ%|9u{ktSV!d$&b+eNUfU1FE}>h02fd5 zM`ocSc|&Sb$+(*AWo<+gt)p}3x%bxh&VV5;1)KS{JDm*%y64tQ4oF}>>XDbf9IoA+0AHXCm}zR}2=ir(FgeHIMmPsi(yf>ZR8exFYAl$DRF3T(jZz8~ zN}q}7j?77wst9i1168pY)00+;GKtf|B=T{ak#~S=uKVT&AG*k30xl;* zUIvQ0fWEv5G2$-)HHIPulxCQBpkheQsSSp>1YWS;queyig|DI0D+tmcICaA?ZhV<9 zs=sgxbN>ZXtAl{*p-I56D{)$H945jLnQEX`UMbxMDP`SYy#& zXHNrPykYq-I%Gq0qbFOkSKTD8s$*j;nSGH+u07(ltmp827?3`~e?HU7wAjFpFG>7; zNA-d@k?sN74fc8a{4`Xs`-{#WG%}3g$e$Wz2Vh4B02vlX>?3k`Ahu#hFYy_6-UL)i z_w-V=q1cj-7vRmML?CnOyRC-s+X(t~kk+1C0edz0;%m^5rzWxOoux~wF-xr4SEogX z$ZZCuCSeEFa<}&%6Yg%KI{9XTtAuF%3qTm=ozKp(DNFwp_bBWl+OWQ(j2z~8-ZKKPY7t*Z9`$$Oc&`QY&_|*f?+mbWY(zCrz!MPGO zi38KoO}t5a759sH{QeV+uOSBvz9v^9W~-IT5I9&R>jk*P{B(pe2agWoIx3|lZck#df z&jyDmeET;E5&)o);s1A{_rJ^6$iUj#(7?#*zbp7!(`@#4oBOw;X8?tdq)-4C0-{`y z)F37@iIg`cv(@S$aj-|iFYx6yL+l`$30L9=wZnS<{*d_aU%pfEI7 z40qw?wXJ#yCYy+AqBbuxm4JFLQ777Z2MJk^-ugB~S_s&1ZIV$GwZ{4o5Y_-voPR2x zy4I755c1s?ct=#?n}MLE*mNWv8Up#A^tK~7J#_=Z9%PbM1f|G#Pb7+qG^*>K5HEY2 zozfAt7+etFn1Bj#jCOW<5Q6o~M}@xXngk<02@aRx<)Eb%Oo(F_!m6gv(S(B~$}(2M zfRht9ctZKx_X2r#Py9RbA`0Dq${V#64~)yA6DTPJ@MFPKymXt#_~HX+BDw}XP?EOq z=eKbXlZKr*nb0WIuxU3aj+0esQnQnjC)iCK?4c~8J_6Q!lG5Zf@dC8vwW#{FvK5r0K#0Oh9k^d zMJUWNRS6W*k~ZvtS-W5J)zl7f4QHumPxMx?w`?qEJj~rf_Ap8j;XT`r5FYGPyffti{KPy3mE_?+kh>(6k)3v z$s3hBIeSSgbt7|P`5S<@<@JUzeJeZnbe?!`%qsM18#(!I%nE)nE0R+8GVaGFINELiz#LT~9t=~={UM!;L zHRNXC@XzZ!crV;ie!>xI0h=qk)b7i?c0KPX+Tr_~(+jil?c@Px`0hY(cl`JBMr)nu_VIicvCQviT5N~Q8em#Mk zPALf~tR7U`^7?aVk8RCQj_>;0|JoCys-nqin&SJMI{h)gXm7v?t%+`nuX4_uKz81be zieH;;{9S=z9Kac-h!ZnW+T)fh$O&&II&r;(OthXKFBo%vmPV~bvyz5)z!Q?GhlRTT zK{PdjkaL?DV$^xDwrhT9#==~x8Q{sxJuF#!Jz~eP=$bo z1@FC`kyegSb@?{Pe^iZ`l)s2lKFt&`)*j`l1R)c$HsxZJaon#8svl5kwXzBy9Qj1} z6z&S}LPi65Iuu07IDO&>je!0@hnKTHyS&;A8mlXaP$%?v#EaZk!>uc}8g#a%WIfNQ zzn=i6#F2|MYUVwR1+DkMZ;Q-bicRcDm@>{4e$L`+7!CmXx;N2VHm6W~evcDbrG3iu zi_x+(aFgLLxAW$fD?vyvzAHrA{W1PC?y_{S5X2PkSYpbYn_+p;UbpJg|$D@1J`_#3Z!VrzuWoP+m zm0ktcJz*K!|J79j`qWWFJIKzv!es4kfZ3PrE)B-Yi zTDP3Xjrq!s?ehJBm03-SgfyPhF{XnVP$1cJBgLd1srR%@hUCgNF|v;t+60q)Oja(s zo&O%{-7(mA39|eOO+_Q`WUZ|Nk`r>pRVPrapBW}JJxR^fFiO-+Tooq`wHs)M@Z|%+J%Q?RUj0WtU?ziw>sQuum`2%qf_H zEj89M&40ZzZ<1OO8HE*zKM|nAH`?WY%b4+1FvGq-L#^s|N!Io-`azxn|McTM^Xo!e zU&395Bk72JOZsgUoDYy6XG_&tUO1zrUi?HfWY@M1;i-Gi_L!*%e94JCh=Ig-+@yuf z-H79w8{SM*kO#-GLim*0YvdIafd~E}H}{UxNb^;FnwBT%Lrs4`D}wH}3PWm{$y^nY z{zDd%Qr%bnm%ZzOYSVd2!>iyV?o@jliA+Anqy_JDhnIYTuQTtEC(Ym$J=Fy~6|g}L zD8wi)Jz`nvHj-lgediDD4CUPclA5}`V`l-+Ac~t0JdzFs{lS`dj%gz{Q01qjZB2(f zuK%R<_imm6>2?*#RAJ3&luy7Q%ZEn&%8y9@3)f&()%XRK^05R<75D45`E%(V7;d=K z$K0lxS><5G5FdF}+Iz8_vVT&GUsX<662e%#qO$PK4{sf*P9%6lr((v++lX1a=`rDG zKrAZ>Fru&t+B($qbB$?lp_^B`v{(LX%Y`GSt=RixGI6e^aV=pqlpLMy3jPK^7>)`` zH9t^LA%41Xrk)<-|K`Bb1RsUP{YMPKitH2 z&f)|&?|3Y&0s=o>AbGAys^zUXA$!2j8htYN8Z&HY*y z%t!zL*#D1)#Mt70#ej2S~8*C+?j z3SXZxN860N9R9P8kc9w>y50it59WL{OMM2%Cy5{jkh(^482_r@B&W({Y zPgkKFtf*WvcP?gLs!J{4WIP$cm}`%BHhB;V(x~GU1D{cqYq^29T|Uka!yXUvAFAVT zkHo6)7G;3D>)Re8@4{4)Q!OiL{Gl$VCN ztjs}6s|PXtVr4u_{Ye9av4R{{Oj+M{;4v^%w5Nb7JEQ%jE7?M94jv(%;vD{Z%F-?| zSat=wEJOXP50TPYH)&V6X!~y83qd+udMcB;Zlm6n6?HSib&efM`6j3{Pge&G7VyB zUfFO{>ynm^2oC9zWAHBfV$PyQk8_eMmOvS(qr#_*p4my`J*PRuxkM6;5H?t`0gz;p z0c1{V9@@14CLd$V)uEFlry%$r+rX}EkUVv64@c>wHG@obE!+BlH#KXXBjeM1={`0=W?(pV zGbHIUeuL}2fsL(D{=+5T$}wXVM^jpEX>TApeXh6WtZQ#Z^p5u5mbTnilm6K5VueGkkVZ&Y^&1H&jzDu5>l?) z);VFf^cJ`mj(8y#Rd7prW+|*T)&Q%J&Xh-L~@6S2K- zfd3Q%t$xgq+25gF(J}x4{Qrj#m>D>mxc!eeagBTJcqn}JH+P_IN?*tlsg);B2vf2) zIIy#&>DcNldQ^h7oMRNOWl>k(oggO=U1;IP^V^2pV*iaoEORYEbwN_|`n9Fy@g{~9 zYE+kjXRfE6cl#Y7*ZRDp#|4Lx&A^Bjrm57r2=;?#`sS5yW8-5B?)-XQ*RlxG@uB)?ZYZIG}90=ezZhgGtR zA&m4Q?Ym3=!6vS5HT)KKSnA#Q1neeB1Oya<*zV*5>4@^DgS1Ch>|t?A#+GBmdY^oU z8%m5OMgS7&F@P`r&R{D4e=+t>?V*LymS$|*wryv}wr$&Xc5K@=cWm3XZ6}qgr|Wbd zT=doY0dszuYm9ds*gr;)BAuQtdbqkeZY|VI$M!hZ)X>Yt=sto(4Mw34c;u+|n?C$K zfo@OGl7HY(x@YAT59mqzYNLzeLVo<+$;@$p|$O>iV#BQ>> zI=@Tq7({ttq06I9HW^Li_E+y%s%%joJqPe_S!pr-RnBo(Uz1py9{p-ydO8r9?(X|G zJGYt{cbZ=oJ4a6^ndaa?n#^Vhyg*?1L5!$%4u^XB9?@B`cbH%= z?e{|nwnMtO=pkN&{Q*a%WnCc&2>sZ_m0|nKkpxPJLM&fFReG%%O*?#{+WuHtJb8)A zgp{%1Uit8TJhz!_?{%LQaGVoj65bH=t#6nTOKVF@jVxUcZ##*Md00kQPe(_Gi;tsA zwvVBm&S%|Av6F*`3~7+8#uILz&Ih$4$>lPL7q?$;$fK+tCU6 z%KB6L^HSrZc{cDT6b4Zqs(=Zmu?XjEOHbr-?&L(g93A@Qs_6+`1 zUS+g8>bB*`=fi1=AO}`)du)8`l_&SRSMWr`ZCXLW&=)U~;TSH+#&V7OYtM_QGqA zxr9XGH;S_rZJW-tsT8amGjo30eUOco^pA*u3o(+rQP|+T*BH%S`W_#_q=-ok{;jF1 zN)wnaKC7m1A~g9&kG;2UbZC{jmEXvY`zIQ!9YzZKy;y~X26HK|X(Zm|YaRMI?en6$ zW7eP->w>e3xfna~t&OQJBXR^>{;H{bnKc5~KpMab)?v7=ex7uEkNhEcjS=>PU-9{7`mu2Ed+DzCl6|r|^YW58Dn=Noc zQNGD`ezx#s2)dA8edyS&oW1dW*;0FXm8fH-RritX?EjSd(h64PF95(;f}8{F`$`ufj)l0iNCluThmt-tD+BMK_X0ga*q1O7HouO zl-3p^NGkO~W||mME7KpI1)OU@6*G#VATSBhwu5;~PEmH?&JcIVAv*afA?KdL46Go4 z9`?=z^%P}A5J#%$OOWI+&jaxcvisG$y*1G@x&R*c8N^eEK0y{34Tw_BVv?l>=#af< z2ZuEo82o?+ffkh?g{l0dZZ6=hL=l2jUxc><{HBOu7nn>C?;(X{l*RhECd1GTS%l^U zBosm}iH@Yz54ojGbXDA$J{vA(`+Z%(ff zal9QAihv;E<#=P*8Y{*}wWDYgA2CR&QtWcLA1s=7Dl^yJ7MuIoY40Qa6Zq!qp1ApE zCnO%;v$K7U8m#P|u;#WQA>1OvPpO=fbrkC#kT4BJ(6-s;R~-fx6p-?<7(i0_q!inW zj`#`5X?Uoj!ENgJxUo}m5-pBJ94wp|>phPt6tLGcv2X$b7rSWUu^79cBQnKqCct;V z{R8L`5BUOwR;L3dMo`W$Hmd?O;1b|$fLNS*zul?n%0qmu#%~+$fCRqE?DqW#=fPJE zWpz{Rl7%BM0HlJy2fol$+d4+2+tp!7wHp&3ZyoW7HA%`8;x3KJFse+LII_y44NG4_ zUf9tJBh2grkB4@PlX2spo1*i`&u64vq>NTd6iQl>*29tk8NeRvv&fT8=bD4ne&{e? zL3u!uF7i$L`IEOJ5RRyV%kr3`XmOT_f@U_ntPJX?IrL??U!;iHHj`cdGvV9GlkC)Q zfU3CUC{L}(m4@dIenhb?mIWaD^zM$D{{o7yN!Y6=>6GaYL5l4tbqNu7*J zG0;fa-HOS7W~Tqsv(k||D!d{jELb0EJ^6uQwtyZ37C+1lW`(*x4ig+sD7u1Icr3Ed z$^cO<=`0I%SnOPp|5N3*m#sV2{cTM#KmjZ4kHp|T|ByI%Vn}6yYYhcq09~9PSXe|-l(g?#+UWQ$U zZ>WVDBnr!yHDmfNPz+24=yE7=)w>MHK^CnSoGYK~kFd45sMP`UQ`b&yH8Fik zB`QozijJ;b3d(IN_dF#F>4inUprXnIr-fSo){Q^4DDFqP4U|88qI7}tWh;sjrf0J3 zLQp5(o)rE2gxc zn(i>zLiC-{l(3L%3xmSU_<{^ASQ#l$bWgb?xo+ru{7ySwsAnh?3c5rzrz5AX3~a;( zYDiS_hl?+!289Cphngq)Qu&fn3R8F(aP}seqH-Y|+P8&NFuDeMg zX?QdVE9s(PHbnfN<`*LjDnB>z0GxHkX5&msobc8R>CEXnY}w*RKP1z2D`G?mq2K++ zmsh$gl%M3&bC3t{!(d>=UuhUBU{w2RH#*c?*N27f4O0BOO#57I#-RH5NS7_@4uq!R zA{JdMD)UR9xHIfCym0ZIzPA~G1;2T;$>DAeb3$RGOA`TY3`w?NiX}yu?O--$xzELs zelv`%PHmde$XPcn>o?ZTypg1`!yWVbS{R~lKEs>iIns}Ky{T3MYww%;TnbJh5iZF` zilzuHknslWz-(x3;0PoMUj5&y4>d&Ml?$O_lGGaW?&=UTt zXuCY!l^gyeYgy!GHtbI?)(+pRw44~DQ|%a%32btbv;_p}C@fbj{$;$J=DEyX|HKyt z<$TBDQ2&YfHjOV7c5#M*ava%moQ|;|H~!SW(AeqB&pTlVfw+)a(MyuYL!P3A=DvzLkbCC4jA=|H@>#@%G?Q&CX4Q9G?fPW8+^P`)JgsOtTf2^+@kuH` zrFvzyqmmWy>g7s7amg3D7~2MK4>=F&p1%ffiztGw&+e;fZv80|p!o2O-A*;;8Ctp@ zSA!0WaOw?|aFV)X(H3a564b(F{$88S1x4Nj?+9kgKT!11NO=P&?s?H40=NE$2L1MB z%T0aTNQ|L4l$!1sl#Dq)v&r1$D+&<%81&UAuofqhR?PZm;I{c?Rxnl@ulH@;N1XMU z$w1)l3NIELN(ZscY#Nuw(&sx*`|>L3uf6#KLPFOO$zu0uLdcp)wGLTn{; zt|j(~f-}0=>=LsU-^YR5Qxe<-u(_?TI`YmxK6fD=h6e^$&?pvgf5F+;cG&H22l|kanel;|Q=V|Y%%E~I9)HZ6${~n!N>JMB3#3pR zO$#Ox!*w|RMKJ$*Bk{m+JOL;_G1HL8Hc!|GMl#`RFh@4BTOsI|C9uiZYANxWpcku#a5@6ga<#0!gCW_sGv5Gy zJ`j=B&o6r`yO6vsSHr_4^oi|}2w4jqg#?3{llhTX%2qe6aBj2V(3@wVY=$>% zhC#kq^)DIq;Bvnoy$qel=;!*F8bmu@UgBdYu5Ch6>P1wPPNiejvzF`6R8O(B<BGdF+N=FjAmB&p#cDW%;J|C^vn9Fj0NuF04$wkI;h0CS*jgj;I9jPQp*Kw`T8b zuCcq&Teu<&YK5+#BJ5X?_GiU^O}K$wS;Nnv<F5b$|31_TcV#f!0siJX;3LR$L-^`dX0?#*K6^HF-|aXS3z6J711L z+ju?eH%$(Pe@}6RKj{KccQeefWxFtXo2YZ+*3z4eQ+jG>dDFQ~EjMuQMU@6pt{J!H zQE7CQB8md7>0zHUL;+bi_4tavu)|TMoXS2Nz~tldWMdhsZm;nRmukFpFGj}MlV~RP z@zo|pKGA~G!H_;NVh1Q%w!qqMs~c~uqs;R)tJf6q=bFK@MXMYFy~PkZTZ|69X34fO zwfqrD=VbycQtsGmRXj-#ic?!n_8TKJoOfp~O;9y(V~Pw?0YL`mt&M5R5v@spJ(wtg zQ#?de0y8y@6`=*sA)d1@4U-aphGXxWg@jlQ{&9>8Eg{NG(Is;OjwIAHTnp&dRo}nj z#RK}bvpHjiRk6Uqnk*uEn6BlwS)X;Ca6?H^%ZX6qKs=141_XoU74;v#HaKWv0s&u( zw&p8&y2>Q^6oj>c%7eCI%cZ%?^z3{?0wCm<+uvB#q3ngxT}~j5eQ?MK<%5yLW~$d> z*mv?M#yaM1$|$wwl(xYy@J&RwJ+f2;t|Fe}zSIyWM(I_94#Sq6Xqby8m?}UNA9$Uv z`RYrA_sU9m+Fwh5xSD&glbr-t#IO01dIC?CS9eCkg+CDy9FpOkG0#wp(9|NQz8utQ zOjxf3gM&R~kX$%&uI=|A5BcTcG`!d4q~)MNP3x$i*vG|4sdF+oMluHPt+z!Y*e#6e zVXGNGY@qLWUUrY-WD!QkMsD^QfiGTe5=QW8^_w9^E_!?+xn*$S)<>(v7PB_)6kM3)Q}9aT@O7@c91Rq)X}$N|)5U7=oX^ds|z2 znp<-@#-UgL;`wPkuA;zx3S(El2Hgg`7ocwq_op}C;O`-Ts;xn3Er~L)UUW-sGVST> zVg!^fH~QczGr^C2Y_rezEP!h)xsXU;W^nEKVubT6O)z|)K)~b9iubbFEcWAGDR}lpmAfhA5?vt?-nZnB&`VzW=(Xve8huYj z@_cVKG&#vqM48tG0S!s@ep$w&Fjt}uctK+VI(ixh^#||P<2V1wQt6WSJ32q@9R~4& zy-GHW4&X_;-lFOV=gE_6B7efbGFG;#eMOZ*7>Ads_!K#5&V^-R*T^X!5q3PqeAMSv z^vDuhSlxW@HqDnx4gyK7v%J5^ZkE1yA(3zI+DEakBJ0i$Nl_!yC}T#~&3q&3o*X+x z%f3h`aDGM9wz-#6kEvFTPGh^nZsz2*1I+5v9ehft<)E=xUSF|%4DbxcK7cz$9>mz@ z&eKGzx(JY9t8MLW$(SC6Mpj?FvvZk`q5;!)G2uNkY}*(*-d*~U_W*hf$7<;ZUFP=5 z9`E~Rc`$Yy3FUYLJaOKeln3UD22?bc==j1OVe01eFDMb*gG^RWocMJa$!4_S`bQM^@)cfm2{;CUi27U!Y$yS=tBK- z7<%FZDkrLZ=$A?W#?g(7{w>c2?}pm{tETGz&3ISK5`A4x)h?q|5|N5un20dNDjsund-KJLsn5|mw9r%{o=jf)g8yELU(x4X%7t-fXe;qwynp}X>gNVIEpf>9=e2}yntgJn z6<79Ol4~UXMR+*<09m_(O1t~VjP*{E>_#xcf7KD@+WRG^-~}alWm#JTQ)i}t z@=^KdX%#iEpYC3F;_4zW;M*g6QFh z34%T@jR!NEjDlu18k@xYjHs)a(dEyf!$)yO>=| zCGKf{j?iDXlugV7pw0ci)`}CP7u3Y0(ak*R^|0s^D@tGwHbi!5PXO4%65g>1-c@A1 zG-$oLfc}n}U>dOPV6a5lPg4>#P~K@H$D8Oh4JAngS4uE7tzXm4B9dnQU!w$gh z&hn*cPS1N-tKDmLqv>cQg6X)+=y&oy=OT>ld^yqvmoFbK6KPV_j+V|aK#64)2y7`2 zN^5adoyU3|;A};hD%JY1=2KKz!i3h2$1!UGDKdY>>6{uv2H;OS+l8O|>xSsno+Z48 zuCP8}YS2Pnc$KL!7zP*9_)5RILmtMGo}9F+>eFtZ?~zaDP&xb^S5=u|{&JjDQaiUg z8TuzMvCB)8at{FimtPx8MNkBaYG(EV<^5F?IDax3C{LqLC?hrL#Yvqdc&VxU0He8u zca%?udQ5_dyf8`FTfYq4Laa)=X-e)fXp+|^ID}5&fRnKcvR2Lza*{v}p3wr>K0soKW2DhA!`VeFmFf-Tz2;_-mFlUqCa18E1CU4 zUzj@7%vZ1-`vj^WOF5n+Z>N2*Wwkf7fqZlI5CIo88I(yx?bkATN`>cn<{5;9WmjN@ z9asWHJwdeUZE*yXGCXN)oeKMtCX1#RN`m!J5M$WpeHxz`5Z%!eLegf0@BCx%o~J>C z{m6^YTf4lDY&DjHwZo{k4f~pEwtjVI60}MrNeQ0n81uVPxnD%vzF|X#no5lz7WxfZ z5W$y|+pLHY#Lql9Gf6>5NyPPO&Y)^gQK66X9Z1ywNx3Fw57-Qh5+-{`oGl(spgc;W z^(SLmp8x|vWJieGse}VD#)FqbNWv&rY>LpNWOWbM@fa6_#EOVE;buUgB}Y3()6`(C zBmw1;-AT_TS{n9waxqjMBUHPXSP88UgEY(>sjEoB3>%MPNvAiT8YeS~)8Q2-sg~53 zsMravWs&f4_jSV$*-3V~9}!p4dWpS&i049kdVv}Mg3qe375tgVN4<{;YQ*?=FiF7d zgp;i#Bc{JcN%RT%pZ|rky{HA={&KEtVg8#d?*GV~%}uQB|F6V(E2YnNg8?S=#si9| zV^vr`Z^$1A*byTXt;4$Gl8%AxfFhYzFd?x`j$gCo!c-yxXL+#NHo;f%&`j}u{v0S+ zUA4P`1Z1?|{34&qq=Feq(zBs#6UKh8seiZI7qi@g3quC$Eyr3-h zu)b%8FrKlopMoo&3z~=*l62Xi5=g>=Ef}V=sD97y-m!(86nOhV4H>^WG-t^2ebm!W zUwCAVz`jRBCrdyV%r?p!7Bk!@wK%=A%VU@+n_3Tv*iy9+v5r5%1cT=04zutH%7t8q zDwgz$%24PLbTX40o^_a)IVuBn6u5n@adC@BUCry?PMVU#1G($+tJXjT+i!#kEkXiZ zyPx?p{I?FS%T=R-kRPyF{Q#`8CDQ(ixKLpu2 zBvSafg70+iTiRe|7ILc0ap#HC$k1W$@E^{o>1W{nM7+BbB8A6xpcR? zdtKk;(NN|LJ|E|}17@%PB@d`{_1~f+7y!T;#sAzX))r3A|F=`7G;I2II(CBnam9&i+}|Uyng_?T@QTBLL#({(Myq#yCwHF-UHj zurZHg$cusi4~_`lWD^nNE8YR@-M&r#$|F;-t$;WSUbLuHgCIX-7{$=0_8iIKK}gdO zL@@K89io86_YBT6Mnql2RA-@K5kvAb4D=8zxyz6R|mm~C)xI6bG zrz4QY6NF|e9o%_sb%&vK1^&e>mBp!qvo!R6)$9Q6e3LN+=#+EbcJF*5r8*p3eXC$c zRKNuKvP#xJ=S7Zn)G%EOSQ*STzh+uO2B%;VU-G(u$$pXpSMpuzV;T*IaYeuHc$2Za zts!00o@9Iex*Ud3_yoVc!)UMjocmI6*+XMgHOQynIp^pn|<`QUZ7 zX>YzziAEh@-7NPGnACGa)xDdM$IlMY?jG(0;hS>}T|rF*5~La#Bhw@Ym%qb0Sy9?(oE^N5&l02g9TS?% z#Rp%FSN~2-Txz#8(3EJfFOs!n<0BU$qm+yfSQgT!S#tY#cdvbdx!ZyG~BgoT;c=u%w`^)#JPFB4A|!;j7v+|T{2#e z`UN3XA1*l|WmCjHwiJBzv%(HrU8U+rahsop+u0 zgOF-<(3;CmOzuR1AacD~v%w70PlE?HhdKk)$+!sX*9?7WOdk?Or*yNLPuAAX!oN0_ z4Z)jJEBUr7hua2igT8uM`0Ol@s^-x25J*HvHGeci-&9HVPqE>Uh$hl3b6WPJUY&xM zuZx<#Zy?ZEGCZD}Rgrd3E5o6Dv#q@ed%UrpWrxR) z-ndu~e%TTChJvw<{#sQ3!U@Y77y^aHcBdy5*r}JeE^Jmr|7g z8lRM}ruFa@ZBk48&nYGew7%T8EE9&G!nbA5D5q^$ry4b7l%zt-pk3SEw!Q)JimDw_ z8#1Xo^0e>Lv>o=;Sgkt(8FY3Q-fbq0ZA&Y&o>8iBNo+kL8Rk}c*Vm}vq#0njxZrF8 zS(+pnY6+V1Ler*;1}S9}lVLC|lEHi_2ugV4ZB-k3=xjwhgm~!lFZNX*oxvPSnLK_? z)bt_4_7kq2=gF4=;z)E_=ZhHnW15!Fh|$st)hpRQAxq9=9%t8n1Cx}9PN7G|R%oux zh6?zTGqDkte3~mZ%Gt20pyhh1H>od*I?J9^D*03U>&V@_T?EY*6Ak(7b`!OYYoB5x zGPhEvn-x3X-HOz0#7V zUA45Imz>|RU1o`kH8lD@)h>JJ4F;H|{#6JGEo7;|p?A?o?O-Ad*Jq5$~)7lohnKC0otulEq^*TD5ZNqub${^L?Jq@riF z!GPfRQaam-jve(k`-*h3ZmtG9ND%*DDN`^`E;e^Vv4mrz^<&44G52?35v@rbq~y`Z zN5@AtUCjj-I#gOfRCtEApyW0fiS8fh5sw+HMnp>y#ga)d4rQ;A^%F_u;Q3C*Xdp=m z1&MqZi*rL-ro2vuJd71UgL1S^MgeZ<5lq@>!014>q|F}6frKq)&oB;GYrpVUP?KCP z1%dYoR;4PS$vpCDwzQm(JXPoL15S`wHg&L&+^Pf|io^s-(&aqFGE+|ZE(>WgS5{v~ zfHUrxL_yi>-q{F=8NFr&6#XUEw<{($kpllV%J0|#(CuY+C%<<1)3#%2tNVP$eWAsd zyh+&pJ2;oH=Zip#Q=@ejtV#AWi`{nA=Mmq~1d>>zm(TXQ{GhQWkCyK{94?YDSu*mj zCk&y|3pg*wLabXKvtCR8a;eKj#?AWZp!>3VB^d^6B7u2EF4Gr#BP3s znDQU`xu4LY7UvHN=L`Hes{S7LgrZT%u=Pr!ToMVdtHpz}KmzZLnRFub2x|!iqs8rm zMQE2+EfNHJMqSZHMn76DVbREXU$rH3S{&cnWI-@5B z>_&&;&mryBVPv8Prn9j)>$-5HY8(md)+FCQd7x}+}q1VUd6pHoLX zD;`kpczV4Y1CJWvi3f|+g{B#{v%9VU=E%96)^M@Jl|w~;Tks5rD(S>*`Hkmna|OUY ziRA*g0sDQ2>ObR(-L1OIo7UmBMPIBm=iC<&5nGc-I0u-Sd?7@pQ=9`d7%dfy<(F@eKtfi&r@p zb&b_1jGuu2`NqG1B}ZZQyEdMFvEcvH+Hf-Yk89kPs!n940E*9xn)YR+P#mk)dmTa3 z7A9e-uEeC|WaVedKK?-Ml6fjI~tx)P+gYAZi3Cv*tlb^=o>4EVubDM-q=UWe``I zehrlwoHRISohtzR6Jx>?kDY2B<84jL zy{anpylcNoP>S!NYpXd$=Ng;11zhMQJ1Pd%F`PACYcvW;tP({3!0vK*(M>?&9B;LJ zGdl^UItqajdL!Q z{+y$V?Fsr48-0PCc>{RE7u*EI0?TdbH$bG<=(4fwCtj?-ksl7?i3Jn(iE!`OW8?7* z3j@`q>6d_6&RtLD%pc_X&3puGcDhg>`_ISGmZcb{{jJ$q{;I_S^3sPs8B&TeDD41y zkSRwTKD_G)L0ucIn7oasT=BUg=Z*gA4LtH!Il24XoK*gon__9XqRv1`O#yZBbfD7Z z;s#E=my1nP`<{Bt+Sec>2;st+n zvi2^4P1>&)Is~;VIl4G9Hohq9x(Uw3prke_^VEZdk^9w5UhlsNKRXAn+U>f0Cq1`N zQZMj1B)iZGhPse@H`Dvy=)?Im@Bg0!lK%~)2i1^NO#G!TGyJ|t|EK(X#)%}VBkqX*(M@-eIZ#SM@R1)0|XwE=1hx+v`+NT1n(S3@PJhMC_2bukt<8! zBqRmZbq+=9J{LIUf-Hl1vS692lxE;{Nsww&4JmS}LQZw+#U?fU7H+{8iwlXRtE#(f z;0skG)n}E=#;xs%I?IZDq~{fTXP6ok*NzQoikK_M{dama&IVg@z+(8)Ta zS;$P50yLbdeErePd)Wl0>3qkZRB(Yk`pCGYFS2-f*8YKmT&6g`Cq)&BJUMN1bf&^*TYX22 z^Z8vtZQI7T&%p7neS-E@&nH&?6 z!5Qy))%N`h_YeVGljAM&O@yQ93G1&<*tpG^xoRcxaSkk#al3qlg7-@MjGF8T9N5%g zfSAU3Q=|mvg*M&X=n<$sQ>tI_CnQ)^Bf6#a!zCJ*|B@rvukmsF!5`G#5B zRUd8t{igmxA7DszyK&4+cxfyuK-Lo4=&uwkj*Q;d{;9lPi$f=9r(jbq47=@lxsf=R z?W{gF0ULLH=37_SB;Js{tNiGi`f-7211Fw#D|2h|xm6SVo34KCNvz|q=T#h4Qc6@E zHKzGZpZJ4RZzrJ(=xRAe{OVU-V-i+kXd>dzDK0?(tRe%YZ-C#YdNfnAam{Q=lUwy7 zbnpRlew|M(%L?-`hmXBC4!BnIG$8YggD(EL%SAvp5=UHZ%g+$_xcARzDrp^?S}9dHgquyp4&Couh}IwTZ$1 zixAK`bNt`i*-q(fFh3j!yu<`&BBrExK8{Q&Bzx3|OVaS}Y9IJ-Y<%TFiucgC2s8(< zjkQ!qWwL9}Ov2H&K&#!hXbOAFR;1YJ)amInu%oHnz85pcRp_hP>FUX7&J-_y)Xb#x zN@wpk9>14C{e9yo8P1_y+IuJ&0{_FpORxi=mkGRG^1LBY0>!AW#*y{Fn^_=PqG zVQ&jML&0^tz!~Va`zRfA!tbNS$0di*B&8Rsmgg?F~WZ}>|oHS?>|ufhTLTG8RN zUF$_8F?`V(aPr}^wcRgp&vMFiLryl=0sjc8%=tpKv;s6%X|_XWL|I<7tay`zW3X~j ztV3nZ&$F-~^x?>X*0ABTy;Bcp4JB(I<~(?9>yVa{7WrL$o|y?!J=n;PzT0}#w}3BT z?AKDxOm2?ldmQc*4Lvo!I@8%g@v9Wzd~BgFELTkjW0)%k%TFdtl(~qp)xb?H%3) zUiDk}If9Wix$P`ppT@HVLN2>znZw7|<;j+-nP`ue(`&LrYnjj&k=4=>8UFN1jLjSv zTQ^O85^P~*qmrw6`xIu_u<$U$iBnK!1uE)g%qPTy>XyLGPLG%z_^C02kXTTo*%u-f z?CivWTJ7wq2fkq1)B6A+7hgeG;kP|V@rSUc?`z4*#Y?WDpliKCYy#Pg0=q0d7D{KY zL91gFi-pZ4h`S1>13VO^Q#RVdDPpAPGcS7OXH)FJsM{Ql_BThiv<13#-`i}`lxG4a z);siGPDpe%W5&Y4->ZJ{BlL9mZcnx*_zYuCQv`8~;zH6LqX%(nSo=y`tWY$4GawHl z$7pD33*etPE3Q_U4CnoJ4xeb1AgRm9HjF94UJ_!>CT}IVMYe;6H;iwc z?`Cl-*y|a2NpDGEANy}q(d=b{_D+Q&Ea%$=LunzGDd(aV@Z&(u(SMLlA;a(sf9+Er zMIw{JSm zU|6+QMLlF_FQR}$wiqPM*AB^~-W;d3`VEbb$Ii2l2O-*dMN zg($$)Xt5KCITKcsu8Y-yCSDksjL*3l(04+!w>kdukH?rWJL?CY2Js2-Ug+vw@5Ug> zQ*GrL2^jwkNezzaTZG7{7`l`v7vp6H*ok1oGF3{!qU$ zh?)=X%xueiKbjZ+G-N`@`-6>rjTgy}*WIbk)h#9Dm3S*V!W4Oz6u_qZOz4F@s7Pz0 zYAp}$;)FtP8_p=#X3Z|v*_`zi%;IX&?Q>zPq@D>ReMk9ZSgMguKK^GI=}{jm#E40g zWCTHls!i*X?l$GH%&?2gIRk$Bq4y_>zHV+y0)56gR@2F8fP&Rr*GSxZ855ARP9RLZ z=@zli?gL$f8*~RzA4EYIr{LOh{T>+0Tm^oRl$gW)FFFB7p+i5+W{&Lke&@3C9`-4K zM^$g3Gn(WzlNV$rvRitTN-I;r{oFV-gu8U{EpWi9?qFcoC7 zi7izpsu`RFT!6|3LiE9(g?qq6Axj(m;R2LF6 zOI0PZ84U_zsgD$8^~J*}Q_dBJ69bCYoa(qb&xscW)AyS{_A1^U?k=zMwy6M7Ni-z# zbj}KJK5^m)A#`dE^>B#@`8EcgGbl8s%rFB&s4eTbpquBMy8b`>w}!R{-xcwf74ee= zg>m034YX*;8GzY!!1&32-2n+zAexoyX^V04{+n5y7r+A!9*C2x;7wWe!}Tt?e_AC^ zs-vcIp}8$M4Np2u)kdphwl~CXljXAp7P6D`c<2d%ujQ*@rc?WEKBk$Z(T}g~qDnRd zlzlJCA}UUC!cs9hB!+1bO#dB_{3&iKqHeS0X9vw>Zx(ljz9L5}b~y%(D%n*x24_Gw z8-Ssjj73)=m12bs+$gybxo#7f9t^TLk)jalfR`!qFMeoLrlOfg>g$w4(YEhN*kQa9 z_GdJn?rlNvKyIHR@r&a-;4)bDib{SOM_!b=R*GOM#_Fw3LMOA@>_E9)mEOg*#&Jg# zbVFB}S3mNfSj6r%2TeT)BnWrn4=yti1r$x?zzeLWXX>AXe55S!2C_(1+Qhr@D4=9ic7)6ibUMZi zKqHZOjV5K{(2G#5(t3@X!f1jYw=TGJ)i$hZF(#*uI;ryGd?FPBP{7&zM~9;1>u3hy zs$ad6OubHqt!g;9bs{j8VRtgTo1CFqsctl&$(@o9Z*Q;ICjM@PPLp74_QW*?vlV(5 zdb*}5DGqz(MKv3`Lv1YCl?~lh&!(C*yNOyNSfE^$H6+}mwp-Du2gysWQwML`Y|kd(7WNUYupT$<1kTHHv@jr=u9PEqLq~QK@eJZH-Y;z%El?sR5e?-0 z)5-l3@s8Xd0gO4Hcz0M-F4N>pd+$6L;M_1(xf;E*rKFJzDjC(Xpw6YB*KS^}lklbY zSwnaV3-1x^UuMzu zW2L^LnBw|TywYnGI;@DU-_ci)o7&K?p|YR59X!GQBPd4-IRE}RsspNn8HLSsw0dXy z8b)Iu)pqVRERJl;gmrq6!|BT5Sbxcy(qyoSf;0A;)oUJ=@ksY?s(a?qN>lY+lS_SN zrGl`F+zvl9Zh5sfYfPSseXz~Y1gqIrmk0NH9Qj7)myU-Y7`$79lpQuhK2Q?pk8NPu z5Sj6%vhdVXoYE2i>!WrAtVa=znAGa03IvekM|A`lNo0$bOxGx(mX$oYm3hY>kANg{ z1#vesP(15xplCp2kYz=Vj6Z-rBWg@O@C!d4nB~4me?umpgvY_D=}T@SFYTH4dM@J6 zr~o263GRj&inBdJPznUjgQn_q0dHPZzOt*KI$XFExTZ zjlaN+>3gu~^!#aq45?%hitetK-BL!OrW4-TS!QbD;oh8P_~#17z;*VIG#%f<28F3> zOv+I(WiFqs~DDT z*^MBE_#~3X65x*LigatdGKofAkTMw@{V^YTO&yy_k=zaV+} zdAUgs3(NWk+o-uv=%fV%+X@vogc4ZB33ziN z`TsO|)~}p9yXMKE%mR+b9Nm^_d9|kl|7Ct@2#xJh`PCTJ#w-_XgJ$$;@jh? zeON7Nd@Ri$gWX<(d5ilceqX7Z_WHA_rFK%svA(Lm)6bX3d#1L!qje8J+|nG@o!Psb zzOe1%%1>*+cmQ+QH*0C+U6c~IzIE#_tYI^L4RIdm8b8qM?HnU({1BKcEdPLl183Ol zLX~sM6<9Bw*f*hl{+GO}W}~Bg`Cr*b*>8^$|7Qfk|4xe3`wzOEiMxS~y|u|N$2D1A zMrz;}(Rfpf;V8+N)It>^)Vm(ED+t>4muA7(kX^=g;If4EXM3s?55+m=VY}1g22_y1 zRr-_yw9r6{v1%Ey3T+|;fnB{LXb<)>ek%Zq3VROE*jlnIqtZL)0P?7#htCMOq6&t) z?WLo`7qt&A(lB=HF8%e-XK)lLry+H)I8+l_i$z5i6Bo-z783gg#BHJ?Sh+`4cs(#J zss;#E;BQder#?&bpn)=&m^|9LOUD-$nOY+lD0-Y60P3^Z`jKbzKiw+&RH?bX>^{2b zKX%OTSmNB!jGbUQkdFaZ%hO&xp+W3~os^W#a@3@Qj3A0BWlc$fwXCjaeznwD+a$s^ z^R8WEmQ-}#)pPc27t>Sv{ZevU&*3!_UF4qnuY50@ZSlzp>r*x#a%{QoEj({Dv&V!Lby7+`|k zctQ8ZGv@M?2$c;lAVy`CGjFl zAw-(Uz7%<>{GUNJ!!*DD^Z7iVnd|d$zH`sH_dfUDbDwh@xANEMYv)+$Y^2p$p6Tzz zzw=zyX7da17dtT{bB+A$8#YZd!aCUS`&HsaRf|BNu`` zw!X)=-_`)s^cZNn{Mb~5JdtAO21ND#-H3g3dz{+OQ#9}L)k#ZHcrTPrD{>bV-KC=c zR78^}xv2DrpH5Y4p|ygkE)zdL*%%77x2G?V`SfBpweDNjSdNVBI;GNx^M~oz{W(Jp z#PMd`0u?v@a)@(yGqFa-NN-A`EYMi{m}yc0E@8#UD^~Qb&bX;VQ8(mPZ_n;9F5?)fOaVC#GPFI#cx`q3j>BlpiL)dw;xX{o@k5g;SlejLTZTXApf9uTX zR>_;DXqW@6@`(@x62N8^FFP9_uWOFJb|_})HPA%HTWSH~ElF|c*;R%QsmXa5;}gA# z1`CGdvt+wG@LKBfZNC(`7q?uvC^lbp`eGA;>bmS|%YWWH&uF*k!*v=7?v(C4Z6AcD zkd^D?|?F{EX`Fzh%4?sYAb$WjWe1y zCs0EXTKI(0xS)n*y>Q-QN%GyTdio!ng*gpg`~faZ6QP@r>f%k{Txz>)se!H~*-&`~ zILqgor5yLrzw_1nk zpZM9jWX6U>aB()$Ix#SSq_5tc%wOr)_ptI-&TZEpHJcf1 z(E|~V*%tpgl!m7dfnI(VD9+qN*Fhfw*0}=2Bs)iUEELhc4z4-c1E~pqIlPvQr;rY> z;C_fzx4KAGq^fMLDy;0p$u=>^kHjoJ>XUzEF+_}-jNRy$>LZs?HAy-#+FmB-y$|HTbE8YQ4_v2Ao#vi$>iKkfe!v8F%FIG z$O)1@5!K2%4=;8}+nC!TGnL$Vd`~#(lHuHUwU#y2c-3?^wpRoj*vRqSBMSpJj4zrK zwq1Vpo))R-EH_v-oJcHfCF ze9KJOsNZD|YO8T0(fm0#r|PZbyAcL8atmz4cM7FtErRo7*3~@_1v7{pDMN_!3!YQe z21?{L_~W`8U+thnEP3r`uh!z`bVNxveCigvNV@;fxH&q=-$mY#5Rxt2>V^pa4q2j6`> zedU5)?~4I={SC?SgOAj5en|0%`X`$R=I$k4J>4>ExNy64T0QbpYAm)PY)YwVgfL%*Z<@JCT%6jJQ`IY7m<8T0Uv~Z{p%CF5^Jg`l?)Rz?sLQ= zFDKD*{}huqR&QDQzBxHRfre44aKm*~s+l^Uy*G<*V=%G_+A?QT5D&T9zO$_*QVx?0 z4O9ppE!)up5qDT<4@0k*R*+lTLLDaCI^j`^DMDOtL`4hdPA2Y#>bVQcdbeDThPmh6 zk5<2b$^o}{deC%QrUN!%@+Hkr{cU_qS#dcRB*0wU;f=&&7`4VnMqZW=9dpV7ri0FK z$I@Ion(|=lXR41B*{;+=T4ES`t_^!UT!z;Z&eF>aR;RlaGwu=cJdb>OzCy{91s%2#4kz7sJ}s+{8S>GRRxV>E-1wIaX* zGYj~(9_i1=A7uSA#4y^fTAOy%mlRAT_OuF^^7sRO@gu|Kkx`-r2&N94({9-{qL}Y^ zIgzX7MN5KsRm=0(Icr6a=C-wHsq_mkj_ao9ms+)>zqyXY0()Gvsb5s&*AtWar-Qu7 zI{!Bvbp2wp2Vt3r_Zz`s3F*C$P!X4tyl*{iJaX_ zJ@xVvGA8w45bB%rUd>r8F$9xJ-O48#r^p3*HWXQvq=$Wdp|Z_2{5skp&J^6t)~1#X zhSdr30g_v)wAP?&{>}bQrt44l*n`-_a_h9%SX2=O@He9i1TaVPamr*7)`)v;0v|$@ zbPyMUn8Uk9L=5dX|Fh0EPJT$74m% z;TSni&vMx@sKMLJ0fb*f_#@wk z>+*AJGx9|XSI(SKR=Z#RcV^g&;OLU7b2+%ubR*6)d!tX{U!YtCIJ96_xkq$MHgC2_XAuzsQ%iTUW~;wOBf zcHK*7zCTWqk^dAZa7EH3b)n;ZQQ2w9Touy*9fzKoPQRKk+eP|}p#fX@J<>fJyz@FQ zXS_gmB5vd0q&cut0$j)wAg4pN321=t$omX%BcFsq$i0$lLJ)|fo1-@b0<3=6z&h8( z)W8P7-@!USAa*NzI}+r>^*|1-sE+EnaUh}`@c+o0{HGis3G+2acQ-%)1akNt0s+f` zxy}m;pBEJ4+LwVO6M?v6YgE`R@kR}Rk~?ry{FH(6D>Uy2lZ2e7z~2q*xC!Gm{o68y z`JApGl~|4$RofKbKNp@%j%DL4i}pE8K52?=gDbIbaGbScJ-JTR`ze+E@=J}{ku1l> z-ppI~KuR~h{ab1^r0Wm(ka2MT-mg8Nk4#-v0Q>7A*k2tYi^yMty!N$20s`+Ba6q)6 zwa;6=Fz!K|pP!yEv6q2B`M21Es=eu^kIyCmfyjW-gyNT2q?}(n0>}XO9`0Vq6z8Y} zP_$odLe&5ct_|rtQu5C-DgkN8fUR={Dbyg4pApiS$Ym|^0t5F7z}^4h2idU?l4BpG zcpDFNGY;agR~}Lh%SS2eyK(;Cfq~H)IP6J=MyUm6RRn@e+T4|+L!+;fz;6|94hM49(F~oR;C5)0OLCDB@kdhR7DK|+S*4P zBpq;faj~=U2HNknqol*iIB1?cFv%a3#rn(zKT-n(3OB?#ViG(-RMKG#%s)`Bdt;{{ zcooXJfG*b{aCcQLJI>pEfl8?6m?ks-KvgvI?`VNWTy*EI0PIFNvIv;RnbA$`}t{g=fbhB zhwK=S++PO22hO-T0Xq6^?+`$EA-ZsDJ2x8#S1T`P;iJ_!wC`H@VKx3>?*ayHjnBXW z%K)G;{i+6%)Rcl&4eRUPb`T(OlcSsc&*OEdH3J>Wry9o2n=q-@wF7KY0}PsAG~PlQ z8t;G~%yLij5}5N`|%npy$0prHz-P(T-%n!lX+> z59}=T3_7T+?~l<~`(p!IJ{?Rr`!wu0^o%E{xaR8PxL-TOhfWMkME|GQiRclzQHgOa z$B8KMxiP^qPm*J`9D3AWR9ITuaoCX)rg zojr8kF()M^9P~YQ9J-eY6_<&NgW5`v4F;3X2@}+hh!U#-(LE`sP(J+s4MqA@Fi|6g zuv5{U5~x%|0xVPvE(%QaG8lF``sO<-edz=iI;O37O!%fk?0EE@TvR+0@p1g`eO*kf z$qno*^c_=F7Awhd){&hNOu!X7Dy;1ZecuHYs7Zzf{C}`m;xV!QwDLt|Nl>D(P*$Lr z>NI76Y0mVgmGMsw7??BN_{sUbPDTTKdxW^mHHgzEaR@jp|DHW#S~jBs+J6H6RE5JO d)4I&jB@+Sg@;TBu$__dWOq-RdfLSf*{{W)JaJB#d diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip new file mode 100644 index 0000000000000000000000000000000000000000..2f8edcc0c0b886669460642aa650a1b642367439 GIT binary patch literal 80352 zcmafZ1B@`;wq@J4ZM*wx+qP}nwr$(CZQC}!wmtuSnLBspB`;ISE+lnMcBQhk0v&QGOsvq7S?-^UssrKY{S?;d*)&wieEMdUW<4|230^y}AYwy4HW6Nl6u= zoaJ5-LR@-QR$A^vl6rDZB|J`!nwFAQGA2%Ke42Kgo=QPnnre2Al2#%n?2d5qNx#vg zLS!W4-9uaZ{$8F;zKm8{9bO_C`oE>MI*Aq63km=LhxLD@WoTezWpC%`{QpU7MO`Xx ziw&XoQw?Flz@#5B5@*v;0sWj>WS)p39@G;8D1fR(qEen@k{AUt3$zi z;{dihzth9IJM2krlN>HW3%ZU@f)zh#x}qz1I%iKBe-(PUFCT?dhT0aD|MeuiGxy4R>IH0`W_ZMf4)j+w+$g z&DpJ|DXf0EK#%-2Xw7m&$!gLmp+z{0+|j)9=N-Q-kysDt#|9;^#5;m^kYh*zh&4#| z4df7CH$H;4FfyEB7=l^*M|dV=3w8Kc42583#*=jtV%zL*ueu&U2Oo z&)ds{ur9liY+7?~5wuIlOP?FV65|eTeC@3ZEeknNdgMEe>Yr-~rnGgV#w^%xwuzI{{u3=pNbvv`QwzoNZN_kuvWq3t?@y$ z8MEXqc4>Y2dgo}^sFQb_jN``W;U}$r=9)Y3S(74`h~W{;EmYZKnY^q{u~4q{4VQf~ z79KOih$5^^(v5a|aY3h~Q(t8JFvuJMhg%~`X$_cd@_nb7q3Pt3mWyq?5X=1Cj-+d>H*(Nq>;;*OGB2~ z6}$*}Tnw0Z$!e_-f+JI8IeHLiZWRdIV|J<#VaX-l9)ZP_)z(gWnlBX7X-&o)PRc&< zKOTvBXrbPmD>-L= z`cZbtSy-iQOV||H$!Y_4TiFHsY&F^43qL=mKdxUEPMtU`d(E{95DdDFx%`kp<6>h6 z#S4|1euLb|Jb~|)-B;LQ%_k-gCC-eYUZluQgEXN2=B+z*FU4ld<*LphdGKbhOaGWD zJ*gL?wLTtS-$re89X)Dv*)w3Y=s!~ySv#&?=eaJ+n$M`ks-Q7b-^}BK#+ELLM*q4? zzks;u+q55ct(g4>p8kby<~YP%=coVxzPbPaDF22hBWnv2TW2Q|M^_Wae`CSF!0uo1 z-@xu#`y%P6HRZRq4`{)gA;nzTHCwqvc|~jGVl{f}!l+g2NTZH@q<}!kx?S`GNk;m;nclK2jw~ zLSc`rMZB#`+SWu@jR>*IxhIv>J77;TLtm|dR+R}ET4PTt10Lz#=z$QCuf$}BB%@tC zwatRA;3~_2F&aaf+Te*<7ur#?_x}v1l{r2iA|ziF@?!8en1 z0+yUUz;*&d6itxsh^nKNH%F=JwVFeo={PWkG99)Ej#hIr`S{v^z7Coe;@DNGtt95y z_)cN9NK%NtQ!>4}8q1?h5+aHZ`(SP9qpoAG zp|pO?h_NlcHeZ^L?@WKo@Xbs#0MhQ~q!?NYyOrcBY0sUy6RXp(QPdlD0?Ed>-U7(u zFB51^`qfQV+$imO-#`mYU4RnMoxcI>XkFPQ%(Bm#W;-?yH)njvkwS zb9KgUYiobmy1eW{%z_!8F=R9wi)_mqCoE8czP@cTC(LyM2ypy#DnqSW@i#M#x|%?EQ7L~}E!QUlDq)X4OqO7?x5 z9sCQHB{GA$z}DKWKFje3nu5Bv1+45-_3_2QJ*$^sIYjTKZZ~4;-}~~q*MQ$GjcOuX zC^TyH{mNsx$Nb4$j6J{;SX>b6lrm~!&7X~;@V+H~ZXsNRS;Ha*t1VXq#P^*s%|wt$ z6jU%>%TLBW&Kt)Yd~#`+FzrrSDP9@Xsh-AgS8)uRU={?oBCedNqePLlNcqXYzi$i7 z%2}`ZKc@EncP;Tu#(PgSNT<9E#W>*dMSoZQ^s-Rwp5Kd-h z-XWRydEzVVYG!ZK+8jO5_s~>wf+FX$_m29Zu+IP3U}LytElYKIfk-2acnT}62p3uw zv+t9RQ|gpS$?JU7w*C>bke}#Rn1cG=88$+ygmnt{ei-`8s9^}}LbWE5gb4F~l-KwB zCTSRa<462Kdc<>S(4K39~X>J1BqRJETL>by{2Hkqof4uyPm) z24-Q)&~AJc56Z?}XSFQG=bP#Vi7e;ErG}p2mu;LG@Dww*L zCwP;<%9KE8qz+f3AyB#6w(TMnJ|Vz?d9)(SU(op?8ead#DMy{;eYsSc;1F#x*lJ!C z2CR~rVu}{{HT-vnv;YQ6^+2Mb)9I^oz?f7b z17Iqfu6OUnRtXUWj{WcV6cbBNHwjTYWK4sh3atk=sFHQg>8KWmWO+&(HL#RK3nFJ6 z3(tj#_Yo-2Fa3HCuSEm&)CgMyfi%WHX0c?aqR$A)zb(S!ERUswrmB=4hb5?~Hkl0b;BN8)z0 zHg^Qs;|E>fZn)76JSB~DVuNgTmUfB#IqMeh z6vxngU2+c`kU4|#I`@W-iepJl3|@sN7P8zX-Pnj;akhHD2ZWwKWKU96v#dWF^cItU z%aS4PI>Gk_?AwvoAcbdU;83eb{q;5{LZ+Yez-lM#BO77aSw>%gSVP*LiSa=?`OhR6 z>Yx32=Jyj2=2ch8f%usJ%=l9_%I0$VyUwJ_TtFHV5AYsaL}fn%T zb7qzq5c42X49xd&DK~U-tqf0iQ@q`hQLK|iRzYzP!`ts*_GvTSFZ zH{BpjK;lBcWc4L{l?r4WXN~BNrG=KdD5fR;-mGYN8(J3Gr zt5B8|oMeIkWsEorBX$M%VcnsBgpR%0G_M0cSuf;tj*0v3B?XpXt zG}z7a+yG4#ehaCMCx%&G9LInwL!GSXQk3jY&)B|SKZEyM4ed6!dF|xK!SDgXrIBo9 zT$|HJa>vLydtmJHo^b?$>XSxd-4@FB}`L(=h^|tk}qSaat#V(=3aGd`WDp zy0f^*mHCrN0Ws>ML5{1X-EdZE7(sRG!86)6_!Cbtp+4iIeshH22(O&+#jA5)w4=V? zyBT#9VR+shhEWFwY1d8Ak#Cqfwy9yiYP(7%>IY9t;gjw){ukOIK1A&rr=9yx zBsm7GZiWRfdF^V7YaB!)d?i{Oem8?C2WirE z-_9scgUE^|FZ6M_n;rS}e#^bD;d4~B3jXFGye&AQSKut`<)c85S2i8nXB11FW%WU! zt|b#aI^_B|26Bohn`1G&kfxc{->1sIgOeH^uiU+YoJ(k>C+#*C->J)Afj^JXH1e#~ zTz~$kh4P+q44JH6tVhl<%&#yU@7t2mi=3FL2%5R@+;yr$s%pz{lWkhmD1D^J&+ww8 z-I6|{nD->^BDl%8SZM`#NI~_}IGeHAy!@LwVmS^6>nGe z3@cH_wsy`JZEuF&Tr7PzHLt*P{jASW1vXE8wMy5yfaF2L3{|XTd8<`hOQ7sfNA<{5k8_!YBj_7H`zuvV;uW7hTrtL_|DYLr2KQc#5WKD=BoA-g zFUPrd_}TR|9w&o8?OW(yB~dn~3p+=Hh~_l$graf_`UqbS;y>9%st6lM|RTj6uJGEp(Lhp}dmhT`Jz+OIkLg+f=*N8J$NAG`k?&t0&Rz^nd zPb|DWycx3q)@fAre`I=m-_s1n?e3(zt=;1Y`w+@PW+T39t&}UUZVfxE>mB{zw;Px4 zy;WQr6QfL#VYgzpi;5??@05ncs5S{rsao)2gM8KG&c0K~g^=9Z@yHYTqS)=T9Xs zvQi3j^>TD0jxz0zmngwhHUj{aoV@J-bSKE2Z)4UL8$>)8-%+zil)oOgez9fGUzBh! zB%(jAiHaXO9qZMibWBFJ4rd7?Gf6KY&T%)JR9RI9)RU>;2BEiy8mbh?YGxVJU&nl= zU}A!j$0kk9*4EV4?pI8_nY*}oj%m_?WQW6OC3JU?C0Xq1vgktW6;e<2clxCDd|zJ^ z#0FFM8IhF*3Wx>pY>t=Q34QyCe4QVh?R3Xit_kWO#N?1PSJk6+jBf}dA} zRCKBqRnN8Kg?>HHebxtOauS0t>aYxYoeUMsbpa19Tc0Z)Dr*FelymGA%mF$n#|wwN z=Y)g$@MQJ%^hVR0va^rXnHA+m*ds+*d_C+u-Rf4|UAumr?|g~7z_U@I^^IO!I%E!3 zGC?dBJW8|fCMcdH9X&d^U=MD)4>nz^V%q_(8n&o-)kD6-5^j$93@_nE?d~jm1#gG) z3?J~+8#$P})Wy~PLInnuDs4+!t)<2`43iKPp*L3-%q0nl&dpQD4 zVr)getD`+z5c($5Cs+&HJ$7viXN9o?zdj0fC+%UVAvxq78gZ=aIG1qh@c5g3@9%@6 zMk$G;^VXa-e&3)Hrl=>6~IE<&={1?0W&q7cV3WYMmHlZ?7wYpO$4TK0Us<66 zD!x|x=1%QiBAA$W_}eit8GG$AF-~5(S{+V-abENe)~KXw1K=B<%wv&uz+(qMwDjPz zfWOB5r8Xu}Gx|QiF0hk2SmCY^(#&GUOkV36XK9|xrqcMr1SSV$v*Y{y*0Wx6AAP+| zYQRKvJ{`IS1)7P6bL8uygFE_r=C;2Wlw-rkr-}^s<|WVDf>FdlmDCXkfHi^`{mIt- z{FdgAav&!}m4{g34Hx|2u`q!gKY=Vik;FDBCq(=$e(Y2bUq*<4Z1xoI1@_?%CX^S# zxr{KN-0kY2YxQ$c~tyfThY__XGrb&DzM|)YA(yI%w+a+YdZ#( z@9T%a?XiH8r~6tSsfJ5y=+MY_7{ptNMXx`c!s@fSz99Rd1+-zb+gYoQ-13DCSr0(swB40%H{K6nKqsta;wwM zB0BVw5CYXJE6ltzcbHf~h9aXS@;qy2AJgf!(k8&Uo{f?zQm3=AILmZECaOgR_A&T% zLUzUH&$s64jjze4Ekn_{L{5SIZY+|niknyt-37dQ&~;kVCA+YOoBg)h*yUCMcf#%R z=cv7gw^s(W_ciZ ziELhk>M-5fV}wPqe>P^Ddf8nqo-VGmwtJh(##=-2wlrO#6Yg&Q zXQFv(Dt~n=+|(;lEmvhdpk1FN^ZZ{dvov^Z*ez-0uNQx?jM^68nb(v)MP^@Mz<=TY zU)Go^*?s~87yuv<@*f59|FFh(HueVpM?3r<_~I7Tnb;);gx)hs(H*HyD#S~oU@A}b zf)HRNpnze(8Y1!x>z=I5ifij7^dnGr#!*+_EL-&W^MZ!H_tUYXIR`p|6ijqC>T>Cy zLm7swnLsy3Ev6x4U3(}0ZlsC>Mf35JK*_soBQQwYbz*V?fhIAOyLWvy#eSyBfC!k{ zLQbea)G7!S>UO&NQ25I=xrT(1Rg;O7Z9$x~lnRtl#ycCzlm^IhqG(EXD%Eq)df}-; zCJZl81QuyJ=ZFf6vX6Wp=_ZBoznPoydpHOnapf$uB&S4a^{sFV2!;KBU;^-~PH3yw z5a`hFJbR)1ed&Ze++FZVX?*Tn-e~p`c)Ue`?EwXE)cr$71Lkt4p&rZDleX#Y4rsOC z0WzMM+gA)$lMVc!Dx82C$Foic;;UvKSBAy(Z3aCh>xj)P(jz0dig#Ch=T+=%LW)!C zzKx8<(?xAi+)f9Kg)LulFs{vUj`ZUxN8qHei*lZ7u~*y|^$*40|N3H|5E9Gz*V=V5 z277Ko)JoWgMxY2GktT$~u9YPHu3-MYl9+POo~Yh?x5jdgTQCxLFc>hoq)`#SrCU>T z3CMclF?zer-eH0IRn!ZDZ(B=w%2o3qbz}_fbPw_P+0)EccJ6Nm=QRIHZuQjGI8yfc zM4w#c60a!Y%Z0ZniPg=*Wl|{OS0wzzk#gj3v^WcLmjIYGX7L$uo_MqUlIL(9>QOy+ zK>RATPvBdw4nXrC)YUyVOY|hqvA#wC#EiP-jB3%|f@$Je`mxXVdS%$9n}$YO+fr_h zadYV20oZ|_p}*tlpoeT<-5la8h;Q~UqDTQ_4$*Glzeq2{NCP2m3)>|6 zb{(!`9dazIA49j3u-VL=lssbJ=p);~<`wtDv0lzAiL&4mP?=!Pa*JneH`cqd`VJ4p zC_WW1K`AT>a`LN@R@$#Kc7l{Idn~K?YCFyUNy78KbO5@-cM16rxWoQWef`Tcp!87` zy!20fEg}E_ApX0)Of75;tSvnMBgy;^u7PW{ov1BQ1iv*sMs0jcbgg32c_NN;ohbnV z3&I}WltV2kVk@UAorbQ+qu{UGobPzjjfb6%wJ^a#axa#=zV{iJ53tldT0%9@v@DpP z?jDVwW0HxH5s?Q`)EfSJ)FyfT2~sS5`yo=~FJjAaP-ZZeOi_uUwPH;wnP2v8_2pxP zbQ*E!jAA~C4NS5LdaR5QpyXh7>7h92w}X*a`a}|=nwsR<{Y(q3f__i%zD8Kd(HOak z<;u{NN^6@chZZ1Rc1bOB(lR7_QVvklK^p&>7^IpJm3i^Nk{XJ7>K112PtB9pdy^)w zce)vs9r=!X){ZKfM3RC6$$Q+cOQIC;3n{d-eiX@NI!MM9UfB;Tl#D ze-*E8{iwgYT?^`kq8Z4!Il#coxrUN19xVZrsaui7e~=Al9zH!X-SOK{MPLHj<5XIK zT2Unjv!m0SO;NI+8~9U}IUrIz%y>m-U3Q^LKiA1z9PWZo)-AQu12XNXppo{HRz}K0*);O zph^{MDEO(Q4C(>VF7bq9CbzLrOkms^mjX9JQmva{58$@p@&nb-`rk%mnSu^^UPFuA}Jk2SHU+sCt1nZO;JHqxEK^rXX)% zVSfIzB1Ur{n3GP>ppnTyD?JO|kl6C*$ zLXb&c`d?T%#wEQV{ES5PLAKLv0*_l%uXu5T-6AlRv4TZ=S7d+IoIW<%0?Y}{(azb7 zJWfJ~)gT0#8Y}AwbR`J#Ag47e%uF*Lv1>xXBTio1?-UsV(FjN;d3gwNq08wDPq^8T zSc473#czC0?W;&Em1uQqxq}B;=wrRE4BZAN?>YuWDWph7?qCROlG=_S7c|)Zy@|I3 zE+ZCmxQK0mWr>u!stw+sc|IdN*G(4lD-R0!V*laS3FKasDe7kB)7E*|@N6%9XU z^mJ90I{Cy?h6E#%-fuD=8?<5GQlMfWi$6*KL*=_rc6;UH^|yBf_s@gdi%-_@&l9TG zaUiZ&=b*HHLT)IJ{fR~9?|<{|+%}|ns=xvO0Eq$sVElV^urzQr&@-~Lwl**@PwBM)bb!17sU4vUPClTMJd_zr(4rZClsb zCQB+r>U76UQYPG?w@ozQhfkWvpbYhEUztOh?J4Bhcqo|p{(Z&Z`n0aa4#H3Q>5z!W1)wA)Csb}b5HAFY2Fl2B(#IE zXk=@Oan0JlPYrw%Xi~aw#08;uc7U9vhiGqS+7E@}17Hm}bx)n4cdt-e#5RB z$2@GONPYgX1DM!=A#1pTZ}R&I59H@JL{mYItp|XR{`!DslA|*FWCU^1aZecY+9F+R za03z3MIHM1+VA-K1tb!#z|?@9&BD~x%_^KGe}bv1F=PXIVWO-ng88R?^ntF2Y)^?u z1ylz?R0T2z3?=G>YqLb~f$>WECj-*smfFJ$`6G>2jX>uXiWxZ+K6x@p8pLOw?v*x+ z`&ipdLl`Hsfxun5;2~ukpd|pOX}k)?#$T#fI-B*0cP=~zi_A| z7(B&mB4T5bQGKu|sxwL2lK@O70}0`{e&bH?Iegcw>|9tGtasyb<8qH>_GR4XT!h)9xO)w-enyzW0DaR9|N5DI!O z&sz~X5y`SPkLV=S`QYNBw^d;!10@$ovQ&x-O}&?b>JmjHcSDrM=y7KQQb20+Ket7( zjla{D)`Xr}f!p~Dk1Z)A_Ul%(03^>8ZzcsSMWzQrj5-C#R|N`xeq#dz%Ln94_6R?^ z6_j9iP6`0uxVi?*8SZrZN|^fZ&~;Tey<%Jw#G)96=@8j_>JUbVvkQvpjs7ABj4`23 zfVD(WqDTGJTxRJ-F8}BU{3C?$0^9#4G^DC_$C09yJ{|ufR(+L9eIrjC2^Fb+x&bPUHE(%vVSoybxsasJ|7Ai*`YCY7))oTGFazTEg3%fB7!Lr)D*RVzhKWW zsDo@C`#tSzOav%TG+AF|0^r61Bv2++6o3khwjr6= zu+t_A$(zGg2-56U)f~@v*PW2*Y58RN&dO9yhaRTSzq1jJTW&5t5@SI_&b>EYMwmZY zk(Grx;IVvcJ166XQf;x>w_O99z11AhRxaRLRck!9K2S^G!Md3Md!Vwtg_Mg}54w*Pr3p3}y#nCS12~LxFY= z4I7qvXcO7^@OpI^d}5! zw4!L#`vEczEn6u>JyACk(aGev`CJl|A)4oax*|C+Kwm_@3xq&+HCm2hO6&BVi0;se zqCLNS5>)M!shC~kCjw?pseaUxk|1Sm_2XmyoYsFNah7bcqTPiVoA5^%Og2iBza+o(ceM%ENcj z)Kn2);Ssx;nb=vM$eEZK@XME+oK@P5oVft!Zv>}2q8w(r`cvr#r;>k zYar73#j=a9Z8Bc}yl6;N(NsdoaYZ*`2Wv2Xnh&Ku1TD?{?^=I_r#CW1hRsNJ-me%_ z1i2{}$kh=11&1a35xksSEO?xeF9;N>YkD3P&$s%{F~Q*%SErd!4saA#YP|poj=!dg z+VRvRScT09M9{~FntlS6wJ>N4#nfFsE8ds>A@LJqX{G4j0|%X2aQNL$3!!{OYj`aG zWiVsY76AL99E4C5wRp!@^R_^ctyHoW8T=Ul&ixl$kvrHqib9HNQ$olJ^n-!QHqmhu z9+w}Gu2yD+2f7m-ed{m!dSPD-BZwpnze=h{WpEq;N*Zf*8g)L9x`GmbAmel5HX1C4(7tciiCf} z>&UdHzv82O;fEjjt?j7VU$NOx^p+@Z7Y{%%9m-}filb^48B{`|LLJI&8OZsd9r_W0 zO(V*fYfK$uwB`D7%QgPixX0p;mKMt}{0XD@ZVS`}@BoHU$Tu6>U74&jY!mxw!H1)Z zZtVP=xY>N11Ux+Z_EgJCcS?_t+EO`}x>N6M4zC>H#5b1`I9{M!`Rym`#W%zZtqou* zY4pR-w!iyhuY_F@2FZ{*V$9v|y0tO}QiaSea<3n+%qkpkE%_L?F{+7!@MIW4E&A5k zVcIbhCK}ju|0)Rc4)Axy}V+TtX)i;s(m4nwy$rqAxm5FA|4DXKQqtHMVXm;H$wgtbtM`w}w4es~`EAXwq_1mGEWPABXYRpB33mwlG#kd#t+< zdG-FnBF7>lx_s@`yOb-X82q^8`jIS^D8$?Miyq`e2SgQEAf{pvfiTk-c6t5zJP zu-Y$E1u46oUs^b)YeaN;RAQ3?a8oUfzr8N$sn6zWhdRJL#Q`r`C+yeQSt0BymRf0D zZySZjE5_K9y4TG!CG@_n)PDDmcSz~&_ysVda!mmiWP4KwqmFMsVlQRnLknmnj0qdS zpSNevdvjO#XWc(mG?6p#!9I{f?8``R8+2V`a5)D_s;-Nnk9=|+>tj|#)Q zM&b2!8rk?9$|5&I%Ltwc2frt$KFA#0PPUASx;2yA8T}_!=B{9NlGb%uQaK7R4)S3| zc_kD32`NNqnoKM%h8}5Z)-h=@q(X7AXR2Z=HlXd8>oU&^%9^w0NVb!nW#U|rcCs|S z3s$UC&}hH9Y-;W1@O^lwL4SGUshhfAOqY4wY7erebY|Y%-Ob2D?UGIux%F*p0GVS#1Zt;^8KweJAqgtvpq*K6X5<)C(?Jql zE7Lr*aqqXCIeKqh>jVB|vRh-}qt;E?D|zK_b!RlCh%|cKwT|zw8$()rSngJ}VxvtP z#+<+RG{xtm=bo>P>Nv{u5HG3m0o3Ry=%YF;{o+r?l36*?-qjf~=BFLwNqrFgh|NT(*6uKlx_c%~+G3 z3%&*~wd;2M(xyBIyo|l$<`0W4D;K;s6P7k!%&zvXm-FWt9b3hn-+xG)e_3mJYA4Ce z! zD$#2jTmYU@p|5=n$e{|JMjUczXuAtDjoDLq*mf5arEniAPA66Gb^NCNmA$Wn5_ zsnS=iEafNB_>NP|rq`6%b#i6F#)#{w8iWY4ue=5JEVX}+fa=zO&FDcPkR7=m(M5b8OY zfK^A&nI1HqbkrG>0q0@o%ppwfP4e1ibRfJ;dChGTA+!io%&W zf!Sq=$4REv393fQ01br3gIQdmkDAeoFGD$~JiFXyB?oZ!|5?!4M9&peF-;>D{%F}|=u)yiOsV&tPrVmH_*vI@v1 zg3x!56aSFhigdY&HmJU96J5=)VddU|9GtOc4a&JN^bG}^l`$8s5QefN6#jt*_Pdpv z;GdjKQ(n*;!DQDbSiZo%@YDMI+vCn1ixyq&KJSgbNHFr>l`!48*WZ`5Rfpe`h35+f zPEOw5OL(xm_V1?+zu?ciy`@$kFAoRSV6{2lx1Z13ouRz2I{Mk7sj8?v-xn=zH({@y zpA!bG-0Wxi+Jd1Ag}UF5!>4w?kfU=|+vMGai;JkD-MEE(zh|_!YPT>tzGgZ<-ybrp zFLy`GXF6HA&z^&ckw1(*&mvh{@O8BcyNW6UpD&sz*R|$#_j!>&l&Y2AfXz!(Du79G zp{b4Zk&{$YPWEA8o3{G7{TugY`Dufbx)d-Fn-}0b86-x3IxfNi5s_mt*30Lg`WJem zkn`&&AB|;V{a74n#fS-`-l4QTRzz-b%l2oMB9Jzu(*ZqOSrM>-0QXHJt$IluOEXRf z`+lum)TL1&%bYxlfQ=4!YdXfS_SpxiH4-FZ4t>UUC581j{0G8#pQ6dT=;Qs zk^6ov3Lol{KCaG|Wscb}-P6OS36nmd1bSUjJ4;^AAQFMemGgqDkQ zz>h-~x2;V708BJCezX=6qOxr9DU_p)ItkK%gh7s)17OTVS*M7xyD_+9o+H5Ld{{vC zs|k7G0LuF@p`vM}8Wt9?k|6-Bp)*u_NU;d07PD#(hxeldddY{1(w|s1@|&}9c(^$7 z9xDTT0w1RwJNF|we;{DRVxZ3eP={53r2nFaCt9qy<_DYjPkgk%9#zO$BF!wm0J{3i zXUs)?F}J8lRJ6VjMf8izApM~)C<>Y|xUKYK`73tz?G_I2=}%LrriQ+jmZu>;>YbUo z>L#IXuTlvC0AJn+6Rm;h?p6ubtp#OC=`59Sg5z$Q3^7pF#R2+B*&<8-H;Asn!X8oQ zxenY1gs1dYabnrnw5mAr5yU36ItbGrGnq00n;LK`?vaK>t$Sly`sMhi`61Q z*9TB<5~w*$0aKrj6itE+Ebo!QSb_%+M(_&9(#O;BfZG+61TIB_v=gWX%LnGzAj}4V zD&Qms2W#DT+t1!jA24~fkIf-2zbVD!j~K<*)sNW|do1Z4yHY*> zU6FoAh^A@V-%#nx$E&5eYro}Cc)(#AlNOQGS+uk})T7WsM#oRA)aSZ6UT9jKH~m>l zO;ff%0#FBP{8+a{Rk%fBRr=(RJV;}jKm;GMXq3_l{NL`Clof~8jBoB;(Y8h|YbzJT z{3(O9$;h6gfaKs>DNfyoBl!vUn$g;?h?WqZ2#ERg@xA;GQuzhVVqD`3 zhas51{QBX{d+|W}is2aQBZu9p5A|{^_}x_S+x5}`8qmrd@LeS4?Ujcl+#$lO_l#}?if zH1xi@yfAaJ4d}37YxP(*IfNX6N0rXEe#ACI^$mIS2(2lt3V#Ze!8mJHP_Ju}l*DDP6Hax(>OAHlcp5mMRjKRJnrQ(0%@)9j7c0 z4pI#ei7h9}19J5sRQ#8qM)VrQL2w-`$T4fKI;F;5!ohd_^pipuP2=GU#Qu>5PzBvr z*`U{`>lx;xAmzy|UAI0B`LNiN0WJlyO1b(xL6qce*a8Wn_J-m#vN!3M+WSXN4l6tW z_sIxNM1(OzulZS^hRDD87$aJLKy1ch?Jx)`diw`Q<3rOv)$4Ow8-OOT*5R0x*DTGV zCVLauCPQdY8tP>bTM%)_l2jD@0dG;M6dIMEbM^FFw1?URXAb51`@f6$kUX-P+rtQi zQ@h)V4lo*RE7NtL&?X$g3LN1Html6@gNLs{xr5-(qs${!$-2 zkKM>6y2Dr-q&_YsQMS+VT(e1;uGUS5yl-L)E*MtDEEZow)=YVipgFw05a6NpDO~#T zK=#E{;yzBUJGmGZLO6(ZfkIRF@enlQRS`_0Q?5tUs>5jlS-) zLCYLd<@M43+p6H(HNS351NF`T>lUAB0R#>}2Moz%iRVjh6`(X29A(tq%wKyMY!06;3J0DWe!gm5~4?Wt*V;Huf0Ne_Um26YPJRTkMY92ywM zKjp^F44rbke`KnU)u|dO$I3C<;@FP?hR6`VKQf$x#DZxf(eFbMzrk`;>TvO32faj1 zrBfn!0^Kq^BfW(72p1CpLvyy`v45AGYqCu-)K;KsRyl8M(jaP9@STtQA)7jyVMG^# z!8^0aQ>`px!f3)0>VYE7?YzdydI;Vo35y9S<;$joVK*`kI)}K*NwEZQkBZJ4-m_Vs z3Uy7e7TjzbHqxvof%|ISRY3Wo3SO?Efx>~w3@EB9u9Hq^ANgtamURx&a+I$t9;VztpKu; zWTZ(nlXzN#!p=z5S(*|m#LHmM?L;?&i`I|UygZK`#jmE!DNEl?chq>`Lck`oLVH6I zdI&U?;FkXR5E+_c^QoksE9j^^>=A<9r1)j1Q_?(}IN9v}2GBtOE{GOof>ahM?Im>W z>;?t9dsG!7YWGDBWrKs4)lXNzZbY7eH6o|0S`VbUZ~G(=w79BdE!QS$(*atenHC>0 zj>O|N1X2+k?HZ0sQOl_Wxop;15vmAMy3YH$F zfqIE-Wu+IPB_ib1&&%oQg~HRvVbzZ*Rw=NNL+++Q$VV9)R4F~MMjbkeS}ZC-ZYS*q zKwvnI8LZi-joj=$PA}>11zEjpaxcAp3lc9F5bh2;YYdE1n3{v!-|3d%oTEQb3-*w- z8thqu?orsqg77V%ObvBH_cmg;v3+XsmWo$NrTSrVfg|IBQ)q^JAjYv_H}O#bj~e+2 zKs0fRHl{3SU~?=-H+VLMfj>@kBrfFXR%{P3X>eK5phZS9;(cn;%UUjSW;6g z$JAmKj}vN>2kO;0OTDwvg^J-(@0D+>GofIaQ96fHT`_ZquO7qhF*2ZXgjrRjWKu&a zv^>r0A~L8-wYSX@5cg)o)wlO0;6Y*(@2)}j9PN+?|8Svn3>)_vUetn@)4?|(tPM1+ zhsV4XJl{dvn4M~8>PWlADm$~(L`B|lqe5l2%(6QZoby$X!h`@`z>T=N&Q@u)`EoU3 zJvc=D6XW5N)7-yxoP2B0o2N8?N+|EIO=N~``yCMalE9K%ZUx&YKVPLAy?%TWk($}Vc;wp+6B(R>0rYxcCvMl*dZJc<{z0^|M`q5=bmAqY*z&q78miz z2j#f&x{8m`P*#p=tM(23v`W~49KdN%lsSpeDsQNnAaI+8D}g)G%A~h>Ra5r4Hz)Bs%u)l=T$qjMf4 z2Irx???)idJ|`>m=q=@qBKT$yEdCHn5e=a6d8c;FBI1MA_EZa`#>tdJJ?A7{L60J= zle&uAm8yyul892L0t&+|w3v#KYd#YC7_{O%T2mbPr&&ySM%PMatAx)1jpVAzn47~H z1^`{sGeIU}Z%Dw8$u@RB-FS@L>=__k%`(Y94dM56{BuV-d!ex&3rcAx%U{EK&F#pS zC;ijo?VFMO>Q$fiB6n^@IvK)3=lvaTS88fR+mU1ngNwFQP2IwC4N*BaOM!1v%{!Z4+}pSg9$@_8x*lJK69Jmp8i21%c%z5bNmt*@>|v*R=C0}M^8J>PJGDu!*9ROY{*Aoa z^)6)i&|`UPox!JL`Mj4*7|ni3A4H4Axt|zV9Ss?5#Y_!XThm|LQ4z~kS^1Mu)B1*j z+x1k$9RK)uy9}wi1#V1V{0xWdTwUXF!?--rVqVOGxG{2BF5wjcE;+C}89@W=X-8N0hc}pX0n^dx& zWJRIMJWHpv11g*BZ*xw*0nuk{t__n_=BJGKFWyUb+WI@bH92cZXGgh-@fleDqYEgD zBD+J4;J&zbW69Hp>3Se8TwgrKF(T@tkqIBo1!%|IC{8^R$9&^iR+U=jw3ETMjG@F#K4ij>ouB>-80i5@aIfWgNd-A zzHspRdKg_Dzq|UnsiWISsLz_Zcg-Xwe&`_p70B)PRt6W*>B;%4{uRMC+S%LZ1P4)< zxEc)UfHer9VOv&A;;~zcW=G7sZYlM_T|8l5DPs`&IK&LMU{D%7OT>qB7IxzVE10ZV z1UpiL9=am#E;1Pd1E{ArYrXS8m_h<|bNvJ9FXej7dLxyYL^G~j*^;3ip5`&)411Tp zNGU>ji12mOOQI~zW-|zCPxW`!H9Xkol;ysAGWXt*;W=@DjQ>PHh*J@azl-DVx@PpQ z9kyd;eX>xK8}E{-o%=3nb@K*I4?E~BS-XPitb&}bc@icnzEyZN!2EduDzn7~((6rA zo-`m|3M%7PI&?`f&j73mpgbO(tBNYa7R@x7mF6?n5=KfbGof-PfFp5uNW7 z=w8OWP=>~t52v03eP+)+;eKV6>7(eDTHP%CSvPl=_QU=C{ii0i!|+)uYbG^6;=Q$m zK>fGDvzp`3Zw#*+u8!||8RYOOrxPcC@xY6Y)b!UG5^wx%+}Ete*5e$fhhp4rH=&#L zTtR!t2S?0m+BRuz3!Zegfk)|rWu|AJpyh!zFPjDYaMLEu{Q=Zef%oDk0$7pT- zn4h9;;!yP0I`lPGeK&=mW4g7{Tjz@{1)gE5p|6V6t3$CUK;)c2t5#Xaj70{4&ehbZWuXAI{eB zkp5tX_xlA-1RjE+G8`q4@KlG%A)hiN7VsF4tiU5dG?`0`j>HE3r5HTAY7%pZE*UqZ z3L(?+Ghvg}clpQ)e>ul<)Q{T@IAIOKJFP-xM_RKnN?c2@?eRN3Dxlz!B8S!iKd|SV z*lzHQ5Kz}AmMdLp(2UpfKr_xRIyLU~E+HV zQ6z_JcJbm(zjtpn5`Tb$us(jU?Wz?1Y*bsd!FImb(An2U+k5su$hH3FdsviZHx&#i z__tDP@APOec#u|M(EoRu)O9&kv^%uk8?TWuU(rCo^J<%$%=oiWAT~YRX<#}bO1aG~ z-(r(q2=Gpen4X6;0j7)&?#{Nr{`}TT(#TWl(X>7r_k`A3RCOQgvDJiC{rk7qaJ2X- z>O&VAgdZVAEkp{dAfkc-MgQI2w>GzNBnf`^ujqj`!hjcor3vrF26yOPnxbWnMN&&r z_Bb>Opb0d|wm<+z1EP68|M#nWRb{;zASKzJ-9gM)BvF<1%F4=jBC&dp@I-T>U75bt zRDrx?2WTJ>BE=~MQn^y_V2_DU?M%%fP$5WHkDQoZLiW#1!E03 z4S=~|TX#XXtDwD`vHf#oc}gxl==e~vm!aFdY@|@2G2&=^`XnXd=!leTUpl% zM}{K-JH@1$8qNXG7zVMTuM?rbsmH$YqRE5U0(9|4ebNFSZ_ry%c*@(^{FVeAj4hG; zeMbedE?^d==Q-RH7QI=yW3xomH#gO!E|J-0x`pY0#^$2aV{RSK`)vFur?GRC439T57B?| zyrG6*)x*?g=8r@0?i~JP1H{@Ohwh5oK&FI@3bK8_9o!D;nL77`d^u5P=vKkUh{0<+ z>AfD<5V#Nm&UPCJVVGP9=yI|R6cwxI04&!t?5lcV9N-YMVXiO{8T{B)=wUqq`yk{y z6a!fmBTS4a;RD&U*In=N!6~{+7&_8zzJ(G{AkK(U+#kjagaIBKFDrTQiQ0WQt((>6 z`x?yto)?x~O5q=k*4`Za&3#mSFHX5*SS4JqmDrmb3K42-2&oxa5kizi186^jEzQW} ztl_#D#7F7P#?03H-8PTC@{^rkziCBIyDI}oX4lQobK5A4@X5A9k3@~g&bBj4Ow?aU z(o|*fmtQ!S0H{0HmQaIn-!I{YOWGBANa(;YEr1ISrZ36X3>rJs=FlrUpIz2h;L%VP zBW39xwcvGh9DPX=wHsvJ?4zgfuFkgEFO**f=?gq~_$_VlJgpHr^V>(?8QDUXN2;e*Lg%r`s;bKvbxI= z#q^SAJXCMcfpiU<&rZ9Bxm(qj6xO^V?9eN(r?P+|^lu zScP%F^ZDELg4~80jMlxgS~vn(qbV6wgctr63FB(LlzVv85_!uYEikhNb5!*h4Hhu0 z3WkKwdI2{J$Ga=8YqX<~CbJPGp>kzH+tpjwCYo7OOfh)((i}9(VCh}OTNyBkP0GHn)~IOvj)d3|jW<-%!$rAFH|3-XG6&$4U^g9QjW;v%Wau#F z3a*Zgns4e!#d$!zk(A?0*K=;PiU+hbYcM@P5WWk+Q^EH_y0B0Wo#M^%t&nEkUlva= zZn)uCZVZtEPGyX4IOXL^Y5w5j=(JehS95bU{FK1ferF>l>_P^~w-Y%dd6;(Pi3+$i z3M;bVN3NeG@X*!%l4wM*ofR3j!3yB=Zj8@h{Snz>&BA=qIi2!Zam09FRC0CMK5 zP%=}7I0Vm!%o3VfN=qVk!1Xa6@g9L zgJ<^a3TB9}V;bBNoZn}4bDd(aG#hak@Or_LK)%o$>R%oXDYb6zo+Gto7Vq=M0~GgW zf7*EB2|xIssEt5@gtIKk)YUO0m`m?UZ~bNl;f#VuGNz5Ojo(f;8yLi}Ka^DfyZUdkq#fdoP%`rTr94QWPl7Z{rD2|pP-J%4 z0`Ster|{-?<5XZCn5mXc-NLX${MyH2UhHKw-`bxuO*CQA&4(#M@*(*Ypo-o-0LQG& zuUn{56eE?3BiAXEZP*T+$zU)DJD)r|DNfN^=?bb8U~Gc)0c6Ichav_UfWiS<9`0iI zKIya_3WPv{{MlaPZJ4%a>zkW9tghsRBg=5&U<=j(Dg*{}iqLJZ*UDumizmxVddLh!fF zi*Pq$i&|IBCVTj+-pEp{DcALEwfp#Kx`!TE+eNz2tGDH5F3WL+#U(tK646Q|oJsVC z42D9xr^m|R<4oM<5uN)R3@4j9QCrz9?-)&s~g3OYceHN6v!fnymw)3+F~ zOBWIh+xWOCFG<&;VF9}%nyLBh@ub51Fbr=QbOm0< z4eVfDav&$=I<&mh;NzGj+KmulICMf9jI9o~lNF-35cA?rJ1qGitFHudamsb#cgV?j zCml0g;i7gd$hTFsP+1^FY1P1P~m6A(%1rq=49F_WJZCE|X;9(iL$yjRTGBY05!d#NeqDELU z9IR?*ypdybWS?87DSzTt1lND7dpU^qyB8ykj!^f$2v`A^V^d*LNJ^3jB~;Q?I}PHZ zaOplO^<&MRSi+Wa?Uv!0t0Qmx(!$iX@#Hvg*D3pmjc;R6m zWt=I2)+!w4O@7 zcjF@NUYchMecHL=fE4gfDU<>GLN^GHvv>M-mpnO46uAn#s6sshBoJd+ew2JwY!wgO zVwJbrA`M4taS-`L`5Sz9RCHn4+=B2#O(w#a9&OXrnsR?96dn}6(9CV{qK6eDI(%$J zhyM}ir3&HJq%#|vrsl`E8NHOYkQzbm=Qx`o?o9Hjh7V`$(Y5Kj0<#y78tDp&!UZ5h zJtKU0%Ti283mXvgj!LsU|A5UodT6h_`0ahD*nL09-wbmKZvJK?1hHy39;GD z!q(9V)6S4BaERoti~sF=cY85Qore_Yg!e3w8iexQ+f3T+j~}l`5H%LE>`TY58~wA* z;mb6wKUi8|+q{6*8ME+udU|r|=udD@+L(F1DlI%Q1IyStqK;}X?|3EnCw0i7W>Ly- z%}5;%2#U(oU2na^lVc-qA3t*=IwyCZ|KP|q;8mR#n_|*37In%Hj43x+M~@j&$Xmne z=_QMBX`U+c*j1O8pb|0!-|W?uV>?`)orqCZL+(+%Cr|MYhf^AENO&RUr)gA_O+9xH zC$=`?(+d`AE zoq#JJFVG-I(A$M!__W*YA~TP;0a^D@#W_esUujOF7xNpol$LVZwAy8qXuhK&dc{!* z>HDieTL4Zy(#L%7dxt-Mj_{7G*y-6mD9ZI{pu+>@6>c!1@7V(^G2r zC4h%-l^(k9z(R6W^S?k%OGswn&_7*|!MS+i_ftXYOHFngOQLr>2r+9F!GBwI{4GQq-d5#Xb0r|{V!{yU z%l_rf#iT55m;Ep{cvtO)zu_&>-i0TowzrL9hP93Wm2wHS6a2jig|7=oP-j;*Hw&ND z$i|P4;C76XS3j+&V;pAK-Ry|Tlw>yxyQQVOYb0s(3$iJkY@Ju-)vEPrsa0V;6ScuD%`<*@nVk_c2YoDhp_ zra*PXZSEwGiZk>UupWaK)g{D@a*jngKu)vB2!!RYXki*w7{{NQK@`nNAln)Q5b_l8 zx6&u4;_x1*Pj*UV4eTNMWTiybpj!3G_q7fMW+9n@I|)%Jx?R1Yq;LW&qI(x9zQ4MiOmQ_kNO(&h72ji~S0rO7ykBw=K28r?OBz?oC;`@W^O)U%19T%5cN{QM z?4zfi_wIG=o45J?i+U-0rMF;Uc#~KWr|nRzSZG(iW!QhyScbZ~aI_xvEC z^t4l_A6`G*c@rSTH)x=4IEd4?$xtnalBz)%?u3XX_=nCY+iU=fM`K<_#U-dPS2&YF zc74lw(d~i4PNziOaPYFFc$=MD5I9xB90NrtD+y`D4}!IHlphIiE$}4fo*wKU9DaA; zypUBTi`WHOk`p?Y;`XoKy9oD%>jVR&iY<_DWG$ck6eUJXX%V>fPtnY78%` zB=c2JFw?6`fNe+?5+W{tB5otlO7r6JB3c3p{I%vYL}n17&raHs|1q!$dQlTO)c`_Y=0qw%p0y2%#r{lK0{B2#jc~-(tkDMG&$4 ziRK|C6 zE-Kx-IZY|itY`t#tUwaYjh$X;u*@(%$Q9BEM)AbF;Jp7SH0O5ueWYeM4@htADUO1G zEmQ~@re8RcKm+**(ktVZAV)^23IRGiD!`keYa911nt8z$!P1rz1I|`ZGjp@C`T(Ss zdU0soY>L(GyjZN43ov|#b9&m7gD$#oc~E!T{AUkS!9dm=tRGmU@U zwxR;?FQGNvGiO&{tXDMRC(*f)jmv*MB7H_Te4-NSgf2LE!*_j;dGX~>DhE$oQwRXY z5f~YDMf8vX!ODQcEl4m^{stOKVD2m@r0R!Gr!Gms?R?q5A(20qZnz@aCk+!=z4RJ)u6rNYH^baR?z~x$XiZ;*Bvd&yH!_EB;@6-dcC!G zcpMA^L-B3wknsA+n=nd`j0T)g)>Q?APppgo*w-cIon(UK?CG&`UdFMXtVDpEvo8{m zBSccLVfcX$8bkAn*gyugu@rOVlxLr~nUTlBzyoU?+t7CbLu8@eHmTrH0Hi|OP@HVJw295k z=M}lm^*wOx0HJy+t4)nYR4Q6Mr%JH`%+?GLagm%v{GhGTkMC|qUKZg}WbS@Z4Xy@O ztuqpLThn9kqnX}dTF$NpbhtrPv4jHk@~7D*FAFSvxeG1dORK0o8#XGzV%LUYJF(?J^pWdTeIeFRA2bU#?yuiq;G4S=K{uz$` zx?F})XGk_JHHca3ZM;y2|5M#ihgzZY^N)MgRLgISu)Mck&7p$xv`Xfh!GFOBZ|(^M@L$iKjMl7o1kg~q=3P{ z(a*Mr`de2`+VIx~tp|KoBo97iApyGC@&>zlsoYL*jOR#YArKYD7L!PaNg1f@`Y$Tc zpDz`@f7$Vr#=H|bwK0F~DUN}flG>Ca|9KRpLWaM~UFMAmOK0RXz;?htv<~<|%3pK6 z-rWrANtjsb!)SsHsuB1$)g!mCIy0_)V2H@@@YyGsw)DeB^&!*j*J=7ZFi8340kXA` z;_?8K%7}FH*eB}4NbOsA#zmPb&Rq%`Ow1uyNGBx(j)YLc6WRtMw0&T;!K>qMk0C$B zCcS}7E3F#v5a&139;l-Tt+%WtFk|e0yZ8Jci7%Q8^77h+R9c^AYFk3kWd|^%!qUPX$sF+mU=W=&KPVotfWC%0w^AQjn_8 zt5@ljGjgv4>UH@I5#N4Szap?%cDZ;u*6k)=jGoTPB{xsKAJ&96{Rz`_`V(zQ zn#kYJtS1974fd`*B&(Qc*)2uxk$kI&1;kW0$E~_sLUC|{D%?b&it0Vp&W+5AI;&|YO3cC3Y2&#tRwO|dZn9H&|# zo7Q1*i(RnnyG@-4=n8W79P)`=PODTQ!AnedeYKt%whR-X2p^cCi@fo>YX5>VJNURt z>G>gukYr5TExLqbuix}GzQy-7j=5H_hldB>Z|BC7rSg&R-L|yXPu}!?2WOuTC=0h6 z(=Z>68Y_G241*0K-Io4qe@wULJoYBE$FN*;`hEHF>CPL22yGl>#@d)ctfz0R{qc8U zI<-pU9;ljO)RJ7?!rr_+k03tQo=zO)|&-ReTqlB>V)+FIjAjm@jq} zbNXPv*mU2CxIsD0jnb_1Rqtcb1|uYALg^@gJK%e~sK;DXBRet0?n$|!Oyk^Fo{EFF zEvMoTUXx0O8tMjRvay0?Ng?fvx%tiTovYglb3ie-^4*bUDN7Zp`dJZy7Nrz0u{tkd zX>$vxMid&$kX7kv296mC%p#ZI=vzY_4cc1NVr(}q!(OS}1_Olnk^lx%$xTA5@`i+D z_nnN)WbRGojA3m%576e=as5#VMbHrpEhj9FFrU_E(SfJ2Nq|U+4cPUFBXTLIja63| z{6;s?pA$JrnjfK1E2VyN$KoArpidlGb<|{jZiediMu@kbmkU9$EfqM`%X-Q{QnO__ zsUFWSFD;Uz`lWVhXoKh!3L7>Y)al~m~*;A9UN&*r%)vmPl;GBSy=WX4nvuvPXQ-?D7i3K3u zhwNeQh@N5z9BH#63g8#zto_ZwFCd%Xl8V%^dS6N&$W0T}@ZZN6Gum_YZ&{Hx4?9A$ z1Ow%6wdNhXi>H5yoo3X1bU%hQY4f-Ut_s;N!>wpR=|D<4p6jX(x7AW1Tr)Q;9m0c^ z5$*}(By48mu1tKTe$0wiX!_wQTR z5c6vYZ;ydu16U%erfyfzXcVllCP=ot~sGcO69X`8H zJct*mth6q=qsk?SFx8_L5aV<)CU8SKyE8)N&48ny&g+}%sGd&iCN5;3DfdAypH#$2 zq+xdxL`z=cbtdgKMDpB&quR`hSKDr7eXt6`J)fj}j2GKDblDH@mUmxm4 zpkvRIE@uU}fO9inVdO6-Q{2Njm$%70US@!Y3gWFo^9JTEw719ns>f-E z7`x>>+6rbQ$0bfQzQXNA>x~RQt<1?8_I$ox8o-bcjbB!{$)ZfMDRL7UhzucU<>tK; zbreG#aFBN-vzwdlzxHW}X5!1UZLF#0`7(CKEzfng?Z|^I$kIaar-H``m@4&@>~_ZM zW;MTohpQ8mb;IogWxeT(NmG#_Lr?YEP(vuIM9#&HVYpm_YSyd`JLhmwvR1R>^-VcJ zPe{1=*ufelK^Z%($17K5RrzPb2P`}@N&$jy;SK|yV^}@!5fhtD4!OzutMT zXZ25z2!-l>KtjPu-K^@FSv~+D&ogmKA}0vX8w);Nv>DE5-ROmkPXqQU9lUyys%gI$ zc}5v#8-^h>X6Z5g$O+kck=sQti(k)XFCgH1_D>HCt2Pu*pw?cIuMe^~zojP+vZB_r zDstp$^WBMKPLlj(V7c1BZlX7X*I)Sss>$|zih!9>m>gg$Jo|p}*~D;oGqATcIii4< z1p6cIL)6R2ff%82$UzxJm&moHg!j70O_1WnMFQ~K;w==1=Ia+^dXcf*JOJ$>J0Ns$ zfAza-`=>Y3==Ib7pWlqwbI+jHWIe_NGb8K>zCY&cWmKy>GZf7modsmS{?lX~Mgr4%>fVCH8GMhHm@B%ly~fkLf_33uoSffav<3Qlb@3 zcxV}~RI%1G)1#r9h%0sLJmsJ*+^hFbTKXK2hsONPz;E*9D87)vM0Y^U83Fq?;zdL? z7Nk&_ccfVIu}Kzu?3$79(7SeeNBjg8K5H_MSB7H_4>PQ%x~=#N4w=E4BOM?ywfZmV z@Lyf_Ecz1H<%*`Iwk8m*GDeR~*))%Rk9g|@YG(NFd!y?AB3#T#wdpV`un=PvRa7SM zt|#Q*dkcrLjP^eq$1Uz+12wm|$s3tuz=#bKs)CcAaTdJP4yeF+EmQ3pRzUEcJf(Tvu z#oMY{WZe6FMl(e{hU%)c& z5rlk4M=n@Dq2)*|cLzWTK}q2O1)BJqXATz;2OQ>BRJGdR<8<9XZFs*x&NjO@2q(I+ zhc{~~UT+#FwQmW+)>f@2w8}9733bfLt@Os<+ARl%7^DE^eOP?Ip3GfKdE{(yFEJuw z^9jVT({HT0hpHS*U+Y`^dvjF2rKgq@61+;&*+pui*@6QHwGFw_KFp{A@QSo^4i2uu ztn=csDDFRM5%+OcL{hyeRR+EfHk7OJfIDbPN6N}(#h1LEy_>(4iI8&w zTC3}u8pm1EH>>%)FtX(xuhhNAIx#`UryV|8EkhtkFOyg}JiB^>yO+zV>J>qxVla?8 zvy>JsOvwTmfcZ)3{NI&E7+#bsP^Nid7Cgm=XZzvG2)S}hfHR5_3?sZ8|vS`X!Ii_mUK1UpH%`O-Rf3a~pq z=HIN#@?|}R1YA^I5@QGQ8mrD#wfg&KU(e?qgBy2B+i-Me{@90~)W2Vql-D96;TM?v z*|fm;SD7DR;Z3<3Uz>TAS0!Kv8GdG~8tPBiSF_TLSNZmz^Eq`Q$BdN}3c=tFxf)>G z?9eaA=F9HiKkLN18)b`?MK^9rpnm zroHzdL85RdJ2&x0&4EiODe)_{80hZs6@0a&H)eHHAw%!{W5@Pp2Wp5BmN*xHG_q=~ zT0)@1b@5mUb{dWut`MavPT%Oj1u4!joYrIO>#c9Lz0^z<{tGXwswV?vVp^_e<7=yu z3Al--Pman@;1OAxF|Kzzoj%=q`x_9V`C`@VzJB9BpOd;8&S!?juWi(ZgwG5c^ER$m ziI@xB!U$JPuO_9T-Lv}tRzvxo=@wE!nhN|FJm74_1a19dIS0mWc(t6b7bzz41@1%; z;UxwZ zN0x%_w=e?$Wd*kZ2YW^#i@*>QEyo8O1|X|*L{<$NP~k?jLz)@WATjpM7VtZry6a>Fy91+IbiJYAjUk!?6~4~KsRMrz_&Jm+QsK!8v4!!bfLA zx|@DLw=19V!VNb2VvuNRxb{8^iy)j`R1UF_+kq0DK7rv(DzHKG_0_duaohp1t3ml7 z_RW7Hb+!NEC!?^?4k??KJ!NzL!$ilT?nD;O21PV4uHofgI#N>zra}NySCAw~Eq`FGzbRz}*~0)%QEPCt`yNE5-`;`^#JyQYl-kBS#_gRxRw z3`*QJb6_&f`%_~&HpmGGln)aMDoTjYkmrRD`4Pd*cRs1m$ygj2tsT1|i^38$j41sY z>2m6&Jf>7|cUG@CF4w?O1BOkia!M}T1AmQESix>_3M&aMyLB@ZW0Ps(*KpJgKsp$D ztt=_0V1|U{+6C_fmqOcPlsbS7Ka@bjLShmCL<-W~L59ZltlJxCno(X3(HnJ_-5mxH zgk{hP?fM!G`wp9-lBfA!fNq3R1jT`)f4V$OU;cb{vj6Sy?ELg#??`VexD*;}8S1LM^W(xe`@oK$bKA~Q``#=5S&;DP&@H)U2$<6YIHG2p@fjT%D zE&}en6=aVXx;qZg^rog-GbzEQPhHc{=>v4$fyc&817%4%d+a(x?lMc8_j_e?E!)s8 z_vUm`^{qSNMe)m%Uw3|a`fJZG7L(X*4+wjY*!$Zx6HG+dMR7gQu3z~1IgZM^i%N1j zLg+$6u?q(OjWoVq6zj!*`nbw%cg<~rAG%a2lpNO074e`y;}Qbq+9mEdb7{neZ$V}) zzn%mm!A+JxfRn9%{0V9y(QsBxOddC^pG>$rScqny9Km~LRjQ~(9Ra8jkyX(UKWVj~ zq_8!&6n(+-tC8FBumv5FggOb<#Xv`>-?YDA4B!kzT=PTMOKy*v>FM>nS!G3rYe1d{Rc$DyxOY*c62<-%h4&%^J^*Nk2Ot4IDu^~36nh#}C_d>_bZm}? z$%&E#!pOmzftgh!4e-OndOfIy#^Z142~GZc@`A_CYa1VMI(FybW?9UF_XbbRY;Dp9 z#d{3sV7vd8B!Fr6xy~8Aq7Em_=^H}Z8AMHm=-z+_*`xvvDe-FhFH)};PMVt1v7-Bb zI^9E`n}3Cy?c$UqFss`UxK~EZlDTgZl&h zlXgV3GlQ(q`2S?HjZ8LP$ss7LK_UA|1VlVSTUP@gs=1Xq$^z@t$codwX+MW@Px{3$ z&4kc=1O6kNYs-3Z+jzD4qG!QuAY~5sNgsyocaPEAm!Kqs%=I1IF3Sb)-G-@$ z2O#@iPpTmFz?<)Ev388H+p(MSZ3SxfZfAMXNq1FmKp1yzzsb)znlH{IJ<=w^X%}<3 z9GQbz76CpHP^lxba$J1=F8^E=cjdTj>5(%=tTwM zO};m|Y_@dWLLS(PR(TH~vfoU&Z%ed!Q8F8SG6DXk_aiM8FS-(V8Yy$oqV%Lb^8cg_ZPtr4=v* z-H1CHmi#w;4`!fmX4_FD0^#=@|@ z)K}q-E#}Q?>w6m-Zq*}4iZ`#yK-_A?*4nC3NvDAz2B8b>hFjmCKUuyZdM(A0ikx!d z$$eN>oXCi(D?qFl)v_D!0YAG6+ zPj|fFp#A-|&YCW@>Lt@}`o3mAAX^%GA>tJ^j6?(3EVa!QGh{}DCsx+sy9%#FOv_y6 zj#$Hmn2m}8a3}2cdarjLKYbH{m|y%zHQF8wVRNYKDQ&{U&KD}zSWgD1LHEnkl=coj ztij`+eV%VGe$NbM^IAB>7#u7ihYJ0506`85vxr0<75SvAy<@$5-e+6Kd+)^yb`sG^ zPtssH4+x`x(@_|_q1??p#l9j;uU(<@;Y`@H`KM zp%-maN{i8nVU$9CJN;Z1l=jR3k5r+X865glB|{8K*tIW5t6o0ZVq7?G7k531h3CW< z@Ld#jbT5rVSfS7W480D(t~tAsfc;1r7u`R|6pzWCSwiAOl#(8S^gLu}Ls(wtq}&t? z83{@O!Hor}001XLr6v=rEGM>+AD;|M8Gj{=sP_$XY5qXr#>N_4rJYNQJpmKsX>cPI zo_n4MUF4CX?Fv)(s^=|FK=cCTv7R|jeG8ohD`Wz8WlK&4c|Pkgh9zg|PuojN_VCw+IJfo=$l(<9DHI`@j3eCviZLX zQkGa%4n9UPt%^P!c45U#{y{+f83>`BNYCo;o5D|ZdW4@LUfR=s?Jxnf5BEeKb^(Hp ze^Qy_{LLeG*hJpQ&K*?0gx*O_bT5X|3`%p2%@xMq{00x0K5rv|wmOHfphBXoVbEk$ zQ2m4kw&3TwQf2C_BXTA~z#ewFqvw-Po2G#vFs>WQGT0=P7@u)jj+Otiz_=z%lr*Uf z2MqTUs*g7!J!r|H^z@ZL8M(1W>@$Otk_3_%0v?2YV`vh@NB08PV+#udVtB^N`vj{B zM!V>(wpm{QQ^bMglX|>zkSv0YkkS|`GwK&3;G*UC9{j;ppuqice`1~3sPEQ}#1_7U zMj#m>i5P91al`uU$Djh=Q06nN+S%qJitF7ie)-k^^Q`)LMIz}Fcimxy_8G!B4#O3U zt07lS_pJnb6DYgL7$bp3uqlxPyUvSYo0pS`-^+Hg**_`QuaR6KHG)Bey#tVJTa>O_ zwr$(CZQHhX*|zPfUADc;wr$(SF5NnPUq|25FFHCSXGTUw##)&fbB#H2{A2!KimBnO z17FP?LId~Rcuw0#l_SGMWwXGm7)2e_3v(SQ%qI=EqFT)-q*jk_6lw6NFQ$? zOkBpCzx}pobBEjUI8+4Wm?_f8(rPo1G9;-N89^G^RHoE~-L<#M8*Mg7uGB(=x?nRyOr9345_k0gBiQh|0ZW2R3k*Y9MG; zGNBwOLs11J*?_Xb5+?Y7*Z**Ff4sH38DaP^D)pzQQ|{T@ znC%E{9t{%0#>Mw<`Tuk*wAj|tBtqBz?=bLYZ0jFW`TGuEQ2F*68{5=>WAx=$l%C)-mVixa)jnWPhW3|D1+Xf;&SYd-D!U-R_YYk5V zf{b-{{6lS|dFKI%&bU;XDGBQ_Z>T^`g0q0xWJ{(d!IlYv5&==k1*7TVv+$E6-z~Bi z7Ghr{^S;Pp0#^iR)>)bpAq~OoMwSvn&4$$-F_zEhiqV!aUhT2Qqil&$rq6gmo4#QR z;;~GVK6SE1Op!^#kZv`vc(%Qd<;59LR;hlHRtIH}b^^^@pq}X8(LQ&rEZYD(sVeF+ z#LI-f5&_Mx@3AF-x}lT(BUQp0*W6`czT*~f+)3%t0ZVHhVLZIN+|5;VTfpT_(sY(k ziNQ2n@zhYq`GRI^yNnks5#cA*FK>LNx#EEqX+@lCBLTmw($T>jh`@kkxqIKB^GC7XnyA%asy zcgbbiz5%rJJLCwdXz(1x7LxIDT{)=ufR(_b72lGK_Xe&3)+2tS8Z^_&_cC(b%XcGw z0j33>w3uqvde(Yv%idJ8V@!>KdU(2ioxJJKd~tj$3xN?sdK3uwjp+u>JsGg1_Wk!)6X@A~XfItV`jb_P=AiHt-neV^0fD-+&_q^%M;P$za zldXentX{)8yaA?t@w`Pt>~{UbQ>(Y5p?>XhRyTG;xRqT=TH1iEy9=k?$~5Hi zJt%8ifmU}9v*d^KX{POf<#rzX_h@>t#dCoYr&9yK(EM75*G@3IugV+F_M(JcF>bL% ztaTvK+D1D`@B8<5RN??!H1p-uxNs1u-R9fGU5ZwO|M+UhFMB%OpHg~@UK}{-2iziy zTWQx@+u{NkM;9>F=C-}`+ltN_GNx#sd0Rlb?Vp9Sbmea&?UgRhGO23s5ayld_iLL_ zdGN)B4_7Xr2s365D7r6b-A8t&Ab{lecOGgs_xm<~<#-Olb|@4g^}qnN3Q?=mzx>4T zIQHZ8M3L7|Bkb(Fue^v+?wQDK*!3lKKqCazc@)n45kwF=8A~?-RD%H!&FJk<)DPH+YukiXt6=Ke&{YWug?cH&?$6csDEs#a4lHx5 zg6q<*9h8D>5Q#e5p$Z@ZvHyDRc9#S46VCc}q5Ra{&H zj8+4q8m>*JTQhe6GF*j8oD^)Ykz7>5Ji+aG1ee8D(`3DwHRJV}CjJC`sNdjVj&PEu zV|EehBZ&zC^}=!7>F@wMo?clz^APT^*yvu-lOaKG6&n=J?Bu2VtNr+b3zXDA)j3;* zO7u!twyWsNJns4O$QxNBN;x1Lot9Dy3N=wnnD!MqW8}MIPT6g zdFzLiuub~I`XP30rRvlIX?48V1S+XM*WoN_A&Hr-qSuA0c7&_=0PjKB2R(W4HH~fa z7-*{AS?-!+YXx9di_yXCew=MYjk=#Kw~4rs-MsWniL8QC7WVjlGo-GwNUoI>6;djy z^}?}B`YaTCwYTkheVj+fOlSI*uswRCg_<6x~dNaRB9mro#xsh=8JeMO6~Zs=zm1ONWm0saUBEy3%jK)`@G#eFTdLWsBTj zXmvaSFVU^sL80!_%aR!PbUC#nvmj2k=niC=Ru;O*V;Tm-kWGD-lH%7F*Z%S%Pzmjw zw&=KIYl{4{$Wg&<4M&#yx?z1UD$6xw z4ex--$tBQsHT;b`XjQdaD^aQiTSI0`te~O(wl7Y(BqBYBh;<-jTclj35mG;qKhc;b zEsqLdb0v}lFO-vR%y$4#?59I;y|b*4!rKymyf#He#G0&w{4 zNfaMZZ3HS3B49!Ryvf3PyAfj7M9bjE0KRuQkv}wm6a=9qwnOb5 z{fhJDJK+wcq+!%XM?Vfi=`69!P#V;DLMysU?I|8I7wOTQ4)u`089XHFKgVhGE0pLP3B|ZZIqI%%x7E8|-mJHb@Bs&rhVG%OV#gjTHUZyJkfsv*pnz@dTF04T;~qj@pWyaZh6JlrH~59riQ3`Xy~ zQ~^6DXB4T^;VxQ4wBYGWCbI&BA%Ka@&Azw8``R`%j#KSV{iU_ELf&R(qM$;Z`3qP| zH=9s9a}BBz-}9aUXQqR2Dezz#!U1v}NgMHK{c093)U_!m`{ zO2|&S0E;n=;P+LbN%Y)R1ZM9<F^r#I2CPPdY>`-T=2z;RW;X!^G{*;{ z;&1e_{g^&Q&6t1}`iV$dg)DHW=$Oww{%n!MQKS?VqNYTqrwlgioZHY5JnBjqEj zVr>FTmy0tCtobDlgDFRV@woxk5~(EoM=_!Oe*9R(_|h;D;cfw1t-3`u&UZ8+jtPN? zIVLvG6dMq=*(7RsjTg&a*vVBMJ{Nz#(g@=9RSNshUGlabAF~Jyn0V7ykASL> zB13n=kepo3Ujk_M*_nxh$4Hh4@3@j{1gI;_TdhL^DL4_ zLzX@KpJhuv)_UAh2q=-}X3@P$;CAi)J^b$$0B=4Z7A|OScy2z)B0Kz7zEyf^jpE># zC_XkLlXXVp1ID1Rd1pgF2wUSv=9PrMNFVdc#Tg{M{eJJKQRdymqvDUo+n4V2rzp=>Ok<2{MMVB@$7E@r)r3%VzJofZ+je+1%JPA|3@RXnWRppCjks^*ip z{Z06g?n0}D1<_i?bse1idhIP4@@fW>9cw7zgv}yJ+%PHY!&>Y7DD|#JK@eHy6OV^c zY6-}jP@-QuQAltX<&x5N1Z-Yb)TCN6FEk>b=>hs9Le00|A0C_Fh)gDJ(OYRT)YP;h zLI^rHS39Wgjzm{TCvW>2fW-v#C0cU0h;-MNDf1BiqhEZ+Uj`t??xG+|$p5bETt5*! zy>&Z%^;K^=)0*->`868N*r0JcQ|*9EL7QA#(Kr;S?vSXemXVxBb(NTV>iMU2pxnYx zK4ZZom!dm>GO1|$tT_gx*g0%)AsijJ zLTFOY8WRf7C@>b7KBJF8Ut8Wg59OpoK!C#>;9a{n+RD-Poe;I+H(iF*iZe^{2{Zv# zs79`p7R^$ecz%7J%3u4tKepM2sMG!Qag3<5C(wq$D4W_Y>BZDp?3-Re(v5Y~|Zk3j> zIfGX12`On=R_&yJD9)(e;V8gK`IeByh4(A%3b{!(=ASiE)7Mg6qOjajU({$Lh`=w< zZbc%2he#jC3vv^j^hpj^hl=2ULr?N&JoTuJ_po%Bw@eyN4)ql;WXZow-f=;YN#UJ6 z>Z-3OsM>m}1`tjixG*YS)z|(eh|%=O&(G|eL>?cIM-B0H=~d?5n8&W{Imi80L4P=b zxAA`kq~qWhMmD|sKClWug2;AT5(2;1c3B(Rr;zlD(K13$Q`eFCP~qJv8S~MB4*B>z zLvqZD^P$!LrA1VOsT8%Fk3mDS0*W6X=x_&;jc=WQPvQ46E{kEf(rI|^#{h8mF@rH< z3gy2g7>-}RLTY0w2*Qw??~0Ihk9mEBD_U7ix^9ezXqpAbIJA-j5EDuUR@j`WAS=TG zO8`jsVJ$Ri>u*S>GKw;hi(h@~w0LY!Dt|6a?c`u{@4&P4dsII9R@dR^KwVeffVZ4X zme}lL=*9HufOEUPT?*v_9hC~s7fGoz)N)Bx!RPY?b?Ptt%)%fmJ$HN}T77=~1(L;; zjd;tiOiFIBk(MX@|H38w!@Ly!gKygm3IG6s2mk>8pM?h}OYXRjbR!ZLlG9 zU8p0BiE43d<9SlRyv`EI)(d0;$T~s+(bFojH6oQKs<3^%_y{MUm~xMyRnw(PPOoIK z56n3|?;nF!x)3iJ4)AsQJ?*1wQReb`J(t=I%Q3he#=?sECB)Q4GgH3p%1+ubaFYI1S`sDJ&UFzLypB_AG16{LTHjFGqifzhxA2I0I=5SG>N^)P;rN^ zGp;Ux@wcE~mt|Ql;KuA8I&b<`UxXR0W6SaT>8-SfO}AOMafKI9V9hn;@`vZxjJvA1 zG*e>Wt&u&wG4YDEAqFH18gy-F&NX-FHXA)4eHxTgF-2J9Q*3woJiK3DlWveL*f$}q zQk=(XPi%Y4N_BllI|Nzg-D0;kz=b!Q70rlFq1xa;2nD5SNdQ=AVhT2+9#j~{Gy&xC zdAjUJDNKXCR;LdEa9n*`W#EC?iVB;1*2tMdruNq`8E-bc_7l;8%;dnB6lTIjNPwFK1xIEi>GIHSDJToJ!@9eO6tQ>38Psdx#Wc+D?>oFA|xN^taE|gNyo|P3QPP6g;PW5mg*H;2NHM!sbM1*13 z#e>sje~5|lVthi<%t`J=urHaNdV&`dY!ZM1AdV(ORsi-k`CN=Rr{K9P&v;D~c&8sDvpH8ur#x1P+x^QO#l?Y# zhF(5yQ$hUlWsPk6X+}D-S7@#BjtE)KA{|Y_4m)5l)>H0#QW^UnmO_a84$)-J3gtLJ zZ|JmcSW+);gbhcYvi$=uZdY6#fwNbwb==^>5?*!bdnPnFR`QV({^roBe&ULxvLWf4 zGC&$=S4zP+8z;h+EirKVI8WEi)rbRPE^4|Nx~ZDEctj~puq`L2vWh=GdCFfWaQB4G zf61TYOJ{R7c(w?{hd)IRcgljuA7aTu#`Ma*%(`a%#;T@R;05Y2yFMCf^2(QpBdIpm z0QFBzYZauVZ&*#H3Y&=3Z{&hj>$;Y=zK4$~vfhC~lZ53<+JhTv2Oao*QgRfI5}j3D zB&Es!5<0Qxe4aM@{C!<0wFXTr@>RfzIjym|?o_B&o9L#qoYf`^F z9!{F2yEw#kwO#IrSTZ2Y;^o5MKmB9m&>40qW<}1#@ zaF>vLA0uPPm+t(ygGR-^?!If^J7!yOq_I8Sj+?sTjH|W5BXZj|< z8ThqdMr>RmQW{I#M*&?vRoP)9TWJH@Om~Ti{gYZy{lr;LY0bp@`u?&%al&noY~KN+ zl}GkNqslvQeQu5y4!Iqe`Ba_);+oJQ@H)ZB}e9Iwk2U!u_W+ z@K$U1{!k`Xu8Yh!J%_!wk&}z=yX+z52u^_vSk*O)?`e3ZbEDxERI*dkrig!KV{7Os zce8EMj>98WcTH#WG3{mOSl=-6^!UORQn$Tg9GV!!n9G}U&ZJ<@F7l}&62UE=`U~)% z=nn<(&+Q*r&VM}l{~n<=w%AD}W`d@AsI4P7aobaM3kGt=GRNegRhOZ*nm+W~T; z5lW{bg9tboPM%%M)CsRfzVgCbyqC@hJczn^1O|``g1gk?=)8{-?qY_xsBsDVIuN~Y zZnjv*r_N+h+5m1eYyA)i%BM;sO{rNY?9O1?!de89E@a3>>Wc-$_5S<@Cd5OqjY5B7 z%e82UL8x`YOmW<}wW}3(wh~My4|~^yvmLP(M5ueapJe74$j6XY20h(+U8liFY$_tm z_JV;{=88v?j+sZXBt#25B{(lYRbk@8aIt@nx7B>{Rb!OaeV|t$!sSe%jMEkA+o$I9 z4_hqOo5+s>i&7<_l#z0gW3yFCBN%Kr3KJKBG>C(W22v$M#1NGT5fK~zxUtEUX0#mc z#ST=1Izw>E6Lgalv31wV)Rs_loM%neg&Ot6R5%+JZnH?2i=Y!}7U&#z?b@>&%7XjC zZf%t?ts69MuaKVt#=OcZujVQ`A?zfo1TE)IvK5@LK&mq)!21Hp79c{~NMH#pMAW@e zDByxiMKy^Mr_juGzZkzFj~|YlmWH7V4U4@0@!cVMiN9NIMw>v8igPN{@~=x63MPe< zHC13_z*sc8z5c+=o10~U7k_L;RIDi@Zla3<<8t*X)xSeS<7H9sK!%;Kh>C4(R_|Gr zNtc9l;RhBXi*_;#r0LvyAHi6uHfUGN9Cjb*wR7p-k^>%?h{HTUOaVP3k+nq1sK_;b z?+{q?VwG9SA=ORS0Z{MWm^twRZRlCCP4c83|8v`Jmlf6?kGbh_&k;e!g67Ce2&s9T zQxEwy&d|kt0OEt1D4vTgh8j{L6%>FS2FP6xemPVEVgHb2Ieb^2Y~s_G>^dHC8;V_P zYr065JvRAAoWuQmxW%iQr`%>I;Nw7=79>;!HU1X{1ZP5 zUb(FJ?%0F@i_KX}RVOBhEe2Tqz$a_?olQ?xM*Hw;KghKncw0`w#~Ve>0r2pxa`nie z`7o!K1w3uvs!vCFT)0P=e|^KPXIl@*wzT_xL1VFBD|JY>^<8RWgUQ>yOR5+{>612- z_i)<`d#dU*azp2){kV0oZlVY?n)C!65yL;L;=uOK((a1Pb&JGp@5aWz_PH6`EbRMS zqjTwf8a+hsq4VZ;$EK5;x4U!Qi@yDtzQMwC1N%<8(Gh}DO1_#2H=>jTUo_ogiZMwF zA&)WH0U^DQ1fp}mC=b`rIf?WwOLGTgQy7K*SFJ_}dGz@ZaWJ^FB$3?j_&@xM=tOZ= zs)d1nclJ*{_rmy%7vYsz*jf2^ZEC`Lt_RKof_8{g?K42gIa+`eJ|8;j|Q)6<@!D@Dc2=)jL;^`Mm zDM6H5acyx4YnhwnQR^P96?PvR5~0KIFf(y7yBRJ6XBIAZ<8E)X^x~v&{Q~}X{4B(# zkohNktNg@Jynn}NBM)a&eJ5u}OFMJ@ek2=7C&qJ#uAmxhr>bC@d#QGECzrk}*pL{2SEl{F1#qZY3|3He3~QM@Ey zNj3y8h*$U?iy<3YLS?KL{T)-9NSF>u_d$-yozxrp)4&M7268(Cwa~&AOck+Pj+RcUsGtp>P?#+>QLXewa*x|T>3GOcjNKSXYd0jAKem2}i%>#a$05l&I1wf6Rp^i#DhcCmsDy8?qISqs_TEtM^A@Nohp%>=D z^SVOS!==qVC5E1^EG&9Z7Geuo?)*aW!^k{dgx^k&fiGEHQHrJG+=iOoTR6Eg8^ZIt z5UWHZxlHes;{_haxec~l$gta!mJXSVA1nmR_3lU3zuVIo+WNymer?1zp4H7Oxk;PB zktCI|R-EcyKSl}7-CO&+5n5i0T6<zLTk=tEuC^rqlm` z+q>eicG}=b-1(%gyI0L|uB;;=Z*)uHwa*#(g>;g*ifV1f#G0g}h(be{4?zd0HK)7f zw`0d50YpNoD3_hlg{6vz1i^y!c@5TU{Cv8DS1&Ow+TcF!o6h6+h`;LSoHd5ZLZ0cj zT%XQZFJ)R1m8zN^44sR+wk#_0sqRrFnmHCu4SE|XxU#79sP;zlr2WWrKxm}Eau_k2 zN$n%zurQSrCx`-_YcbTwor@a(E)K(Yay%B}!DLd~-*L+|*^j!yqb%&&H7A{+LJJIj zEPndq`;}QcAPe{a`pJ_f*-uPsa!&-`_!bYziW)I*h;&!piQAw*wb|0jz1J9ZV$R>0 zD^pv$Et{>+FR!Q9m7~|w;rZR@ahL~qILnmyOR{b%NgZ92tTF#BlDdI^iE8rEc{;#w z6VSg&sVS2p`al!IC(}q9@b;cG&e-FR)WNf0D#$iQ!PV1@yq>In&Zj$D7hhH)gfB7LzZ^nC1y4uI1!(12P1+Qn%-;b?V6Vq;S$`5sI zd)7gkVY#9}YyrAsWa3tHOROjg5vxbOSB(nNDyxd2$cc}a$cLWC3P8Iw>_H!)vKtK| z$gLZ5Qxh0CA$P=2ItSlh-M|-WP+^u_%@EHM9$1w;K6t~C@EaoYd=)XEo_Wjj!qSn( z&tIQj&zF&1U72+!>Nk2iGOLG?pD#xjD>nD{fdRVQJ#Bvs(`Ke-#)hUqJoPuzzu5V+ zf9vb;F#C&C+VRjTN~~G#G5npi?|V86XlG%lbEo;KVHoE60Br%-kGas#!p>3WQEvlo z>(MOqI$ti(8<$_bwBxo;VXk*%!!Zm++Xpp8?*CD2 zhEIoU_QaP=9N+9z@B{P)s#u~|Y60-e2_ti$cngNNlhP~a^78812BEZfMo9#Q zD-SR_n6ntP&_ooG#A22*f?!kzCG{U8LtS-ZX5rC52nFph8D0how>{Kh6ui`rc*CKf zw3ntxx2~5mWTXZgY1E6fkOY9iB9MVRNRQ1gcPpXZ6CpJvurzVX=%ryP1-fckv)ird z{^@Bq+2-vxE@ zLHWKj_x%*8olz~c#eRYD>h3u3Qc<=`}s9*v#!eliGVr_&p1fw`V88S^LOrzDx z&Zvu#RCYp>2)C}1i=IM6(_H8QDlDJS10u5B&NN1s0+c}8D6AG62S)2w0ajRDEw-F9 zT{v$7K8D;T^oHbRsC4NRuhHqS|IR&*QGnRHQ8BYNkYa|(J+S2wHeRVI!=RX%T$|Za zaoKJ8YvMErnQ|$C^KE=|u^^l8kd;wh#~bWD)uF)@*6ZZH59Y!^yG)?`nc}qu@{^r> z;2_q7!_2(Q*Ch($SS!%H@UfuWpEkWaCFk_DO>89AZl1$y$MTJ0n(g#iNa8SPTivBV zH24F@n?x&qI)Q0T+A>B)v;(^y=~sHbA^1eNV8K0EL3$g=T{6I*0GBOR=|(_Pp-J88 z3P-%*+y#S+dfCnMRa05Gbws@X=|lCL`R9 zQt*+@Nqd|zM%=rHZOBSZoD)tCYw;|rz53lE1Rc%^JHldebwq+a631O4Uabc%th)q&w^RefcCaUAaAWG0zy}YW%j@NM?4}Y>QQIC(3!?~7nFZP2&W1(}BU-HNF*G7R?^@Tynje#h!%M;- zHtgK=`EMbJtUh?}DYgZbND*-8nw|Tn)S{1c45N96=Ww6FI?NY|35sMnxYrNNgIMhT zW|nnZ7%b>V`QTRSz*d~yJ@?i1DV$>{B*GvI%)s;^DtB;3s-OG^dupbD*yL@GaGo<~w^nKVv~(8~yi--b8e5 zx`b-CkrSryV3mU`0jcny+@3>@a3VW;xdNgUZ(N8Sr56iIsmm}B@wY&irn>(hL3~`N zIZFnNP4S!UIYPIL50EDQiDvI#UmMhfsgb2o;c!YV6h$wa`Yuo^@^=vZUm&1y*ZRUM zzQ{m$*VJ(@H=OGMcT$X%ylY|n7g7%RgQJ@_FaEX*H@8|Z5r5;>4$s$6y@Ey-C;OPW zw1dT`$GwmcR03Iw2(MeFv;`Dh53 z2Yu}dqT@L?a+D;lGBAy(WUvmAW%N=lGrfzPQvLzDc|FM5Wp9$u?GL0)Nzm;J;KqaJ z?B6{{TIA~xE=3ICt^QlOUB{CaM44!LNj%$!#}U)~?*IMQnhCFKt#j?;Y~`@F%gkUl z{jq4Fimmq1QZ<)zy6(qvQIN`X6Na1$mcg)=k~aJsFxOwZ1cEvh0eOaEj5apLLt71` z3`f#s1&m(~FPM=0-QwXD+9GH~@v*&ZThIKS+0K$;T$w<8&#hmC^J(227N&WD0ye#| ziN`MaEaR}F{|gJk!#uUKo*@1hClqh z_&d-}s||+x>u_8sL~TjNe<@Vh6Oh%2 zTrY&Q|NOj~D%Y-r_zyoi!S?lYdJH$$*Gi)CK^yFfVARoVg04tFh&KC8OXCU#mZmQP z)H2##%O-NIjg8YlMZ?8WF2=7vH3DrpF>N=?=2+ZCHCqZyS*I8Gv-o9HM&fBeRF78i zsYdlIY=ARdN^i!)=tPqK7=3abdSlBwmr^&zUD$5p4DVf<`&(#(xz`;n$9-X&T2@1r zj{uQZvJ8CAfzRAJQv$UCA^de;#{gc)lP@l(3t}$n?2bOVjo)qCewR}1IlA3}<@)HE=*AW?trdlepzK0sG8HOxMFSCHRFTaBBk>iK;;{k(0dlTJ~3 zwJbIVgH~$uSD0Kke_gi4wknAF3AD>VMbSo3Hv8WJ9V`@m-Z6HB%BTb?GQTHp9Jn%I zq_>$_fCS+1!95~E?L9MT?kqaWX6Po5(Zp=5yf{&O9-Ppg0rS^nbDg@ zmL=^GRE2v!Z8u|eIf9`eon~9*m3USF3E9wvwzpqNS65&%3=Bj9xx9y7tjk7`^**;d zyp>a|uxa|7?9*f*%l1Ycpt9sQ`4a@l($b6uAK`mnxy6Cp^s1C$4G2TE+ zNyrlVmO2^D?kY^<2(JS|%%Qn!pu=o%84h)M^usKx4>r3KxAl?hIvS)g)*GUOq)n+n z&3pzZ0Dz?^xM!D}-h*trkrtwL?7x(0r>4>!0FAa0#~TL`yuu-f&kH3dtDY%OiV7?c zKh3bH1R1)7E`^k@Rc_Y^D`#t%Dm&M%PzMY0 zw>?$;N^`|7VY`5zbrYVbFCV!D(ONtvuMp3aNxa<@D72}QqF>>uomQ|RBmWPI*+*fy z(8@wex&EM8q!t0X0dy34k^=8R{n$d1-_V9vI^ViFt*z1~#7@c7s3;T+6_QW+s#84a^N7j;cM!D_JGQC>;b zMciX1#(N?M*QZ>qAmp0$3s~ot(MXiSJJDo-IJ-O;R+Dme7WRlm0L34%iAFs;DJ?Kt zinfmmT&4U!M!|5QnSTz{jlh!M^1&u zYzFIy4GRMY^?q_fOup?Lfq+5O@{OcCqt$CSe<-ARj@pU`~xWfgW9G+nu={2OoN8NcV{3ROX~CKJ7x0#Sn(e&Tb#Fa~Tbe z^M?u)iO;hZqfS8NP4$WcZ}*2ggGNyVS;q{S#HM~Afkjs@jK*fFJP7P+!K*f#u23cy zb8YHw)5hxWyyd$VGViQ0(mETm=Uld*CavLuY zJ?sv(KG}Qkh&;)XsKS)-P-`!)#LQ3k17|THJfhUj4KT?D(`M z(IJXW0){DQZ>@urTFttxL(NbzzyF))LSs|Y`~GvG^`im+VEy|-v$XqBTwH8_YUQRT z|BK+Vr6KL~(`MEEpl+}Uonq*5KE}*P(lSpaOC!IJGY<_$l*VM6XaF!8`Q;-3Kmv%6 zm?9<1y%cTAgVwRb7ogf_5yPaV8Xu9!hLhqVogu*Ky(M5mrcEFZ6Iu4VwPOcwAyrZE zNK%AD>7eY?y}YZ6DjyMNmANk07>Drfj~e&y#c5VsELyf-%gB$hJjh*%c^x_Kc`@4= znoM}4uzJXV+6@mT@fBJLmt)7z9b_``n2(!NkFj%T>8qrX0ceyf-B;HX)x*MU2E-;h zjeU^2fu&mr=Aj78qioS>GyZ$nnR3J}c9gdEi+%oL~My~tw z<=0Ic$g=J`xS3{CNXSXW;I+ax1t(yU0QfG|%amTQsnr|CkXuJOk5cQJMZ!-o_sTwFK(blIxXzY2W>p zBcu#4lP4$wz?_x^lYChQQo^6CJ~L*H*EwcV+7vRnziZu9kF|l#_$t?)Lb0^8g zwv!M32-C0D4DKB%3;&Y0c2*0yw4-v>s$A&ZsPffA!%gEc#FsPHcS%mF9P31^dP>6& z-LrW>8f{-o&XE~*l_hb-BdagHtf%B~ox9`UVLiHB{Qib{RSbu$HtIB*~4icv0a^quEp+LR< z-SSPY)_nL84Yyc6txWrY_pn+2sejJFTkYGn^*Hlw?=^Mv?{youJwPi02mqiA0ssK_ z-{Hi{(A7}i(9zM*<6oNF|FCe!sL9%|3m|m8t5Mv}54N}D}oocILJ<2bs^ z@C7k}qon7oLOoLuF6GKk=8Q8Bztyl53r%ZgqW)p3Na)F<@Qj~_`5yQ8ofEe+M zM+M?U0i{>z7OWPYd0+!-GK-VyJ7Z@IT9&uzp%pZC*N?FIHNuj~FPf)#nK>Yp(>)C8 zm2QU|jJrw}*DOF2sOgRUEV@1jZ>K@z_6ZxPbtKMd7Z94?pn^n;fgbym&!Ro?o9*y% zz9e(JP7`?+%(Yn&%Vl3QR|@%N?ql#k8P+*7U?~f^ajvdcU#!?s-wlf(iSIs(MOfp= zKNp|s;jtA2k%ov&0nDZF7U*^ukjya23+0R+4VvGolb`$ox}X)7#v`{DcPv)Q!yDih zoDotsr2`s`=FAxFdYD#%=l(n!tNH2uheJSKezI^!H|2F{fIfyZ3Zs zUnf&*uPHO+G+`xKQvVDT4m#?iQ>$DnoYgUM4g;P8+5X*cI6|$nqCRNx7RrB?kl>n1 z7t^$6yfQ@kdgb-{V1o;nJ4i*ZKlqxos#bgkxbj5!G48B0x9yy>&0ZnPY|^CHvi_;QGi~{w(+V|H(^pDh{NG zWFt}5e|pnyIsSjL@_)>4V?!Goqo3}ue@*dxLxq%R6C;0(6j_yrqgib;VG)UK zl<^o3KOo?braYjC@lNv`g!GoffJ>z7j|v8(B##Z`P~g_Z`gI=x{zbHt!3+5d#3Vfs zYQg`Vv5-y@iNAeZkbW^K=+ZoqjtQeVs#btPKIr!Z_!`jwou^xvNr-oeRZueylDw88>KC$LZaht+*?@Kg*M?)`+ZmmxpA9%sGSo!FeHyAM%xQg_L0lN86eZLwQs0+MSY zH&K651|1Qd>xD8W2ZuV~71thi3_iqOfN&%@Jg~nbhByGQTGl~y+sxwC0`9LKZv~^GjkJeWkQ7{ zh_?G3XmZ8@30ho@sJ!3u)78FsQL}X*#I)Mjp141hI&(+jgE>qwgPd+3T;*b)_YQSm z&&CWeuBVnLNw`{qS(59UBS>*Z2hc8kHMp72Wa{5u3H*T_`9Gu!5cBC z)?jgLqCMa;EaxWlN&)uG46f^0w5Sy^E&=(E2m}daSGY`ea zsIki91#-OgWjZ4_1s8Pfb54pS#uS0bW5wn9!X&X262HQMo=gSo@mVczPfPc1CBI7x zb(^aDwq*(56l*B=Ig7WuJZ+p4r}+qP}n z&aCv4wry0}wr$(isonjh-_v_vSO15I6>CO}Id0N>8}KhsY?ln07Ro>Gk}5YAR6hY+ zvdN-9ChP@oF_;X=W#zB3*pbQQ?Od{zfS57TG?L6_x~Tkvd6i1Gml#^IAEuYlKfNij zBOyo}zyhC)7^m<&IgKjP{Q<=pWFbR7hSiB#y%^k_az7 zmwC7ChL^!4=#&%Ve9s(VRp*Wdy<+N>P7N$QFSBE(c`>)1Xb9g|7pB_8Fqe>&=%>Qz zVnee~sUXTC0O@Z6^rb<(mSey!K6i1A;2(!>DD`G$na2`kN)Pe_<@1u9K{Zk+3u~)p zJ*arZIs0Vr9Po7VJ-95^4=3GuPC!7ZK$RSM-vtI2nP-fv;=DuF+ti5W^pp?y#I1aR zcaw0j(e>+_@T2)pEkC5cChx2rlqj-9w#&6wuXL7{TBX}eKcJP{O_*Ipni((`_o@|Z zhUuuB@=747x1SN5_es*0?Re@o{u+<4OS(@4r6FWNl#Dne_D(uEsN+x|WlcdhY)~ls zz}T_Cs$jx3ieyoNX-Q+%=Xmj;YJH-1BgS14Y?A5yCLx2^CCu_`iGQDZvUqj07;bSg zxEVFYBfWZxSy*)c@dXv~xh@(6c@6u;2`0gXEAwTQiq1sEMbY-+2+8aPttmTmxOCG((3r7EyZaB$UJ~5zSeYXeojdCtskqKh*AzIcNB$o5;MO=Q?<_icVi&vv-zvOC%(^ZC) z-mlMJi+jQ}5mJQtYo}d(s&QPj=cM;cp@gpg@)6mr6aX2-6dXEV@KM3-LTFM-!%!Ou zqcpOgEXi=+JHStLp?88xhEk$dN8s6;%1ONAxusGLRbSiLRnI+wP5miC8%0Ctwbqf< zzg(!@wN8DkUH@EUanuA$Lb@w08>+8z@*7EgiEz(&M3LXJ|h2wwGFH&FU7i_Yt> z*ASUu)T-%tDsHc?evDJ(DGH2f`-6oxN4`loJt?t?&+3k zSX1hN!n7oGV}-c}7MLH14`@jlZcO+53{^|0E9L~^jN?(sR!fM)5l=S!kFaE5rK{Sg z;q<$mq7WyvY4&)a<%~<~R40~GOWY8av3;@pvvE$>%vo9Hr&=+yF`dfMk~;L3brWA{ z-}n9{4BU^d_zDh2GW44_Zfgn&s)I8@*aG z-m}e@)<(CtFo*G;uwY${@rJ&>ipP{CojOO&B~IJ<=Luj;%}RyI}cs?4|V$@(aT{y5le=sRlo ziL;9b1*eG+&6y-T#7-zLyol{yx%-( zJfZME;QyHkSFts;SN~lxF(U&3;rut5(8TgTDa`*N5ngF%%U`k~`OVbU^-CAjsBHmt z;YLD633O#UQ)MrN!w^l7CJpWmmS0^q{B(DdN@tk_7KC=QAVy%#_&5{bIQ3D7IBr*{ z0S!l}!DGNQv&woxq$z?2F|-LBsyQuF^~w=LITfzi$X}3XYRp zFo4epTKwg*?}((+egG?0)LGEW6geHgcpfjhf=)OI@BL)^7c-GMPUdsL#GMLFk!2_a z2ZEg>m_QKVV|53mn~oWth)~^d&+Q;i&gXfVM8W5Ip0soIX#zBbhNk3OjOJ6b*22vx z(zc7b*!DsWnN`&1Esn^{$UiH(b9b2z3s8Zm5CCKZxTR)${Q!(C_elg?Dv;`%`NQ{gY7j9?P+&ZL0ae-pfu2HaxV(Fr%vM5@#6|qYcdm>_X zJcY5~GrM5~7CJpW8*}b+OJ{t)#7XD{&lS^HTM(*VV&YGTH<84)vq|?3A*;ThlVPlR z|7Q@b+;yRWmJJjfi@dFyGG^8LWw596P(eBJDqPZS*D&2NUSkF;46-Ykp_*}0i!GN> z1%#`W?{s2PPfhW(iYU80#Z*E#Kd}mUj7=8!Q4Q`uG5I%3@P^&F>cS4hmwO0bqIWPQ zP&u-KKEmO6;O3##0*=W===!nj`Z3pLX0^l4^6G+MYrOLp*c+nO>5rJljhMur`@jhs zM%(I}<)!&xYm>m#VbqV)kxBTRHq6>8V<>O^kwuHC2HlEPtaP$YSZ@kXW ztNYqagHNj#FW4`eJyn9Kd9TKEE%P{it>=!?dhU%cdr6Onuu*yN zGASx1jO{y|u$3P-o6S4#ztr5zGV=`s-Jxdf6D?0B@m>^Gw2l&xaAnzK@HR097s8)e zNo;!`J}NkV4g~5moR%76)Fhf+5>%1Fb7E{0EqT@*@zlw4JdnW6Lcdr1l@`Z$1P;FS zT;X#b)XL^vU&l^!FfT~E5T;c3ROAK{v>z7gvE#&ly4+;YQHttd_m07@NpFs(&s=hF z61S5;5Y`J8D7P^A{eN`=4Sk~=1}IFZDL#XA`v0$@)_>Kqpn(2MGUk^Ctcphn1VmQ> z1Vr%Pd@RfjT}<8oe^*HVgLCYX*V1`ItYPPgdU&=H=2QU+SgNJa%d0s`X6^o5OVO>t zv`vAKv&bMAC|H>3A@14zl3?RH9hf0F8p(BV2FHGxJTNEcI;Z>ibdP+^t6OxqZ}7>y zVR+SWqH6O-^DrX~jimQZH4bYkLL<$KN^el?%#!Y!v#y`h4y!ku@!> zFKhLgB{C^~oz3winlMIkzhf4-nU}pJ|3FlCKc=8p+Q`Z|0)E>G*8?AK#BNvIasm2_ zH~&nVzeL}72CtZYD5>p$n;F}w3>vS)c|=-J@1YiF|H3BAFN6Dz)t*tlb5%E_FHP(h z6p^jxbZ*hD;9!2C!Ds03@7eEsUl*>wmx+68eLGisM?(fbMbrlKDT+*5Wq*yh^ksDA zZ_)IF(Ou#O1 zz`jFb`3rUT#R^SjvKd@bUQO#m1fM=$X`%+`#`Lsw_rbrbCf2`6@}aLJ@MLs%?l$=? z%I5i9N*)<}$pZb|>~Osyyj(A6R4caoGTA>xQ++w}AAW%gnX$B#(mj!X9%O-JqsaXU z;J?*5A0d3hO?(=-bG9)roXOJ}ExUSGynHpbf>Z{dYD3`)$PUXZfQ&uKuM21C%+LQk zei;n48`eV#gwsbvc-bM)Uq3O&NX5mNXUJEETpY)TC2e8GUl6r9Fo1wYN|?uczHIdSXT_27gU0wwmxjaEKBB@ed>zKZeyKL|)qj$|9= z`R)DApIg;b*eGo^Xy&_3e=QcD!>!sm#^lDq!5Qcq3hx6?0$U?lmh0Ks`HTB++$Y8I z@Ed>Y&e{n^@Ic|t<-^xR>n(8voWh*7@qQSegr}YvC&30g8jGM&j6n_ zPslO>7OZ9FOyI@E5FbbCIKdQy-}txio?ZT^KQ>tn`&%S*ju`(G-pUG>t`OT8jSIO! zecUqI`X-4x+B~I1|J}kZ!2P)_Ye)|2I8jpZk=fm@><4O#kXk!#YlRZ$A1lYshsMIk zo*F?{RO6H+plq`j{Fq~?#SH})KzdK9dwOXPQMbZem91x{K zBL77N4XsoQ?E~=7LKt9t=5S6Rkc@cpJr+4~Lqo`ek-<8v>m5#zq0G6A%{f416>WnRtVVhq&?Z#4KE|+<& zJ;f%-Nc<6;jilymOj7sp{1_n$$bdA`)o7ZRZEkoUs-mEJj!f~(v83>dup{xTFbl1#bu(7EREGm@yrHksV|lxcpW9QfRR|qm7 zX1>E5(86hDc`uxfyfc47C=C!G^V^bZCyKscvLPbb_Y#kFaSqdlwM6Xj*dZDm6Nd*X z7q2a(c8?Yvd7(y?cX1`CUlmOnD7y!T%;SOZ>v)ki^Xc4tg8$ zPBj!N=Hu49hS-{3KW!pzAd?#(vk5IntG>i`z93N88eq`eRYJwqLLC@%z&HKNfQJD~ zcIdB-tgdvF&J(aSFo$F~*T@#K$K^>=Vfo9GcfOAt=MXuWOP^{Mv?Vct;^d9mY-2ap zGQ@6w{mC4{lVHwrKfw}->7ds&x~jG>SlA5t6D7r->1c?c#ffn@E83N`zoK1b`+$y zcv$+^wJuc1tioWDHe{)9|AXO=7=E;h{A8V-M3?AYKs5u}$HDuMjnA$3zDN6WikF@P znto^xL_VZbXZt)AdYKA_x{qdP(ZboMWEYi@gt&!7D2Ur)1aTPTzXy0MZqfD0u_vb-6_%SxbL_h2xop=NiK>^Jw(#E=elbc=E7$}@1!zYLhX*h?{)C+R zTz+KRP=Q6VkuTz3`v;J5{_FCK!Aj{sMOiZebD3W8)bD;$T+62|B|Urk#1T1uo#|$j zpisQIa>>)3KrYeg1`hC7lLOQ2nSYqlY|+O{^r*dFD0^wUtzkJH1xPYce5i8l`<+K{ zQzP|4QBIxjy>xjZ6<7D-Yq1FvKf=w=b6a%TMCR*8UqNhUVP++e5a2}?s8egpcJ)TO z!Y@i_Z`B!nw+(|eg`i`=xxj-=jutv2$E=B16&%GEOOF3KOmB>@maUnjC9?L=ml}AZ z$v2%`Crw>p(7Gkq2sdCBV|z1X7FmYTXHsEM##BuEnWf3>b4>)=d}8Lz;S+4HY{-yg zC?$#_V@yfx>Dhl=7s>P|o?G_UXMwV8g8GDF{)QRxFYQezph`4GI}2X##(491B%LAyHz27LzY z57X@jvlNu@m)}&#$tu4fK`SsTCvZoCIjTI}KKYvHBF_JdFvhVr^5Z0=Ri^7!Kr$8U zv{cF6OkNs01fOYU*4TO`r@-}>LZ%C!e5q2G1;hA<#dK7G3;>uCIifArs zb5&?HmM@KxdY=^@y?(D|OgdC7ICfXO|Oh!1*(sF($uojeg7}3X-G948j7Zrp`)z58#jXz{>rsSlI;q z>t|wCR|Flhlcei!hIV7w7rYYfWd{5YNyd}0<|^NyhL|^*D(&VuFNvqOG95hPDw1#{ z#pRS9jzAK@*s`qhb|a=^ra7Tnt(0{d1dA0+IHjeyXs`=oh`af5G{w@ueFmKn*Kkk} zH=KgvW*pH~Q116n{y!rcn;!N}w%!$Ho1&MNU~V+Y)7#y9K@S&v&SpFUQS#=-KRx?+ zVU>e-cnde}dX(x&j!Fs8Rfh;lmo=Vg$Dd-ROVt)x#Lsp`olxzf7fmo#DWZEhH?NHO zky?)w$Rlbl1Au3w^aVZ;lsZp|{pt0g==UFeG&ojS3DrW9_$&9PYuM}aq%4HacgB>Dnn-Qa^# zXn5DjU&XGw<&@5FtIe2kw{-g2*{hp^{_g_;JYyBO_U93>tWI{+Rro> ztZs;fq`}d4U})39rO|vY%@i?Tqk^+&!6aKM!tn*q~_%mD@0h_DNy2Ck`rPiaJy@V)wzi^ z-P=OcEiTF5$hHQh>_yhd?ra4-j>JWV_svoW)|BPY^%jZsE_BF|`4EC04kIL;2aBNN z_UvKmQvl%#6#d3>u@}Z528%cftl_#1)wgvs7fv2?n$u$_YeyMjp4g~Kk>a|!RK?bv zMcX;d#Fw68qLyzJ^H?^Cd#u#RgX85*U{Etvduo<+mG49rM6Rms{+_WtJR_hUvD;rf zBiTzW%fE}z+$X?!4FK_T!P(|t&nZE%p#E+u&kirzC;$X7gKGVS_mLJ5+w=|kR{bY^ zn;xqjGivKc=MOQP6P4@ge6rvzwK&6<4Yy{HCKz___(f9pWl8|={LZdk^P#&!S*PAo^i0O+?F{mOW%$9oBU18t+d z%}pmIapJc9!;2_BV*NPbi=(L6+n42c6K8)n0y3sQs8d$N{Y<;tp?tlc$6AwcSqevy zuPc7u3jsluckc8uvEZDQ?FGW4%@|nZb+icQ_m-*Cd6wGCG#q@b5X&YM!-XZreLHjy z`fLt^jz;g-Jj2IVcIt+j{YHb{qu!G$3>Lku#Uk1y;BhRJ7sf^Xyi$5HdtS3%@pj4g zbayx#BA^i)UgCFaF{^|l&y}|{8k=+a@IpKMdV`E(=Th{aZGC*T%PRXs<{fS1&O1`0 zbW;7lB)-rEEC`>AL`8)fVV?Il{)`1t*RnC;t9^6{K3;B|0h1b+G|@OxjY87~8G7}j zH%wDNgJ>o9Y!BJ#Gev(sOf!Nh$nMS6-Y<84#&a8_v1I+|Ko=3k_Ja=skHpNxL+hyw zs!CKq?#%?|O&_qfrHzjRcT?c(SOX!6PNUxHU!b)d4qrXNJD(i6gUxf9VR0I~`McRc zCzKI(pi>7-CL?}xRWFOHvUySs#wXDFdHw==z=;a@#kerBYoDe?#7DOshUZRJvt_$+ zHZ0@5<-^WU=eJ&}XQI;^0{<|gpTfwyiwLXj7(MtSdcIT{W7A3P+JDImex8iXS7R6+KM*Y5(HN$xK&F~2 zkveULMj5e>k&;we@r0z&ut#g_GxSlA3u!=`J7u;aURBgB(q{2nnMMbT>htA%v<|~; zwpls4ndEphFl>!>=v|zrH`;tmZ(qN*eQl)r!h{q%r zH%6ISo%4J>W6CyRJ#2Y7YPl9$`vpgv`HiJ%+s^J!kgY}6!YWsidj)|zz_nc0nln|4 zZNftu*{R&JF`)Il+_XUzj;yPzcw2#O3~Im`H-dZ3bqebb07kq8giaIe+&A?x^!H%hK*e>-5b2k*>pTP2b6iFFni0im5%vCaX)mc63tO(8%*KP^MEjuOxc z&99C@#U5P{pr6OB(LuX;`&T4NFUM76G`|Q3^-JPuB4kwyqY&OZgiM(ZKdJM~MY9aO zAb<0Hzea|5>~j?OT~cn4XL7Kdre#DCI_W}$5_4395`^`h#1)9G)qv^AMkAY>qfY!g zHl}ZD1l_iX(7{QdGtF*J=sKc6lNg-5WqsYM3A97Sd}LUB)UZ?ue$A|uIm8BiJ^T?q zU~@tDqC;;v56EvKA#+(f%fI2QG={z&kA{$9LaQ_#hF_mqQ#p~w2W)_7(sQ^MQYV-G zdrg^k%^}yQca9)W*-6L)(C?LgzDV4|>SV~@w$Mh^BH@d8GUS-xLhgMXzUq_rNGklM z!+=d?xjI{lchxc4IDWC*usbV}^H8#tut%WsG;%4G!9f3cTmMt%>4&fGP7%7$8M+@A zhSzu51U0VHJxG^`kPt)pcRn2@_BKkf9z!0$l?mc01CFFDZ0H_3seJ=QHQ$-1?8vce zn3c37MSH%~QlR=2x*bQJVCWfwxoaP!sy7t zZ!)^8|4&{&=MHMGO&TMWX^C#Xqa>7Uo538jDY135A&)b-C8q$2CHfA{9kzL=6ymJF z7R8+d(@Q;&v7Z$Zq()Zw2FqozCROKup~()tE^xR_+S%W8?-h8xom-KJeE) zY~recJ3^F7o{B}UU&2|0t#XP6E$1FFCnBFjWjF+VxCU>My>qnonO>$+7}=)DQ?bnc zblyIpB)<%$omxez5Rm{OgHkkK$UfHS1CT$f=$B_NaXNt<-`~d;i7a{hJI4gBP9fYu zn;iob+wPE=M0(_&3BvUnWw{9kN)r8E&BYzodTQd`@ee}ec)#oVqD!g3y$5bKC(r{V zrG&5tW-{gYen~Gj*)BgzGKfn>5pa@kZ0pLqMRn0)&{4+*lhFKRp#mC6;F?jmO1|)> z<1nIO!S%=h8m5|<_g;A^Zx>T`XAifMJE#_;e8f3Th=5cW6W`o0fa&zq@nnY@Z6_X9 zzcLnu^e4H6)!Nxc-`!oh>Se2ME@$iT82?ppj^f~KXx8xpZ2(}{r*zpm3C zV>giKP@z7)#@&4ivkiJZ^BJDMshWhI3~9N ztI};GS-j8kw8zNd$hM?pay}Y?ML$RqGDqT6WH8n_FV;DkL)z#{EP%L=p^>59cw+eO zs-SwU-bXB%bL5a_Jnl&T?h?aj>XeO4l_e@+R|wrRU_w0-#qwr$s4~oJ>XN&Y*NYqL ze-EszugYuQ?wNjQ7$TZuq#+!48|C)>gyn8c)n!VZy!0a`{7B@V^}NE5KbAM=ej6=; zuOHtqg1%{gJj0UMu*?WSm*Kab=dTM_VQfP4m!+zP8*n9Kf5Cn1xc6}TMJ$FN#Wi`S z=@&>G-g$I#&6{2ri(Lf=IXqF6GIjf;giTGm_|Cqu=1_1>CDk^jW}04! zEfH0j$iU8F&yef;vt|jMd`bLZa6RBFy5OBmgcGVm4LXOv+GZN9dLG;kPAn0Iu)Hb0 zY$0?B(Z$P;;QU}|eKth0Kf!mSA+Rb zY?gEi6_T9Lvbu+|X`gwrSvnb=NtW|Hl{vO zgd8g)fv`PK>xcv?f&%4T0pBixup^gV+~*g34tl-- zc7pWw@)EE}Yk%A3#Wj#HT^B-~URNp6!%RirIrDp4rK>BUD5$fXCrFQtRL zA2E&4g*z#1i4$>Ix=N1H8Vc*A)EQ2~|BPY&(~e;uw8T|>%?4c^LiP|AZ@v1IJ$B4` zKrS@bgwQJ_T6a`s!l?#AeaA6%=Xz#DG`6iPH!j@GKqkRzGa-QNK>)n-*8ZypyD5I# zaY#LA6fKOGs)-V#W=8ygwyU1U-dp&F%NsO|!_qzn%znEQx?8uF1UHVBKul1{PY=w$ zI$qhG?xLKqyYy~ME#xEWweohW9Bdj9YS*`M>;gxG_B^?9UyJH&z8UP`HNg#qt&z%h zRj2p7S}SlWN~>c@XYpB+n!C~Gi$rQ%-6U4bG-bo)SWM8{yO73~!|YMk>H(OfJa>!O zf$$yScU}2t(YXrc)CD&Ni&)`qFyfBd^Mlt%-PuEvpzvP}H8-SWYlBg#~j3gOGrpfCQ-LIT^S$f+HJREKU_U-SwcYxZg+0X3{RkDC9ex& z`Ej8Aw5_?&54ORoN3MCf&UcegDZZPAB}i;Cv}1)ApLgX!-9ENzYwQ~9pZ3g__d(K< zyYrEenL^IgXHQPJNWTrX%cWQzD7##toP(I^ksBo}UU>Al z)$9RBa4exyoLIT~GVT5N^}`(DfX{7_v5NjL9Sjn+(XPCrrgHUBj@YCRjV02H zkFTqHL;tsA5kE;j4B*9@N3PQ#GwA&BvvPJ6fFV?xhL34*>Juho@za^cRAl3tHR)>G%x~=LfJamcH3E6 zYu)jq~dv~iaMBUAN*gWI_hA=~`!ymSPi zoed?l>eSz^7LqU8iR2TM;a({skBWONq-^EDjA#O?1IxJv%w5RWIZl`VO4piHm7b|r z!%=RurWvfepH|Jk+bxMLEbHoN%@dE}Tw&FoC!6flQGRrYVX?%NlG@}8iPJ~oc;Ion zKChU9HdgX5N;QMoK-avl-cH?I>o)3^PuIS(&8Ryr*mlKRo7^u_NnD~Wi%WE@*xI*- z*60#nV4pV3=LeD1S~3>;aJMhwZ@}) zsc;aWR>_T@vA{+NPk! zlin&g>VN9nMSVeNn#Qz9z8F1JV-}SUJ{}csw>HS_%YFZ0t`FZH7BQ?YvTb|q8z7~u zs{GOp>mZ14YjCBQdp0SFrfiv^AWN6M`?RXB82^bVvAw7V-TzY$+QI!_b<_Wa^;?+Q zIQ$3F|35HFODTPJ8w_w^H=Z!WU6i5*d7^!NjD$^7@i) zO5Cz7xGlp=vh(s$pE6QX?a^lEv2i? zJp%|oTR&}F`qlGG#xoona8rDbeVCwjnqbc1q2K&*z2h)BuVl%588i{Yx*QRTnsk-9 z{6E;pdQR(J^G|LUx(ampBlYb27d}M+*q<}l<3%7V=DLHrgBh;#TEbS9wJEHOjjj^m z8F_k`xEAO*!P~ysdqIBaL_F1S^>SQ!MX9Xt5WdtjX4jpPdTxbKLj{c>dq$%CadG2h zP7YeKt6lf%-*#aEr-cuAi6ue;8nIcqql0S%U95Us0)joa;$TxOAV=`+g@rk56y+KE}Kw z2p^?BT_YFo0VvW}*-bnQ3TP<>x&4VEz4T){MBwI|cUpGYpWP<89vyrJO&>uZhM8p@PZ zjD9{6l1-FR%l}%EuSggqyx`y8-s(ja9C66qz=Ka31WE*L_v6CrBPi^C^gMd?9As5P zW#?JTtic$qUV0%w()?b_W6(LKDMbpSTrUb?N~W00q9KeSfj}Ku1KPWY8RQJEyFGnE z$cYggrZ!PNyY{)78Y3T)v6il{y`x32=a}Q}tnB2wZFd)40s>&9rtQN!K%|4n{Wdg_ z+lTRJDdEIyMV@2CNsLr-G@9;G|waiE1(on{DK)%pH1_#y2iN<~iC zr7f)mEM=1g?|^C~VXU%|(sPteUDR<(+EVTu?8q!pgf;oo(_JFr(xw45mPf^hc)Ytv z)Q1r=hZiWs?jxQ+o#Lc!U!+fJW0qWIp{F|^f{FFsl~x?AFTLr4Rd3PX+mhc@r6c%b z6#aTPn4{A%_3MPGB7-RLD{oHFiyeg=uC?yO0U{`(>tF|t`Yu{#e=vtE7m7O+FuJ1d z^inQ%5C=a-QZV0F(T|YN91{pN|4z@&IfzLIj=c(9oL_4dBYR>HFK%KYtM`9^lGcz_ z?11dHH9|&M<0&}g_kywnLu8wM{mhgBQ&Ns^1TQ)&Byum}!XjYr?t0v6HeN8Vz2N_# zenp<>&#_KA<^UU2QVM?}lC*&&6l2;YZ<;(KuzPf@i0SN9PLXq!(brH5Sm%JfP{5GJ z+-Ks-3q9u*CA0{z<`85mUcgkO&7c8H7f9^r-~IKTRYzby+M8-G+`Pf!7LoAr&<03Y zBgB2Vr_3!(jLS0G$=Ub2CWMl*7^Y zHB@J^HOe6?oT;E*FO&yKWK%`_$>G1Neza;)p5unqdMzJDfI4@4M&Dl;S_3gsj)*)j zEWDUgqts2GiOUV7*8(V3RW<2x8t=G6jC%ICAKT8^oLfbnETFGOV993x z!jT#Law0!{Dtc*HF6thxqc-x|Uv!tVg)Clw;iH$pqnC%) z4Nhuhkk3Oq;D+#9_-CI3(ldtBlC2PTdh~#eyZTAe@@vz%oFeW<{ws+bX ze%4d9ekGoEybCI{cqU1nY3koL3f1S0A1rOySJ~pszPxt^6N7(U)|M+2SI$un+`^bz zCN$}!!K=VYTRN$#s8R)dor%$9!s)YPyIC!>2`3UF&db{W>EjeB{RUEp?gE4LLa)1%vR(fYnD-TFu^h@*#&4WfP0 zlIPd*7p8mWhOTT?cl$%Rp_&RL%U(O)K6>HIusbeB`Ks3|h{2V5+Sh1*P6)HJgK(KD z+C95j4indwH2^ZRVVc3ptMS?l?$#~YM*cc-6CP@nwst{3yP*_QS2rIfz#BK-n3c~x z{utOUe?M49hGtvALajWaw&@i)tvm6z_F{I%?8jF2c_~^z-lBQN3&7!`c8G3@?2Lka z#7|c!tdu2(f#@MDPNQjQjC@Mr%6z;-*+(Zh_NJUId1sizRg;?gX6FPCE-3q4Uepi1 z!sv2LXnJ(l_v>tI)*C#PKXYVoxSvYVSM+gmmdO{qR$9Eeq{z5mC|*Z=Zx$n2vTUHt3wVnhG;?6$4p|I>fbklK#@!aohijJ{#NG?l7iX1}y3 z+3c`@goJlMsa3<~aXY!J(SfrX#QM+lnxuBe6`QQYZYtNw$K96uiSq$REFB9JC}$FU zM-aF-x5=6XPgKE60jqWSKfPp=rZD)EFKW`a)I^43x0P&VptIVdl-0m`iN-jthbrc5 zK|c`+Y_h6#fIGIL%Pb2-JlHg#-Cwz+_%F(5ii6G2*Yp8KqVKtg@UnePC_1YG*kepJ zGn|eC_4=D6u^23Tao=8PnJTyrWh^u%xFxyQn-Dl)@SV>x7Px~)FH^KmxzUBR-I zTcq?OkuCK-AFKg~?~-+x-?1@o+qNwN*T)usIuUQ>$qsoT~m|7{Qf_1 z=KtD13)u3Mw*U6e6YPJ#f1FMKpZ*88)OMmX1^#JlYB4XPh2j^eeL0|-rT{5IXMzi* zg)~VVlBlk&9oF-oFL}60mIydu&v;GF1fvJzl(1(`E z*y7Is`~CBL1o;`%Fd>&ED`?S;ITS*ulTMa2)o-HtOhzQ%a7;4^8}G_m9?76<&$~@) zcC~!vaZSy}`zvEaSb~RLsvlvEt3ev;rpE;aoz-gJes()c4>nB%x<8iZpqFtWGvO&5 z)+Z~`WNw!_>n6_Bt?SX|cboPGyLH&G5X|I(mCHog&7c|T(KpxQtwW&AN5vq_5^fzg z`w3bm(YywjJ%rk{3Sk@IdSCbBt@g(jlc$5;3_#c2v^EZFPDzOz_Z=9#>)qR*V;$u- z9aqFv-t{b)*F;h!LJHV;7}RGji=~oTH$EBli?O(Tqz_k#yKbvicwniH+MLSo-eSLD zY;t0eis~kY^5Yl45@jUYNjKdna^rt}V-$YT_BVAgmeZ7X$fdZV}XXJv^7B2 ze&`ypJG7pV>G!%i;Psxs|Dcuy%N>HA%_mRL{-Nai-s>T~3Y<;FQubhDA{jOZYSSgW zO8Mr7H8DB)nFHZ`4I|0f)0n2+`ofS*2H?0|IbJRt|_II_0P!9@GlzwckLY)8|VL*L{qG${a=a3 zZ>5%SF?=+wgiNpl1e}G&wUJeD*`vPE&;ll?B-Vxt_1`Ei+?$mf1q|_)tm1)BqbQ!w zY4>Zm0R~JQ0|cf5v^VhMLl*MV5ej~YzL_c~Y1{3Df7O*F{?vX!qyis1mwzfIYIUr< zOW<<#T8T0$IGBH`$Gx>PQW9g5)AZ{X>Lh3yG;U*V;@g!>6o9E0UF4$Z%ue_pxK0PGIaKAD(;>EsX6nlQ%}S?oDrt{% ztruLUyLJ=n&Ac$*?ua<&Pk94PvQoz>V^XkaS5R;I?itX9b^Qg`INqs`(GklW$+2R0 zyH2O9+Vu+YWMrs&l&in?lniDmxw7t;H_mcCNPalm+H!Q`YHKMoz8sjmfCs|W4{@{% zlx^;YPuQ8goWDP=hNj-aQ)?c`3{2^pUAGNH($Zn{+DTa_EGjk5OP*~8(_duQIlX5- zOH#gb--KhrU?6F9x=k?W<-BF2WR{!RG`$w9W1mUpSzFLsH(OMTo^@LB$z|i@xJ>?9 z_fhvZOd2fo0f$z%8N|{@md2q0WvQZ%V#mV8bJ6)ZxaQY$bLa-|i}@p|r&DxieoIf_>7ACCt`j5(w+k~b>@Cj*oPcKX`8N@i!Dc-Gg_0>> zNxd`L>*=z=7z6NDij+Sw)GS1XkHoj%e0`w~&@ zoVEY1SYfr>9Zz*8?ci|g&T;0XdydoK7**~kGTw9beVvsT)@DBZ^tNvBYXd``Z)nZKi2XoVA%Qrce9Gr)N14-?^ zqaq(?Cnw~;t-L;VBMqq*hP?CB)?7Fgr%nlB{F&?#&mVgpc=AdxVEB27|;9-LmG^xqUv{O@^F|GR{d7$4zvKGI^Zyz{ zVryz^@8qd(V`})Hf{*{w8hjI*}}rwU!pfft#av_FOD=?_wkTlKK?2W&)zA5FIFxG z_m-P`WiYqTHg4}`(zPa@IrAe#)ZHw_T$W)C`bN%1(+~acWID>uaB@&**?e0k3ZSHK ziN}}X7I&Dzlv`=ir`~;`yWWo0*@)4{oR(N{7$N9_mdBHpph@9%ZL=R;Pb)b#mvDxLl2RfJ5*Rw5Ou}1YOE3=T{({; zZz4>apH?yY79{IS-&Qz&xK&&(m+`5l_4#(v)q1>fz#ljDa4?sN-cY*wZPt(Jv1%53 z(~RWGm~(P;A6J66-bcEM0Oe{|AK9bb*cM%=wr#NT6b`pOBk901F2NAIz2~S^uKZ2P z3Q;PxRyV`&PAhUdm_O^dK)#CnE>)=e7$^RvPV3_2<@WjgNdJ8&{4=n--}q+6;K2?D zUeRKEeR&yk$67E+JgMRuT}=@_Vu2sUhD>1X%#5i&sO!i2cM6Ag(Zh-JqkHFUdeGb( z;fLv1t^%r~th!~1j$r1b#yA>*tvkh{qJ7Tzv>dxh!9%Iik2iJ3s1CIDvPU z-wy!k@D#um)4?MyuIR;>w+q%-Sj;z%+FUPvSK+kyD_QB-R&g?ZM}#FTx3M03n>-DC zP}rH_&7XNJ8LnTqWJB^dIVHn6@Pkzy?fxJwb8f(y;w@? z414TEfp=RPW%8V1frZkomPj5fhqHc(56ZZ|(G3hp{B-ig;XS+Gt-s%pM9ozmUi-5{ZWbhmVOBi%@ebSmB5DBWGsEhXLEA>6}PKhW>~ z^xng>pXdB>c;DG;*38~Bv(|d)6D8yX$Tr1`w&%dx=eR7)7Ul+5$6kr_uXQP-S_%?H zPRs(onLB49nq21noW^QeYl#WG8yXrgjO1|~#BD_iep`r>tcWysA1)DeNC|V+IsxaL zJo2r!A>ADNt86HtJ8}U|T;XZ58f>~Qr({_77xjL)LnMM=bOr=dt=d?;hLF7|2$$Cs zB+ehV$25ST=Gc#6(BUkKOuC+@H<>E6!;7n|amIZn1Zuy zqstW0@=NrT^HnoPYM^-DPosi*`FerrcO+p2sPnbq!$G;Hsy%Dy5~>p1ZK;HM*2zXb z{fVjKU5Lmm#Y@JI_1?xT_vk!KQ4C}L2uaXgAV~d*IxF$Zq$u^4HzobNY-0FCg%ONw zZ%eR|xGhesd4ug{yZHLbgtV@lwAh)I1c1Td#i{O4v1y!Mu&dv`G7V(zwG`%aO0aNi z_>wHsB<6Hscd%CBkM?XZPsg-W9(~zHjT+3+1_|D z4m1Qf-x$=LQsR)%6&u-!`@1=sdRP<_W_KDDX03Nnp9)8% zN7STZ*DEC0V~2$fz}c$Crz=z~TvXrDxfN#ie_E~O^H{gL?^ilnTU%L@$W;Jh6t9rw zS2fR|y(VuWk=!m#z)13-Lw?D7aynB zV#$BGG=9&vB2BFGt>%Jo-x(ZsuXseiyF{#5?JMw}iQZXcOzBvR))45lHQL>z>5JaV zfzpgth1o_5yCZWfJq$Rw=ExS?{a^#2G5taIe&lH0THeo(-hct4onE9zLaKg)x(N)rX+>f_{W-lxd zc7%22&ysUVtT$l=Zzm2U?!%oDY6^5vx>6o^yHX`|+k#`hs&P}Yvv@@8eo1_f0ZcPx zT&u_}tv^wX=_&fMtd}V-Qotilg#nv9c%%sAkSw zRBdoTcMjbYl^TUAPUyV0v_1`_M$Q6I1Da$Q8$aTjCzRYQ45hFWr(J6gI+pGvTL=}h zI|!DOLU#%X%PqGJR8EgP84Cd}M8;R@D1)0%SbiQXd*2L;xa-xjDlB^&-RP?trg-so zq`5Ow>}9S>BE+sdQClMG&hjbfnz6zMs`20zUbQ1H*k^sLP&Lpu6ydBe;kr=xq6(1n z0p^XtAK3s8F?&{NzZ9qqbb5;i6rUxuFJ0Bw<{eoaCsSC@tY5(Xt#YektA@zQw6n_s zcCEn6<=P=0$b_p=s2FAJ>RSI@6(4#vD6Hq218wCTy=#h&n#!kMpAZuozGE9lNKw$R zv0f--2bR;?X#FlcIQKQfSZlJ?+v}?WbKM~VYr0*b%EmbRQ7v&|v z$Q$$q`*Z!%qIQhHq-gC7rF6Cl7l8yY$NhY~gQODXWOf~=zis;gZR^P>RFcE6$>jyL z4ZT8b6T6Oo5ke~~Gi9Db82>h412BR|@&$OcRU{@Di+T!lFOB_m`Lq{FA*NU^u)|jh zyH;GcwPfba2ffV(%jT1acxrP~818!!_Osz#ccc&U!K*f)QVW#6#Ek3qeu_A?R>qvT zlsONi3t?**RC^AA;9ST?w$2~ON?J-)#fiydxTD9*?qVW%5MQX<>3d|0Cq4&DIyG)H zp@4TIwk>^+8geb=U%AQVHD9mAgPM6=m_jh8jdpO(;qR1>L1d%P5Nqs{-5DENMnKX-1t*of?aYM0_A~?aI?_FiWK)WfCdf@^mOF!^-g@Ujmq^w=exql_E3prwfgQs$fb?tyaI@%)D>|xQPe5Hv>_Gmx>P>959M6FQLoRpl@Hterv4>7 z`Kr~Rj|TaOTc_vb$_~xjUs(1wST|OB-IU5tc)PW|Obdnz5G~!RwGiaHyVWMpy-b(> zVt-!t0unFhCQ{>`uj61%IA~ET_J~SV&AGOT#DomDd^^~3w!enGW-596^u1N}J+td>oq71_kA$Ppy4_xH?Z zS>PYg2)c7O9T1LE6{qsr%*Dy#UX27RO=-uaZY}O!O8`>{2HY3NcP$@I!gwbEslDcq z8j@w^M8(qF8&?V(fHZ%ep$z&4?h6~^BEiS%i@nc!QDxgSW^68by@FEso;1fc;rixm z(bH<=az~x#3F5a_iWyqUP^GEfShze8EY9a7@%$R#70`}WOrc&_J8`Wb8Ql8spN>=%i zatt&tMxs$j*O9{ZzupK&l-S@EIUw00l|zlh0iENQRuVekAG_PZ$J00~g>vaP=TKQs z7_6~<6Cqn7>i4!SFN#@tyo9|46D?c0Yorc2zRyQ&sl~?BEoP`FsuC2PSX8!-3;m1G z*)UA*0MJa1ExZLzu}Y>;LcaW)sV*EX-j~LXtpr?dFA*e~%=GtdcG;gf5PN9VqYT~3 z(!7W6?0){a|69-pHINlG7E474bU0P$G!y|*^4d!K_h0KvqdjqPNvA^j0ztN*KDWNM z;78}JqpyNLl#$Kiw(T{VY|bvuw56d)Jk2S2tCv_N8vsLH@R9=GG3PcjG*4r0Adw!w zA3}&)E^?u?$d`G*Oh3&|L9XqUD3wJLDR&Sypd1T?Z}c+(pA$qdIAkboda!Q@#qcs} zLKPcVuYeOehG4lGHYdaGt|n7zod$~>d+lT8k8U$RVp%EMuEr>gC_~9$nOx8mj15Sa1l$WNZ4zqlTBcY9yOR;}w8ncp=g%DBl6b2V|CtT%4B;LoDK z1Y?59xk(yEa!;?nMJj;H;5uIznb)-J z3upu|DlqPfNU@E^uJF1Im}#1C^N1Ccr63Hc20s(Z>x9&)$wbOJRjbKDw4JS19gK}- z6{UJ?SyfP`1{V+LORDGNe!kjL3*7pHs2hFRYq-mpJcBeuXd4S@Y!2w!lWRlPqvh;Y z*9*?P7#R1HdCjj+H(XhlvNo(Bh1ryr(eao(5j5e=T#BvoUv!t0Y`s-Gl_GZ~h!JUfSi?DHyDWhK|ite+L z>)Rg6HZYwjpMt`3SGq{N3}?okbt^=}I;4 zQjVx zvEhcM>Nyvp{dWF0LCWmx7-a0tY}9nlomVCJ!co!N2Q;cx(GVHYuZ@h1jj_#9?U$Oe z>0a1XZQ#rBc*`c2<7=(#a)rko_GHf;UmN?LzjGX*Q`sD;%b+ytP4;y-gv?y@JINb~ zGn&=g!$1%MKH4h-?(#!VDIZO@h-0J=hQicN80!^e10DqV>OA5XiGRkBmyGuH!zwQk z*uD-9#D#&azi5#QD}CCY+=$e>CN#FjTu}M#sD`u{h=e4$=lk^A0~Wze)!02Se8&gT z{rF-KtD;f8A?>vWrrPr;Uk`ubky(o)>Y zfo!cKi0{s4KK+9G&8Y1y-Y(=I4Xx;!kp{@I(rS}A(;FUfG|wEdz{0~OP0lyz=iw+uu4Q0Ep{2R_8a(I#g>Za4z2E+uFTM5L8N^>kwfA70sMC@6A9pX|Gz~ zH$uDlzAn=Nn=Mgxq(}>N5|JV!+y0F-Hns&h8_q{W>BO=72>-Tti9m2OQldMfGpDsd zN;GvAWqmugt*O<$6mCLZ@0pa_%1iKhB zbkXmX($Yhsf^=0PT`bbmL0u=p)DQ%8Qzq0Z<=4w5^9v8k8K`eFJ&1eE(@NyHiBi|@ z3T&0%!9WTY?@aHUCG|o7Ooib4PM1||gJZ`~xCb!lu);|y zuqM~&C-~>{sG3EdVbo_jF_Mvm9|%^zqDY;GUjVbI5WuZ|P&xSH+4p;6-ZEUu959Cl z9NeJ^P+$azb<@*<33FiXvYH$SYp!(|-nBR=w{3criawuC6ToOJ#rGi0aHW!rMTj77 zIHsUgK{%XdrO8e)&GVo(pH|4{^|Hl8gZU;76VaWU$u-I`BfZHI`b|(eyRF{D`cqbn zla6Ip%cQ_8h#F3hh`DX?%ViA9dwEwTmz!2DdcFAs$!-Vr;P81PySqxPzW5dXgT9uF z(`?iI-kP27(4^3O!dGl`-K90}|O;iK34OE*jl|Hr4;HpR037|wD zF4j*~xBjrt8&hL|=*n*$^zFTtvIc~}Hg2exY<2Ey=z}|A!mOg&s|OZSmhG}Wd%H~u z(49M`tyR7Mo`*1Nw@d1{UBb<}NmMD;BogGBW%qVadPufo!?vQ~yDyxU@LUL;#YpEPy4xmjh%TWLc!!^f@Q6v8(319&CPNmNIw!QzOud);nW@RA%T9lhrw>i$tA0T?D!W~)25(`vQzX-O!AYvjk6BdGEckLG$a2BoTgHE$x!|`3%bFjuP5v?rq>}43r~hte z&?ya%x(>SZut7=H;pwTTU`6(H5Vim{YxF1D+!UtoTE0@n+a?IiRCEwcUH}pB+_Qt~atY4#`>Q(insX zgcxd7=WS5kA>L_?yc0NE9H)F0AGi)i3=U9~#qzO2WI);U__!LC_u=y z*qL!bd7v07@L@CYRgs0TMQ-Eat#s?5VzyucA0wDKp>VJ3VQi#0K6G^zy0!VWwG=0K zH|!;Y%Q4G4UU5@=!5LCc(>l!MT?p|7;Zejr6O;izQ`S4wV{n`AZx z$!Nz?l-hbd0fsHOVwaJ^72z3u1;|~cI%5k640DnxKcGDowB4!z3M9S5q)!I-Y_7Rz z-=#f(%xc6aNU4C=bme6;u?kglKoxP(&b|=JRuQ6{`24Yx4o3L}(>KMP;mCro#adx_ z{Z1;=TE$AK^se>hCHYgw_l)z;xbH6H!61)n8GwBcT>}bVT5Ex7Kn1KYcd+R)vX@$B zbl!tW_3ORb7MaFeZ{`|(0YrnxA`aR8A=8h9WZ4YTpeC`cp6Iz$AmN(Y-J^1!@qiNf zPN6hkSpeZ;{BS@(SWh3{ZwxJr9;OmQfS!z@rII2v5NHOVF8;&yLv8%VxEc6D(Wr+& z`o%FaN5*U}IV}N`YYw9`^DDy&1x8BpPyZwIr!suS&Ds|Y}OCyPrYQ6KUrPDDM@p4%R zWhE8HAl&0T(d*e_Yvetm)IzAGHYy?ixTiv0#{GBVcMH}TgviZDxcA{Wttbtmu(Tq< zZLw)YqLBws@)vutd1i9U^o)7UI z<^q90bkh4Je0D{WTCOO}AYIbwa1tS?oLX;(JZjKFQ}KwfFEaDp@Cyk;pM2NlA>i&R zx>0>ItC-Xkp{2qSIOE8mElhM->L8I6PQ?p_i92=SiYp+BR)g$PtqgJgwij)i9oIhg zD|C}B_t*7{BF%ob*Y?N0HL+Gv-rdI2?Ly>tNqJ-t(ffr(I-8gwBikT~bCxSd!tl_pneEWtc}vYq_=roYA}jOt zmB?QlODh?^W`;m8~5ctB-lB7uOW+u|46zJ+TX#9H|=}zbfxMs#U*g zD>IWQcb&Bra`PGbr7}&#Way#$>2u}9A$_1Mx=_neU^ZM)p~0h z5EN*-AaNI!d9qkGi1!d;8H2k-@fKj98lK&yqQP2HJdU!U>g)xY5?cW`uhf$7E^A)+ zWkpnb?5!O*pMW4@?z6nbtHq2qAVQ9tw|v+2R+zKb8{K>y_4Cr1`{I;FZ?Vg~(`fy3 z)5NOTL3EFPY{v4V@~BnnY#5cBxZ0h@-dnwd!<%Yg+by&g(m^d2tGRb@FQ{o7F9^#P zP{@+IeJO}t_)}dR4+lSeiC0KJSaiYSV1nO6=*V~D`_|Y^*mDDqA6umx=Qv*mx^Pg# z#(kND$BOKskLjJ0h`)(m%j{8xf4sm4t|yS3yg7;&<@FR;t9xg_S1Bg)?u~}B%n_p*WSTaxTPTueJMu3U z@$qws7$bugnM{WAL0i5Oi<^2IgCbYeDJdLZ*i+Ft0=I-1BQEgBjgqx1PDSUK)$%@ZO~N`m#unzLoNT zN!low6rK}u@a>W#K>~VmngR{U>gTEa#qU!zC~sE85y`!V?0f~ohs09s+b_9s+1!gu zz+vM=m7ydM#`dhr2Z5NZnP?Q8^Wa5pp%}Bbt(<}WyxVtyg8^EUsde#_{ znxI|7>1#1+8t#?|pjMR88EzKnI25|1j%6Jo}CLM*8dK#Zb6#=D2a z)_V93S$)&db zE3Ho=C}pyCUb`XVsGLi5Yvnjw`6PMS(C%>E*|Wo~;v`h>mc7N*5-h*ZO9_(|<#V-? zb!oA0T6G3C({C`kWi7ukY=qJ)4T@xMn5xUek@V3P1zmgUdIo-K)e$fX47nTc(-wHxb3An8NNhl-T`yxezG2(nym7jRj|}lEoJ3C zKS&#yi=nQ`PT{iD`0n(5I;Ft+36IGl*)%GA%zG1S@j2+0)xLEFHin>iS$NJdI4BM4 z8Xc6r*HCyO$kjZ4qZbWXubOU&DZQZX{rkOE;OmamXI$e=M{h zzU&3FHP1!mC{!$%46$ii8IbHbTp7q~)xB)horFMF%C;O|wwl!8@H67@0OuHFWXiI` z0N1h2#PEt$9rSL-PYxZkC*|pX<^tq!r@OT%^)faQE{cZI=F^J9P&LCh& z@sIze96hyol8~FrtE_OEJTqLIbwIL9J;Zw@~z%csf z1|5nw*r0lReW-*HDm=bzZ(p!sv}&H|R7Ne;TE0ZZK_mEHiL`)wuQ`Zd4Tp&3@>MUT zcSv-@lrdw^t|49XsG}e;HF#)9!ixh*9nY?Wg*7F|s9XI8{QE1>Xac~|my|DslRod~ z2&xrJI{qu4kDe#zyxJ_ItPK~Jyp`iGtw)>}iU+<92h232-+HTxlTcQSN=|xo>(OoLj(#_sRF>ULnZhVy-Ja^%*QgLL$rf%zQ=*XDE=rHo zfe2rEVnnp#=;*{-D37sC*R9PQvvk#$eoT-_KEi#C#IzD-bXOE4CQ-|(8sz!5>A1kx z+muXCD$md+q1`pO&(&EL1U4Glb|qY^mZ5rPp0rGH5%sxrdRjDtNCq)PC%8;P;6c2a zw%k(aWnzD?LrDOoMiZAb9O6Zfgf@~Qs{R)n{}uk5%bk^X{adFm8WBEJxw>7-I`j)d zr+UD|F?p=ZTawRoIpvU8mE!lZJ-5ZlRKmF+LPfJ#Hw6^?@HEhi=LqZoi)cii`k?bl zfd$vMENyDELw3E9Xh(TB`%?qk*@Bj@T4VGx6a~+kCbVRg847P7ao$o{|1sWVaJU@q zg`N?xD?+P-#t~*x)O)Ak`|!|R?WZd`Ph{4nD!`tXGazj9hZXPtVKLXJ+zZHOL~Pul zffy8WL?fy2pnP^2K@LA`zI=)d20|B#l2efvp{Sm@+5BN#oKCi=CLcaPXz-Zx08Yf% zhK~y}6`a*ot?m2z`t2vJa>t~-))U7dNlsV%VUI60exw5YnIRLW-2A>|k-m&4}1+*9P*R*mO5+bqH%);~Ob8rvS@OnC<93fLX z2#yd?Wd2CdBM;P!IT|@EQS3ytGWM_3L3}(S1X~EJCCL-Ra`w_vdjKaK~d5`p<<`*IP;jYkI#oXdEQ z252Ch4xK`gu*mBQQ=*U9?R4!J+iFT>r*v_$nuJ6|90gh~$SHQGlImj=m7(0-0sBZo z0nFUHG0ykQmPLxm>9hOqg(A_7RlKsibDMPK(s8}~wk*DmcibO2;Ca9`l5d;AC}-m7 zXYx12iz}ZsIjPyP;*-qA9V)XEDb}5tXDC+K5O(yi%c2H#I&0(}C%`lZN=5I*v5tk{ zq@=}DEd>Q5a^!VU6 z0eT2jyCZ&ojU=@mGw1so=xNLbUn6loe{qang#Po;!@(3vKuz`p0^*P8nb&DseXr*U zuX6)=oxFMG@Lg{ISdRb>=>Q*U)irTD1X+Cy!1sp>;qUc`qak49(Aw(nO>c~V9yFD%jd=4P}5q0u!)CbWr8*4lJ-$Z>dtO+hj{U$&*)Kd!x2=z%@7XT@+f84gd zg(*Nj^o^k%z>a?-|H29n$kZ`5y_kUQs3^cP68;G?D&F77)|NK9|FrfOm=9r{3fGx_ zV*t!iL?9rfCz$qxe_{TkEc|ycBskn2Y%_`d?M?|3?0WIUeflZys)rGXPA%#PXv{ z9K8Jt*~ZS=-df+<;&0T4GW)-9G{IT|%SwQwK>=i&5uRAXB>FGZzn(sSgZ{53{`cVp z`xod#qxblY;{O7DICX#Xi$V4$;Dalmd0&D*47rC3{Yfhep920rhFXbrujf7>UBV3E zM}I&_c+yxsLo0m~OI@UcfHr9a{Lb8u__Il7;){KEep`rxU+_hT0I$|o%U z49X`A|Et^l$5v03jUO{~)I4GMN$2?2c6zvjr&6bn@lhI|;QulHb)O2K{+i-vwaLd6 zC(S=o{BwZ+PYeoxhnjVh{6o-B$_(|adK=KHp`)6c7KBxRWqNi~rkBKx;e@*mfRLSo#J&m+@ z%#?-pUzq+i?&9|tp9Ts%W>mrWFO2^lHt>6tPY;qGQx;(TO!7a#r)>^W^E)!(;d)%BS!@o+|&h=Ktq?@acl?F@6rs hU-*xgcL30TeIEe_EK-1g$N+yGfcJL@+K0FH{{dOjytV)U literal 0 HcmV?d00001 diff --git a/python/setup.py b/python/setup.py index 2644d3e79dea1..cfc83c68e3df5 100644 --- a/python/setup.py +++ b/python/setup.py @@ -194,7 +194,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.4'], + install_requires=['py4j==0.10.6'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e5131e636dc04..1dd0715918042 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1124,7 +1124,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.4-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 59adb7e22d185..fc78bc488b116 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -249,7 +249,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.4-src.zip", + s"$sparkHome/python/lib/py4j-0.10.6-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index f2d9e6b568a9b..bac154e10ae62 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi From ab866f117378e64dba483ead51b769ae7be31d4d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 5 Jul 2017 18:26:28 -0700 Subject: [PATCH 100/779] [SPARK-21248][SS] The clean up codes in StreamExecution should not be interrupted ## What changes were proposed in this pull request? This PR uses `runUninterruptibly` to avoid that the clean up codes in StreamExecution is interrupted. It also removes an optimization in `runUninterruptibly` to make sure this method never throw `InterruptedException`. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #18461 from zsxwing/SPARK-21248. --- .../org/apache/spark/util/UninterruptibleThread.scala | 10 +--------- .../apache/spark/util/UninterruptibleThreadSuite.scala | 5 ++--- .../sql/execution/streaming/StreamExecution.scala | 6 +++++- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 27922b31949b6..6a58ec142dd7f 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -55,9 +55,6 @@ private[spark] class UninterruptibleThread( * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning * from `f`. * - * If this method finds that `interrupt` is called before calling `f` and it's not inside another - * `runUninterruptibly`, it will throw `InterruptedException`. - * * Note: this method should be called only in `this` thread. */ def runUninterruptibly[T](f: => T): T = { @@ -73,12 +70,7 @@ private[spark] class UninterruptibleThread( uninterruptibleLock.synchronized { // Clear the interrupted status if it's set. - if (Thread.interrupted() || shouldInterruptThread) { - shouldInterruptThread = false - // Since it's interrupted, we don't need to run `f` which may be a long computation. - // Throw InterruptedException as we don't have a T to return. - throw new InterruptedException() - } + shouldInterruptThread = Thread.interrupted() || shouldInterruptThread uninterruptible = true } try { diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala index 39b31f8ddeaba..6a190f63ac9d0 100644 --- a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala @@ -68,7 +68,6 @@ class UninterruptibleThreadSuite extends SparkFunSuite { Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS) try { runUninterruptibly { - assert(false, "Should not reach here") } } catch { case _: InterruptedException => hasInterruptedException = true @@ -80,8 +79,8 @@ class UninterruptibleThreadSuite extends SparkFunSuite { t.interrupt() interruptLatch.countDown() t.join() - assert(hasInterruptedException === true) - assert(interruptStatusBeforeExit === false) + assert(hasInterruptedException === false) + assert(interruptStatusBeforeExit === true) } test("nested runUninterruptibly") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index d5f8d2acba92b..10c42a7338e85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -357,7 +357,11 @@ class StreamExecution( if (!NonFatal(e)) { throw e } - } finally { + } finally microBatchThread.runUninterruptibly { + // The whole `finally` block must run inside `runUninterruptibly` to avoid being interrupted + // when a query is stopped by the user. We need to make sure the following codes finish + // otherwise it may throw `InterruptedException` to `UncaughtExceptionHandler` (SPARK-21248). + // Release latches to unblock the user codes since exception can happen in any place and we // may not get a chance to release them startLatch.countDown() From 75b168fd30bb9a52ae223b6f1df73da4b1316f2e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 6 Jul 2017 14:18:50 +0800 Subject: [PATCH 101/779] [SPARK-21308][SQL] Remove SQLConf parameters from the optimizer ### What changes were proposed in this pull request? This PR removes SQLConf parameters from the optimizer rules ### How was this patch tested? The existing test cases Author: gatorsmile Closes #18533 from gatorsmile/rmSQLConfOptimizer. --- .../optimizer/CostBasedJoinReorder.scala | 7 ++-- .../sql/catalyst/optimizer/Optimizer.scala | 36 +++++++++---------- .../optimizer/StarSchemaDetection.scala | 4 ++- .../sql/catalyst/optimizer/expressions.scala | 14 ++++---- .../spark/sql/catalyst/optimizer/joins.scala | 6 ++-- .../BinaryComparisonSimplificationSuite.scala | 2 +- .../BooleanSimplificationSuite.scala | 2 +- .../optimizer/CombiningLimitsSuite.scala | 2 +- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../optimizer/DecimalAggregatesSuite.scala | 2 +- .../optimizer/EliminateMapObjectsSuite.scala | 2 +- .../optimizer/JoinOptimizationSuite.scala | 2 +- .../catalyst/optimizer/JoinReorderSuite.scala | 27 +++++++++++--- .../optimizer/LimitPushdownSuite.scala | 2 +- .../optimizer/OptimizeCodegenSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 24 +++++++------ .../StarJoinCostBasedReorderSuite.scala | 36 ++++++++++++++----- .../optimizer/StarJoinReorderSuite.scala | 25 ++++++++++--- .../optimizer/complexTypesSuite.scala | 2 +- .../spark/sql/catalyst/plans/PlanTest.scala | 4 +-- .../execution/OptimizeMetadataOnlyQuery.scala | 8 ++--- .../spark/sql/execution/SparkOptimizer.scala | 6 ++-- .../internal/BaseSessionStateBuilder.scala | 2 +- 23 files changed, 137 insertions(+), 82 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 3a7543e2141e9..db7baf6e9bc7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -32,7 +32,10 @@ import org.apache.spark.sql.internal.SQLConf * We may have several join reorder algorithms in the future. This class is the entry of these * algorithms, and chooses which one to use. */ -case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { + + private def conf = SQLConf.get + def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.cboEnabled || !conf.joinReorderEnabled) { plan @@ -379,7 +382,7 @@ object JoinReorderDPFilters extends PredicateHelper { if (conf.joinReorderDPStarFilter) { // Compute the tables in a star-schema relationship. - val starJoin = StarSchemaDetection(conf).findStarJoins(items, conditions.toSeq) + val starJoin = StarSchemaDetection.findStarJoins(items, conditions.toSeq) val nonStarJoin = items.filterNot(starJoin.contains(_)) if (starJoin.nonEmpty && nonStarJoin.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 946fa7bae0199..d82af94dbffb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ * Abstract class all optimizers should inherit of, contains the standard batches (extending * Optimizers can override this. */ -abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) +abstract class Optimizer(sessionCatalog: SessionCatalog) extends RuleExecutor[LogicalPlan] { - protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) + protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) def batches: Seq[Batch] = { Batch("Eliminate Distinct", Once, EliminateDistinct) :: @@ -77,11 +77,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) Batch("Operator Optimizations", fixedPoint, Seq( // Operator push down PushProjectionThroughUnion, - ReorderJoin(conf), + ReorderJoin, EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, - LimitPushDown(conf), + LimitPushDown, ColumnPruning, InferFiltersFromConstraints, // Operator combine @@ -92,10 +92,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CombineLimits, CombineUnions, // Constant folding and strength reduction - NullPropagation(conf), + NullPropagation, ConstantPropagation, FoldablePropagation, - OptimizeIn(conf), + OptimizeIn, ConstantFolding, ReorderAssociativeOperator, LikeSimplification, @@ -117,11 +117,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CombineConcats) ++ extendedOperatorOptimizationRules: _*) :: Batch("Check Cartesian Products", Once, - CheckCartesianProducts(conf)) :: + CheckCartesianProducts) :: Batch("Join Reorder", Once, - CostBasedJoinReorder(conf)) :: + CostBasedJoinReorder) :: Batch("Decimal Optimizations", fixedPoint, - DecimalAggregates(conf)) :: + DecimalAggregates) :: Batch("Object Expressions Optimization", fixedPoint, EliminateMapObjects, CombineTypedFilters) :: @@ -129,7 +129,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) ConvertToLocalRelation, PropagateEmptyRelation) :: Batch("OptimizeCodegen", Once, - OptimizeCodegen(conf)) :: + OptimizeCodegen) :: Batch("RewriteSubquery", Once, RewritePredicateSubquery, CollapseProject) :: Nil @@ -178,8 +178,7 @@ class SimpleTestOptimizer extends Optimizer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, - new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)), - new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true))) /** * Remove redundant aliases from a query plan. A redundant alias is an alias that does not change @@ -288,7 +287,7 @@ object RemoveRedundantProject extends Rule[LogicalPlan] { /** * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. */ -case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] { +object LimitPushDown extends Rule[LogicalPlan] { private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = { plan match { @@ -1077,8 +1076,7 @@ object CombineLimits extends Rule[LogicalPlan] { * the join between R and S is not a cartesian product and therefore should be allowed. * The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule. */ -case class CheckCartesianProducts(conf: SQLConf) - extends Rule[LogicalPlan] with PredicateHelper { +object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { /** * Check if a join is a cartesian product. Returns true if * there are no join conditions involving references from both left and right. @@ -1090,7 +1088,7 @@ case class CheckCartesianProducts(conf: SQLConf) } def apply(plan: LogicalPlan): LogicalPlan = - if (conf.crossJoinEnabled) { + if (SQLConf.get.crossJoinEnabled) { plan } else plan transform { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, condition) @@ -1112,7 +1110,7 @@ case class CheckCartesianProducts(conf: SQLConf) * This uses the same rules for increasing the precision and scale of the output as * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]]. */ -case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] { +object DecimalAggregates extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ @@ -1130,7 +1128,7 @@ case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] { we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone)) + DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone)) case _ => we } @@ -1142,7 +1140,7 @@ case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] { val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone)) + DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone)) case _ => ae } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index ca729127e7d1d..1f20b7661489e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -28,7 +28,9 @@ import org.apache.spark.sql.internal.SQLConf /** * Encapsulates star-schema detection logic. */ -case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { +object StarSchemaDetection extends PredicateHelper { + + private def conf = SQLConf.get /** * Star schema consists of one or more fact tables referencing a number of dimension diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 66b8ca62e5e4c..6c83f4790004f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -173,12 +173,12 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 2. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ -case class OptimizeIn(conf: SQLConf) extends Rule[LogicalPlan] { +object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.size > conf.optimizerInSetConversionThreshold) { + if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) } else if (newList.size < list.size) { @@ -414,7 +414,7 @@ object LikeSimplification extends Rule[LogicalPlan] { * equivalent [[Literal]] values. This rule is more specific with * Null value propagation from bottom to top of the expression tree. */ -case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { +object NullPropagation extends Rule[LogicalPlan] { private def isNullLiteral(e: Expression): Boolean = e match { case Literal(null, _) => true case _ => false @@ -423,9 +423,9 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => - Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) + Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone)) case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) => - Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) + Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone)) case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. ae.copy(aggregateFunction = Count(Literal(1))) @@ -552,14 +552,14 @@ object FoldablePropagation extends Rule[LogicalPlan] { /** * Optimizes expressions by replacing according to CodeGen configuration. */ -case class OptimizeCodegen(conf: SQLConf) extends Rule[LogicalPlan] { +object OptimizeCodegen extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: CaseWhen if canCodegen(e) => e.toCodegen() } private def canCodegen(e: CaseWhen): Boolean = { val numBranches = e.branches.size + e.elseValue.size - numBranches <= conf.maxCaseBranchesForCodegen + numBranches <= SQLConf.get.maxCaseBranchesForCodegen } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index bb97e2c808b9f..edbeaf273fd6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.internal.SQLConf * * If star schema detection is enabled, reorder the star join plans based on heuristics. */ -case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Join a list of plans together and push down the conditions into them. * @@ -87,8 +87,8 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe def apply(plan: LogicalPlan): LogicalPlan = plan transform { case ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => - if (conf.starSchemaDetection && !conf.cboEnabled) { - val starJoinPlan = StarSchemaDetection(conf).reorderStarJoins(input, conditions) + if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) { + val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions) if (starJoinPlan.nonEmpty) { val rest = input.filterNot(starJoinPlan.contains(_)) createOrderedJoin(starJoinPlan ++ rest, conditions) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index 2a04bd588dc1d..a313681eeb8f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -33,7 +33,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index c6345b60b744b..56399f4831a6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -35,7 +35,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, PruneFilters) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index ac71887c16f96..87ad81db11b64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -32,7 +32,7 @@ class CombiningLimitsSuite extends PlanTest { Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, SimplifyConditionals) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 25c592b9c1dde..641c89873dcc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -33,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", Once, - OptimizeIn(conf), + OptimizeIn, ConstantFolding, BooleanSimplification) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index cc4fb3a244a98..711294ed61928 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -29,7 +29,7 @@ class DecimalAggregatesSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Decimal Optimizations", FixedPoint(100), - DecimalAggregates(conf)) :: Nil + DecimalAggregates) :: Nil } val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala index d4f37e2a5e877..157472c2293f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -31,7 +31,7 @@ class EliminateMapObjectsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = { Batch("EliminateMapObjects", FixedPoint(50), - NullPropagation(conf), + NullPropagation, SimplifyCasts, EliminateMapObjects) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index a6584aa5fbba7..2f30a78f03211 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -37,7 +37,7 @@ class JoinOptimizationSuite extends PlanTest { CombineFilters, PushDownPredicate, BooleanSimplification, - ReorderJoin(conf), + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 71db4e2e0ec4d..2fb587d50a4cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -24,25 +24,42 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED} class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = new SQLConf().copy(CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true) - object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, - ReorderJoin(conf), + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Batch("Join Reorder", Once, - CostBasedJoinReorder(conf)) :: Nil + CostBasedJoinReorder) :: Nil + } + + var originalConfCBOEnabled = false + var originalConfJoinReorderEnabled = false + + override def beforeAll(): Unit = { + super.beforeAll() + originalConfCBOEnabled = conf.cboEnabled + originalConfJoinReorderEnabled = conf.joinReorderEnabled + conf.setConf(CBO_ENABLED, true) + conf.setConf(JOIN_REORDER_ENABLED, true) + } + + override def afterAll(): Unit = { + try { + conf.setConf(CBO_ENABLED, originalConfCBOEnabled) + conf.setConf(JOIN_REORDER_ENABLED, originalConfJoinReorderEnabled) + } finally { + super.afterAll() + } } /** Set up tables and columns for testing */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index d8302dfc9462d..f50e2e86516f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -32,7 +32,7 @@ class LimitPushdownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Limit pushdown", FixedPoint(100), - LimitPushDown(conf), + LimitPushDown, CombineLimits, ConstantFolding, BooleanSimplification) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala index 9dc6738ba04b3..b71067c0af3a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._ class OptimizeCodegenSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(conf)) :: Nil + val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen) :: Nil } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index d8937321ecb98..6a77580b29a21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -34,10 +34,10 @@ class OptimizeInSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", FixedPoint(10), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, - OptimizeIn(conf)) :: Nil + OptimizeIn) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -159,16 +159,20 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) .analyze - val notOptimizedPlan = OptimizeIn(conf)(plan) - comparePlans(notOptimizedPlan, plan) + withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "10") { + val notOptimizedPlan = OptimizeIn(plan) + comparePlans(notOptimizedPlan, plan) + } // Reduce the threshold to turning into InSet. - val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan) - optimizedPlan match { - case Filter(cond, _) - if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => - // pass - case _ => fail("Unexpected result for OptimizedIn") + withSQLConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "2") { + val optimizedPlan = OptimizeIn(plan) + optimizedPlan match { + case Filter(cond, _) + if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => + // pass + case _ => fail("Unexpected result for OptimizedIn") + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index a23d6266b2840..ada6e2a43ea0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -24,28 +24,46 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = new SQLConf().copy( - CBO_ENABLED -> true, - JOIN_REORDER_ENABLED -> true, - JOIN_REORDER_DP_STAR_FILTER -> true) - object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, - ReorderJoin(conf), + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: - Batch("Join Reorder", Once, - CostBasedJoinReorder(conf)) :: Nil + Batch("Join Reorder", Once, + CostBasedJoinReorder) :: Nil + } + + var originalConfCBOEnabled = false + var originalConfJoinReorderEnabled = false + var originalConfJoinReorderDPStarFilter = false + + override def beforeAll(): Unit = { + super.beforeAll() + originalConfCBOEnabled = conf.cboEnabled + originalConfJoinReorderEnabled = conf.joinReorderEnabled + originalConfJoinReorderDPStarFilter = conf.joinReorderDPStarFilter + conf.setConf(CBO_ENABLED, true) + conf.setConf(JOIN_REORDER_ENABLED, true) + conf.setConf(JOIN_REORDER_DP_STAR_FILTER, true) + } + + override def afterAll(): Unit = { + try { + conf.setConf(CBO_ENABLED, originalConfCBOEnabled) + conf.setConf(JOIN_REORDER_ENABLED, originalConfJoinReorderEnabled) + conf.setConf(JOIN_REORDER_DP_STAR_FILTER, originalConfJoinReorderDPStarFilter) + } finally { + super.afterAll() + } } private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 605c01b7220d1..777c5637201ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -24,19 +24,36 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, STARSCHEMA_DETECTION} +import org.apache.spark.sql.internal.SQLConf._ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, STARSCHEMA_DETECTION -> true) + var originalConfStarSchemaDetection = false + var originalConfCBOEnabled = true + + override def beforeAll(): Unit = { + super.beforeAll() + originalConfStarSchemaDetection = conf.starSchemaDetection + originalConfCBOEnabled = conf.cboEnabled + conf.setConf(STARSCHEMA_DETECTION, true) + conf.setConf(CBO_ENABLED, false) + } + + override def afterAll(): Unit = { + try { + conf.setConf(STARSCHEMA_DETECTION, originalConfStarSchemaDetection) + conf.setConf(CBO_ENABLED, originalConfCBOEnabled) + } finally { + super.afterAll() + } + } object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, - ReorderJoin(conf), + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 0a18858350e1f..3634accf1ec21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -37,7 +37,7 @@ class ComplexTypesSuite extends PlanTest{ Batch("collapse projections", FixedPoint(10), CollapseProject) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation(conf), + NullPropagation, ConstantFolding, BooleanSimplification, SimplifyConditionals, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e9679d3361509..5389bf3389da4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.internal.SQLConf */ trait PlanTest extends SparkFunSuite with PredicateHelper { - // TODO(gatorsmile): remove this from PlanTest and all the analyzer/optimizer rules - protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true) + // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules + protected def conf = SQLConf.get /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 3c046ce494285..5cfad9126986b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -38,12 +38,10 @@ import org.apache.spark.sql.internal.SQLConf * 3. aggregate function on partition columns which have same result w or w/o DISTINCT keyword. * e.g. SELECT col1, Max(col2) FROM tbl GROUP BY col1. */ -case class OptimizeMetadataOnlyQuery( - catalog: SessionCatalog, - conf: SQLConf) extends Rule[LogicalPlan] { +case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.optimizerMetadataOnly) { + if (!SQLConf.get.optimizerMetadataOnly) { return plan } @@ -106,7 +104,7 @@ case class OptimizeMetadataOnlyQuery( val caseInsensitiveProperties = CaseInsensitiveMap(relation.tableMeta.storage.properties) val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) - .getOrElse(conf.sessionLocalTimeZone) + .getOrElse(SQLConf.get.sessionLocalTimeZone) val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1de4f508b89a0..00ff4c8ac310b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -22,16 +22,14 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate -import org.apache.spark.sql.internal.SQLConf class SparkOptimizer( catalog: SessionCatalog, - conf: SQLConf, experimentalMethods: ExperimentalMethods) - extends Optimizer(catalog, conf) { + extends Optimizer(catalog) { override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ - Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ + Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 9d0148117fadf..72d0ddc62303a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -208,7 +208,7 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. */ protected def optimizer: Optimizer = { - new SparkOptimizer(catalog, conf, experimentalMethods) { + new SparkOptimizer(catalog, experimentalMethods) { override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules } From 14a3bb3a008c302aac908d7deaf0942a98c63be7 Mon Sep 17 00:00:00 2001 From: Sumedh Wale Date: Thu, 6 Jul 2017 14:47:22 +0800 Subject: [PATCH 102/779] [SPARK-21312][SQL] correct offsetInBytes in UnsafeRow.writeToStream ## What changes were proposed in this pull request? Corrects offsetInBytes calculation in UnsafeRow.writeToStream. Known failures include writes to some DataSources that have own SparkPlan implementations and cause EXCHANGE in writes. ## How was this patch tested? Extended UnsafeRowSuite.writeToStream to include an UnsafeRow over byte array having non-zero offset. Author: Sumedh Wale Closes #18535 from sumwale/SPARK-21312. --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 2 +- .../scala/org/apache/spark/sql/UnsafeRowSuite.scala | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 86de90984ca00..56994fafe064b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -550,7 +550,7 @@ public void copyFrom(UnsafeRow row) { */ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { if (baseObject instanceof byte[]) { - int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset); + int offsetInByteArray = (int) (baseOffset - Platform.BYTE_ARRAY_OFFSET); out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); } else { int dataRemaining = sizeInBytes; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a32763db054f3..a5f904c621e6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -101,9 +101,22 @@ class UnsafeRowSuite extends SparkFunSuite { MemoryAllocator.UNSAFE.free(offheapRowPage) } } + val (bytesFromArrayBackedRowWithOffset, field0StringFromArrayBackedRowWithOffset) = { + val baos = new ByteArrayOutputStream() + val numBytes = arrayBackedUnsafeRow.getSizeInBytes + val bytesWithOffset = new Array[Byte](numBytes + 100) + System.arraycopy(arrayBackedUnsafeRow.getBaseObject.asInstanceOf[Array[Byte]], 0, + bytesWithOffset, 100, numBytes) + val arrayBackedRow = new UnsafeRow(arrayBackedUnsafeRow.numFields()) + arrayBackedRow.pointTo(bytesWithOffset, Platform.BYTE_ARRAY_OFFSET + 100, numBytes) + arrayBackedRow.writeToStream(baos, null) + (baos.toByteArray, arrayBackedRow.getString(0)) + } assert(bytesFromArrayBackedRow === bytesFromOffheapRow) assert(field0StringFromArrayBackedRow === field0StringFromOffheapRow) + assert(bytesFromArrayBackedRow === bytesFromArrayBackedRowWithOffset) + assert(field0StringFromArrayBackedRow === field0StringFromArrayBackedRowWithOffset) } test("calling getDouble() and getFloat() on null columns") { From 60043f22458668ac7ecba94fa78953f23a6bdcec Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 6 Jul 2017 00:20:26 -0700 Subject: [PATCH 103/779] [SS][MINOR] Fix flaky test in DatastreamReaderWriterSuite. temp checkpoint dir should be deleted ## What changes were proposed in this pull request? Stopping query while it is being initialized can throw interrupt exception, in which case temporary checkpoint directories will not be deleted, and the test will fail. Author: Tathagata Das Closes #18442 from tdas/DatastreamReaderWriterSuite-fix. --- .../spark/sql/streaming/test/DataStreamReaderWriterSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 3de0ae67a3892..e8a6202b8adce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -641,6 +641,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { test("temp checkpoint dir should be deleted if a query is stopped without errors") { import testImplicits._ val query = MemoryStream[Int].toDS.writeStream.format("console").start() + query.processAllAvailable() val checkpointDir = new Path( query.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot) val fs = checkpointDir.getFileSystem(spark.sessionState.newHadoopConf()) From 5800144a54f5c0180ccf67392f32c3e8a51119b1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 6 Jul 2017 15:32:49 +0800 Subject: [PATCH 104/779] [SPARK-21012][SUBMIT] Add glob support for resources adding to Spark Current "--jars (spark.jars)", "--files (spark.files)", "--py-files (spark.submit.pyFiles)" and "--archives (spark.yarn.dist.archives)" only support non-glob path. This is OK for most of the cases, but when user requires to add more jars, files into Spark, it is too verbose to list one by one. So here propose to add glob path support for resources. Also improving the code of downloading resources. ## How was this patch tested? UT added, also verified manually in local cluster. Author: jerryshao Closes #18235 from jerryshao/SPARK-21012. --- .../org/apache/spark/deploy/SparkSubmit.scala | 166 ++++++++++++++---- .../spark/deploy/SparkSubmitArguments.scala | 2 +- .../spark/deploy/SparkSubmitSuite.scala | 68 ++++++- docs/configuration.md | 6 +- 4 files changed, 196 insertions(+), 46 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index d13fb4193970b..abde04062c4b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -17,17 +17,21 @@ package org.apache.spark.deploy -import java.io.{File, IOException} +import java.io._ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL import java.nio.file.Files -import java.security.PrivilegedExceptionAction +import java.security.{KeyStore, PrivilegedExceptionAction} +import java.security.cert.X509Certificate import java.text.ParseException +import javax.net.ssl._ import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Properties +import com.google.common.io.ByteStreams +import org.apache.commons.io.FileUtils import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} import org.apache.hadoop.fs.{FileSystem, Path} @@ -310,33 +314,33 @@ object SparkSubmit extends CommandLineUtils { RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) } - // In client mode, download remote files. - if (deployMode == CLIENT) { - val hadoopConf = new HadoopConfiguration() - args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull - args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull - args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull - args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull - } - - // Require all python files to be local, so we can add them to the PYTHONPATH - // In YARN cluster mode, python files are distributed as regular files, which can be non-local. - // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. - if (args.isPython && !isYarnCluster && !isMesosCluster) { - if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local python files are supported: ${args.primaryResource}") + val hadoopConf = new HadoopConfiguration() + val targetDir = Files.createTempDirectory("tmp").toFile + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = { + FileUtils.deleteQuietly(targetDir) } - val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",") - if (nonLocalPyFiles.nonEmpty) { - printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles") - } - } + }) + // scalastyle:on runtimeaddshutdownhook - // Require all R files to be local - if (args.isR && !isYarnCluster && !isMesosCluster) { - if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") - } + // Resolve glob path for different resources. + args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull + args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull + args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull + args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull + + // In client mode, download remote files. + if (deployMode == CLIENT) { + args.primaryResource = Option(args.primaryResource).map { + downloadFile(_, targetDir, args.sparkProperties, hadoopConf) + }.orNull + args.jars = Option(args.jars).map { + downloadFileList(_, targetDir, args.sparkProperties, hadoopConf) + }.orNull + args.pyFiles = Option(args.pyFiles).map { + downloadFileList(_, targetDir, args.sparkProperties, hadoopConf) + }.orNull } // The following modes are not supported or applicable @@ -841,36 +845,132 @@ object SparkSubmit extends CommandLineUtils { * Download a list of remote files to temp local files. If the file is local, the original file * will be returned. * @param fileList A comma separated file list. + * @param targetDir A temporary directory for which downloaded files + * @param sparkProperties Spark properties * @return A comma separated local files list. */ private[deploy] def downloadFileList( fileList: String, + targetDir: File, + sparkProperties: Map[String, String], hadoopConf: HadoopConfiguration): String = { require(fileList != null, "fileList cannot be null.") - fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",") + fileList.split(",") + .map(downloadFile(_, targetDir, sparkProperties, hadoopConf)) + .mkString(",") } /** * Download a file from the remote to a local temporary directory. If the input path points to * a local path, returns it with no operation. + * @param path A file path from where the files will be downloaded. + * @param targetDir A temporary directory for which downloaded files + * @param sparkProperties Spark properties + * @return A comma separated local files list. */ - private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = { + private[deploy] def downloadFile( + path: String, + targetDir: File, + sparkProperties: Map[String, String], + hadoopConf: HadoopConfiguration): String = { require(path != null, "path cannot be null.") val uri = Utils.resolveURI(path) uri.getScheme match { - case "file" | "local" => - path + case "file" | "local" => path + case "http" | "https" | "ftp" => + val uc = uri.toURL.openConnection() + uc match { + case https: HttpsURLConnection => + val trustStore = sparkProperties.get("spark.ssl.fs.trustStore") + .orElse(sparkProperties.get("spark.ssl.trustStore")) + val trustStorePwd = sparkProperties.get("spark.ssl.fs.trustStorePassword") + .orElse(sparkProperties.get("spark.ssl.trustStorePassword")) + .map(_.toCharArray) + .orNull + val protocol = sparkProperties.get("spark.ssl.fs.protocol") + .orElse(sparkProperties.get("spark.ssl.protocol")) + if (protocol.isEmpty) { + printErrorAndExit("spark ssl protocol is required when enabling SSL connection.") + } + + val trustStoreManagers = trustStore.map { t => + var input: InputStream = null + try { + input = new FileInputStream(new File(t)) + val ks = KeyStore.getInstance(KeyStore.getDefaultType) + ks.load(input, trustStorePwd) + val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + tmf.init(ks) + tmf.getTrustManagers + } finally { + if (input != null) { + input.close() + input = null + } + } + }.getOrElse { + Array({ + new X509TrustManager { + override def getAcceptedIssuers: Array[X509Certificate] = null + override def checkClientTrusted( + x509Certificates: Array[X509Certificate], s: String) {} + override def checkServerTrusted( + x509Certificates: Array[X509Certificate], s: String) {} + }: TrustManager + }) + } + val sslContext = SSLContext.getInstance(protocol.get) + sslContext.init(null, trustStoreManagers, null) + https.setSSLSocketFactory(sslContext.getSocketFactory) + https.setHostnameVerifier(new HostnameVerifier { + override def verify(s: String, sslSession: SSLSession): Boolean = false + }) + + case _ => + } + uc.setConnectTimeout(60 * 1000) + uc.setReadTimeout(60 * 1000) + uc.connect() + val in = uc.getInputStream + val fileName = new Path(uri).getName + val tempFile = new File(targetDir, fileName) + val out = new FileOutputStream(tempFile) + // scalastyle:off println + printStream.println(s"Downloading ${uri.toString} to ${tempFile.getAbsolutePath}.") + // scalastyle:on println + try { + ByteStreams.copy(in, out) + } finally { + in.close() + out.close() + } + tempFile.toURI.toString case _ => val fs = FileSystem.get(uri, hadoopConf) - val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath) + val tmpFile = new File(targetDir, new Path(uri).getName) // scalastyle:off println printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.") // scalastyle:on println fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath)) - Utils.resolveURI(tmpFile.getAbsolutePath).toString + tmpFile.toURI.toString } } + + private def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = { + require(paths != null, "paths cannot be null.") + paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path => + val uri = Utils.resolveURI(path) + uri.getScheme match { + case "local" | "http" | "https" | "ftp" => Array(path) + case _ => + val fs = FileSystem.get(uri, hadoopConf) + Option(fs.globStatus(new Path(uri))).map { status => + status.filter(_.isFile).map(_.getPath.toUri.toString) + }.getOrElse(Array(path)) + } + }.mkString(",") + } } /** Provides utility functions to be used inside SparkSubmit. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 7800d3d624e3e..fd1521193fdee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -520,7 +520,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | (Default: client). | --class CLASS_NAME Your application's main class (for Java / Scala apps). | --name NAME A name of your application. - | --jars JARS Comma-separated list of local jars to include on the driver + | --jars JARS Comma-separated list of jars to include on the driver | and executor classpaths. | --packages Comma-separated list of maven coordinates of jars to include | on the driver and executor classpaths. Will search the local diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b089357e7b868..97357cdbb6083 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.deploy import java.io._ import java.net.URI import java.nio.charset.StandardCharsets +import java.nio.file.Files +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.io.Source import com.google.common.io.ByteStreams -import org.apache.commons.io.{FilenameUtils, FileUtils} +import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} @@ -42,7 +44,6 @@ import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} - trait TestPrematureExit { suite: SparkFunSuite => @@ -726,6 +727,47 @@ class SparkSubmitSuite Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) } + + test("support glob path") { + val tmpJarDir = Utils.createTempDir() + val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) + val jar2 = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpJarDir) + + val tmpFileDir = Utils.createTempDir() + val file1 = File.createTempFile("tmpFile1", "", tmpFileDir) + val file2 = File.createTempFile("tmpFile2", "", tmpFileDir) + + val tmpPyFileDir = Utils.createTempDir() + val pyFile1 = File.createTempFile("tmpPy1", ".py", tmpPyFileDir) + val pyFile2 = File.createTempFile("tmpPy2", ".egg", tmpPyFileDir) + + val tmpArchiveDir = Utils.createTempDir() + val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) + val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", s"${tmpJarDir.getAbsolutePath}/*.jar", + "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", + "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", + "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", + jar2.toString) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 + sysProps("spark.yarn.dist.jars").split(",").toSet should be + (Set(jar1.toURI.toString, jar2.toURI.toString)) + sysProps("spark.yarn.dist.files").split(",").toSet should be + (Set(file1.toURI.toString, file2.toURI.toString)) + sysProps("spark.submit.pyFiles").split(",").toSet should be + (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) + sysProps("spark.yarn.dist.archives").split(",").toSet should be + (Set(archive1.toURI.toString, archive2.toURI.toString)) + } + // scalastyle:on println private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = { @@ -738,7 +780,7 @@ class SparkSubmitSuite assert(outputUri.getScheme === "file") // The path and filename are preserved. - assert(outputUri.getPath.endsWith(sourceUri.getPath)) + assert(outputUri.getPath.endsWith(new Path(sourceUri).getName)) assert(FileUtils.readFileToString(new File(outputUri.getPath)) === FileUtils.readFileToString(new File(sourceUri.getPath))) } @@ -752,25 +794,29 @@ class SparkSubmitSuite test("downloadFile - invalid url") { intercept[IOException] { - SparkSubmit.downloadFile("abc:/my/file", new Configuration()) + SparkSubmit.downloadFile( + "abc:/my/file", Utils.createTempDir(), mutable.Map.empty, new Configuration()) } } test("downloadFile - file doesn't exist") { val hadoopConf = new Configuration() + val tmpDir = Utils.createTempDir() // Set s3a implementation to local file system for testing. hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") // Disable file system impl cache to make sure the test file system is picked up. hadoopConf.set("fs.s3a.impl.disable.cache", "true") intercept[FileNotFoundException] { - SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf) + SparkSubmit.downloadFile("s3a:/no/such/file", tmpDir, mutable.Map.empty, hadoopConf) } } test("downloadFile does not download local file") { // empty path is considered as local file. - assert(SparkSubmit.downloadFile("", new Configuration()) === "") - assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file") + val tmpDir = Files.createTempDirectory("tmp").toFile + assert(SparkSubmit.downloadFile("", tmpDir, mutable.Map.empty, new Configuration()) === "") + assert(SparkSubmit.downloadFile("/local/file", tmpDir, mutable.Map.empty, + new Configuration()) === "/local/file") } test("download one file to local") { @@ -779,12 +825,14 @@ class SparkSubmitSuite val content = "hello, world" FileUtils.write(jarFile, content) val hadoopConf = new Configuration() + val tmpDir = Files.createTempDirectory("tmp").toFile // Set s3a implementation to local file system for testing. hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") // Disable file system impl cache to make sure the test file system is picked up. hadoopConf.set("fs.s3a.impl.disable.cache", "true") val sourcePath = s"s3a://${jarFile.getAbsolutePath}" - val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf) + val outputPath = + SparkSubmit.downloadFile(sourcePath, tmpDir, mutable.Map.empty, hadoopConf) checkDownloadedFile(sourcePath, outputPath) deleteTempOutputFile(outputPath) } @@ -795,12 +843,14 @@ class SparkSubmitSuite val content = "hello, world" FileUtils.write(jarFile, content) val hadoopConf = new Configuration() + val tmpDir = Files.createTempDirectory("tmp").toFile // Set s3a implementation to local file system for testing. hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") // Disable file system impl cache to make sure the test file system is picked up. hadoopConf.set("fs.s3a.impl.disable.cache", "true") val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}") - val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",") + val outputPaths = SparkSubmit.downloadFileList( + sourcePaths.mkString(","), tmpDir, mutable.Map.empty, hadoopConf).split(",") assert(outputPaths.length === sourcePaths.length) sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) => diff --git a/docs/configuration.md b/docs/configuration.md index c785a664c67b1..7dc23e441a7ba 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -422,21 +422,21 @@ Apart from these, the following properties are also available, and may be useful spark.files - Comma-separated list of files to be placed in the working directory of each executor. + Comma-separated list of files to be placed in the working directory of each executor. Globs are allowed. spark.submit.pyFiles - Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps. + Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps. Globs are allowed. spark.jars - Comma-separated list of local jars to include on the driver and executor classpaths. + Comma-separated list of jars to include on the driver and executor classpaths. Globs are allowed. From 6ff05a66fe83e721063efe5c28d2ffeb850fecc7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 6 Jul 2017 15:47:09 +0800 Subject: [PATCH 105/779] [SPARK-20703][SQL] Associate metrics with data writes onto DataFrameWriter operations ## What changes were proposed in this pull request? Right now in the UI, after SPARK-20213, we can show the operations to write data out. However, there is no way to associate metrics with data writes. We should show relative metrics on the operations. #### Supported commands This change supports updating metrics for file-based data writing operations, including `InsertIntoHadoopFsRelationCommand`, `InsertIntoHiveTable`. Supported metrics: * number of written files * number of dynamic partitions * total bytes of written data * total number of output rows * average writing data out time (ms) * (TODO) min/med/max number of output rows per file/partition * (TODO) min/med/max bytes of written data per file/partition #### Commands not supported `InsertIntoDataSourceCommand`, `SaveIntoDataSourceCommand`: The two commands uses DataSource APIs to write data out, i.e., the logic of writing data out is delegated to the DataSource implementations, such as `InsertableRelation.insert` and `CreatableRelationProvider.createRelation`. So we can't obtain metrics from delegated methods for now. `CreateHiveTableAsSelectCommand`, `CreateDataSourceTableAsSelectCommand` : The two commands invokes other commands to write data out. The invoked commands can even write to non file-based data source. We leave them as future TODO. #### How to update metrics of writing files out A `RunnableCommand` which wants to update metrics, needs to override its `metrics` and provide the metrics data structure to `ExecutedCommandExec`. The metrics are prepared during the execution of `FileFormatWriter`. The callback function passed to `FileFormatWriter` will accept the metrics and update accordingly. There is a metrics updating function in `RunnableCommand`. In runtime, the function will be bound to the spark context and `metrics` of `ExecutedCommandExec` and pass to `FileFormatWriter`. ## How was this patch tested? Updated unit tests. Author: Liang-Chi Hsieh Closes #18159 from viirya/SPARK-20703-2. --- .../scala/org/apache/spark/util/Utils.scala | 9 ++ .../command/DataWritingCommand.scala | 75 ++++++++++ .../sql/execution/command/commands.scala | 12 ++ .../datasources/FileFormatWriter.scala | 121 ++++++++++++--- .../InsertIntoHadoopFsRelationCommand.scala | 18 ++- .../sql/sources/PartitionedWriteSuite.scala | 21 +-- .../hive/execution/InsertIntoHiveTable.scala | 8 +- .../sql/hive/execution/SQLMetricsSuite.scala | 139 ++++++++++++++++++ 8 files changed, 362 insertions(+), 41 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 26f61e25da4d3..b4caf68f0afaa 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1002,6 +1002,15 @@ private[spark] object Utils extends Logging { } } + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } + /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala new file mode 100644 index 0000000000000..0c381a2c02986 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.SparkContext +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.datasources.ExecutedWriteSummary +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * A special `RunnableCommand` which writes data out and updates metrics. + */ +trait DataWritingCommand extends RunnableCommand { + + override lazy val metrics: Map[String, SQLMetric] = { + val sparkContext = SparkContext.getActive.get + Map( + "avgTime" -> SQLMetrics.createMetric(sparkContext, "average writing time (ms)"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), + "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") + ) + } + + /** + * Callback function that update metrics collected from the writing operation. + */ + protected def updateWritingMetrics(writeSummaries: Seq[ExecutedWriteSummary]): Unit = { + val sparkContext = SparkContext.getActive.get + var numPartitions = 0 + var numFiles = 0 + var totalNumBytes: Long = 0L + var totalNumOutput: Long = 0L + var totalWritingTime: Long = 0L + + writeSummaries.foreach { summary => + numPartitions += summary.updatedPartitions.size + numFiles += summary.numOutputFile + totalNumBytes += summary.numOutputBytes + totalNumOutput += summary.numOutputRows + totalWritingTime += summary.totalWritingTime + } + + val avgWritingTime = if (numFiles > 0) { + (totalWritingTime / numFiles).toLong + } else { + 0L + } + + metrics("avgTime").add(avgWritingTime) + metrics("numFiles").add(numFiles) + metrics("numOutputBytes").add(totalNumBytes) + metrics("numOutputRows").add(totalNumOutput) + metrics("numParts").add(numPartitions) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 81bc93e7ebcf4..7cd4baef89e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.debug._ +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ @@ -37,6 +38,11 @@ import org.apache.spark.sql.types._ * wrapped in `ExecutedCommand` during execution. */ trait RunnableCommand extends logical.Command { + + // The map used to record the metrics of running the command. This will be passed to + // `ExecutedCommand` during query planning. + lazy val metrics: Map[String, SQLMetric] = Map.empty + def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { throw new NotImplementedError } @@ -49,8 +55,14 @@ trait RunnableCommand extends logical.Command { /** * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. + * + * @param cmd the `RunnableCommand` this operator will run. + * @param children the children physical plans ran by the `RunnableCommand`. */ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan { + + override lazy val metrics: Map[String, SQLMetric] = cmd.metrics + /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 0daffa93b4747..64866630623ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -22,7 +22,7 @@ import java.util.{Date, UUID} import scala.collection.mutable import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl @@ -82,7 +82,7 @@ object FileFormatWriter extends Logging { } /** The result of a successful write task. */ - private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String]) + private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) /** * Basic work flow of this command is: @@ -104,7 +104,7 @@ object FileFormatWriter extends Logging { hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], - refreshFunction: (Seq[TablePartitionSpec]) => Unit, + refreshFunction: (Seq[ExecutedWriteSummary]) => Unit, options: Map[String, String]): Unit = { val job = Job.getInstance(hadoopConf) @@ -196,12 +196,10 @@ object FileFormatWriter extends Logging { }) val commitMsgs = ret.map(_.commitMsg) - val updatedPartitions = ret.flatMap(_.updatedPartitions) - .distinct.map(PartitioningUtils.parsePathFragment) committer.commitJob(job, commitMsgs) logInfo(s"Job ${job.getJobID} committed.") - refreshFunction(updatedPartitions) + refreshFunction(ret.map(_.summary)) } catch { case cause: Throwable => logError(s"Aborting job ${job.getJobID}.", cause) committer.abortJob(job) @@ -247,9 +245,9 @@ object FileFormatWriter extends Logging { try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - val outputPartitions = writeTask.execute(iterator) + val summary = writeTask.execute(iterator) writeTask.releaseResources() - WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions) + WriteTaskResult(committer.commitTask(taskAttemptContext), summary) })(catchBlock = { // If there is an error, release resource and then abort the task try { @@ -273,12 +271,36 @@ object FileFormatWriter extends Logging { * automatically trigger task aborts. */ private trait ExecuteWriteTask { + /** - * Writes data out to files, and then returns the list of partition strings written out. - * The list of partitions is sent back to the driver and used to update the catalog. + * The data structures used to measure metrics during writing. */ - def execute(iterator: Iterator[InternalRow]): Set[String] + protected var totalWritingTime: Long = 0L + protected var timeOnCurrentFile: Long = 0L + protected var numOutputRows: Long = 0L + protected var numOutputBytes: Long = 0L + + /** + * Writes data out to files, and then returns the summary of relative information which + * includes the list of partition strings written out. The list of partitions is sent back + * to the driver and used to update the catalog. Other information will be sent back to the + * driver too and used to update the metrics in UI. + */ + def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary def releaseResources(): Unit + + /** + * A helper function used to determine the size in bytes of a written file. + */ + protected def getFileSize(conf: Configuration, filePath: String): Long = { + if (filePath != null) { + val path = new Path(filePath) + val fs = path.getFileSystem(conf) + fs.getFileStatus(path).getLen() + } else { + 0L + } + } } /** Writes data to a single directory (used for non-dynamic-partition writes). */ @@ -288,24 +310,26 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol) extends ExecuteWriteTask { private[this] var currentWriter: OutputWriter = _ + private[this] var currentPath: String = _ private def newOutputWriter(fileCounter: Int): Unit = { val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) - val tmpFilePath = committer.newTaskTempFile( + currentPath = committer.newTaskTempFile( taskAttemptContext, None, f"-c$fileCounter%03d" + ext) currentWriter = description.outputWriterFactory.newInstance( - path = tmpFilePath, + path = currentPath, dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) } - override def execute(iter: Iterator[InternalRow]): Set[String] = { + override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { var fileCounter = 0 var recordsInFile: Long = 0L newOutputWriter(fileCounter) + while (iter.hasNext) { if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { fileCounter += 1 @@ -314,21 +338,35 @@ object FileFormatWriter extends Logging { recordsInFile = 0 releaseResources() + numOutputRows += recordsInFile newOutputWriter(fileCounter) } val internalRow = iter.next() + val startTime = System.nanoTime() currentWriter.write(internalRow) + timeOnCurrentFile += (System.nanoTime() - startTime) recordsInFile += 1 } releaseResources() - Set.empty + numOutputRows += recordsInFile + + ExecutedWriteSummary( + updatedPartitions = Set.empty, + numOutputFile = fileCounter + 1, + numOutputBytes = numOutputBytes, + numOutputRows = numOutputRows, + totalWritingTime = totalWritingTime) } override def releaseResources(): Unit = { if (currentWriter != null) { try { + val startTime = System.nanoTime() currentWriter.close() + totalWritingTime += (timeOnCurrentFile + System.nanoTime() - startTime) / 1000 / 1000 + timeOnCurrentFile = 0 + numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, currentPath) } finally { currentWriter = null } @@ -348,6 +386,8 @@ object FileFormatWriter extends Logging { // currentWriter is initialized whenever we see a new key private var currentWriter: OutputWriter = _ + private var currentPath: String = _ + /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ private def partitionPathExpression: Seq[Expression] = { desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => @@ -403,19 +443,19 @@ object FileFormatWriter extends Logging { case _ => None } - val path = if (customPath.isDefined) { + currentPath = if (customPath.isDefined) { committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) } else { committer.newTaskTempFile(taskAttemptContext, partDir, ext) } currentWriter = desc.outputWriterFactory.newInstance( - path = path, + path = currentPath, dataSchema = desc.dataColumns.toStructType, context = taskAttemptContext) } - override def execute(iter: Iterator[InternalRow]): Set[String] = { + override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { val getPartitionColsAndBucketId = UnsafeProjection.create( desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns) @@ -429,15 +469,22 @@ object FileFormatWriter extends Logging { // If anything below fails, we should abort the task. var recordsInFile: Long = 0L var fileCounter = 0 + var totalFileCounter = 0 var currentPartColsAndBucketId: UnsafeRow = null val updatedPartitions = mutable.Set[String]() + for (row <- iter) { val nextPartColsAndBucketId = getPartitionColsAndBucketId(row) if (currentPartColsAndBucketId != nextPartColsAndBucketId) { + if (currentPartColsAndBucketId != null) { + totalFileCounter += (fileCounter + 1) + } + // See a new partition or bucket - write to a new partition dir (or a new bucket file). currentPartColsAndBucketId = nextPartColsAndBucketId.copy() logDebug(s"Writing partition: $currentPartColsAndBucketId") + numOutputRows += recordsInFile recordsInFile = 0 fileCounter = 0 @@ -447,6 +494,8 @@ object FileFormatWriter extends Logging { recordsInFile >= desc.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. // Create a new file by increasing the file counter. + + numOutputRows += recordsInFile recordsInFile = 0 fileCounter += 1 assert(fileCounter < MAX_FILE_COUNTER, @@ -455,18 +504,33 @@ object FileFormatWriter extends Logging { releaseResources() newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) } - + val startTime = System.nanoTime() currentWriter.write(getOutputRow(row)) + timeOnCurrentFile += (System.nanoTime() - startTime) recordsInFile += 1 } + if (currentPartColsAndBucketId != null) { + totalFileCounter += (fileCounter + 1) + } releaseResources() - updatedPartitions.toSet + numOutputRows += recordsInFile + + ExecutedWriteSummary( + updatedPartitions = updatedPartitions.toSet, + numOutputFile = totalFileCounter, + numOutputBytes = numOutputBytes, + numOutputRows = numOutputRows, + totalWritingTime = totalWritingTime) } override def releaseResources(): Unit = { if (currentWriter != null) { try { + val startTime = System.nanoTime() currentWriter.close() + totalWritingTime += (timeOnCurrentFile + System.nanoTime() - startTime) / 1000 / 1000 + timeOnCurrentFile = 0 + numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, currentPath) } finally { currentWriter = null } @@ -474,3 +538,20 @@ object FileFormatWriter extends Logging { } } } + +/** + * Wrapper class for the metrics of writing data out. + * + * @param updatedPartitions the partitions updated during writing data out. Only valid + * for dynamic partition. + * @param numOutputFile the total number of files. + * @param numOutputRows the number of output rows. + * @param numOutputBytes the bytes of output data. + * @param totalWritingTime the total writing time in ms. + */ +case class ExecutedWriteSummary( + updatedPartitions: Set[String], + numOutputFile: Int, + numOutputRows: Long, + numOutputBytes: Long, + totalWritingTime: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ab26f2affbce5..0031567d3d288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -21,6 +21,7 @@ import java.io.IOException import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} @@ -29,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. @@ -53,7 +55,7 @@ case class InsertIntoHadoopFsRelationCommand( mode: SaveMode, catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex]) - extends RunnableCommand { + extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName override def children: Seq[LogicalPlan] = query :: Nil @@ -123,8 +125,16 @@ case class InsertIntoHadoopFsRelationCommand( if (doInsertion) { - // Callback for updating metastore partition metadata after the insertion job completes. - def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { + // Callback for updating metric and metastore partition metadata + // after the insertion job completes. + def refreshCallback(summary: Seq[ExecutedWriteSummary]): Unit = { + val updatedPartitions = summary.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) + + // Updating metrics. + updateWritingMetrics(summary) + + // Updating metastore partition metadata. if (partitionsTrackedByCatalog) { val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions if (newPartitions.nonEmpty) { @@ -154,7 +164,7 @@ case class InsertIntoHadoopFsRelationCommand( hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec, - refreshFunction = refreshPartitionsCallback, + refreshFunction = refreshCallback, options = options) // refresh cached files in FileIndex diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index a2f3afe3ce236..6f998aa60faf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -91,15 +91,15 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempDir { f => spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) } } @@ -111,7 +111,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { .option("maxRecordsPerFile", 1) .mode("overwrite") .parquet(f.getAbsolutePath) - assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) } } @@ -138,14 +138,14 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val df = Seq((1, ts)).toDF("i", "ts") withTempPath { f => df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .partitionBy("ts").parquet(f.getAbsolutePath) - val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. checkPartitionValues(files.head, "2016-12-01 08:00:00") @@ -153,18 +153,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // if there isn't timeZone option, then use session local timezone. checkPartitionValues(files.head, "2016-12-01 08:00:00") } } } - - /** Lists files recursively. */ - private def recursiveList(f: File): Array[File] = { - require(f.isDirectory) - val current = f.listFiles - current ++ current.filter(_.isDirectory).flatMap(recursiveList) - } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 223d375232393..cd263e8b6df8e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -31,14 +31,16 @@ import org.apache.hadoop.hive.ql.exec.TaskRunner import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.command.{CommandUtils, RunnableCommand} +import org.apache.spark.sql.execution.command.{CommandUtils, DataWritingCommand} import org.apache.spark.sql.execution.datasources.FileFormatWriter +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.{HiveClientImpl, HiveVersion} @@ -80,7 +82,7 @@ case class InsertIntoHiveTable( partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, - ifPartitionNotExists: Boolean) extends RunnableCommand { + ifPartitionNotExists: Boolean) extends DataWritingCommand { override def children: Seq[LogicalPlan] = query :: Nil @@ -354,7 +356,7 @@ case class InsertIntoHiveTable( hadoopConf = hadoopConf, partitionColumns = partitionAttributes, bucketSpec = None, - refreshFunction = _ => (), + refreshFunction = updateWritingMetrics, options = Map.empty) if (partition.nonEmpty) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala new file mode 100644 index 0000000000000..1ef1988d4c605 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.File + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + +class SQLMetricsSuite extends SQLTestUtils with TestHiveSingleton { + import spark.implicits._ + + /** + * Get execution metrics for the SQL execution and verify metrics values. + * + * @param metricsValues the expected metric values (numFiles, numPartitions, numOutputRows). + * @param func the function can produce execution id after running. + */ + private def verifyWriteDataMetrics(metricsValues: Seq[Int])(func: => Unit): Unit = { + val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + // Run the given function to trigger query execution. + func + spark.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = + spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size == 1) + val executionId = executionIds.head + + val executionData = spark.sharedState.listener.getExecution(executionId).get + val executedNode = executionData.physicalPlanGraph.nodes.head + + val metricsNames = Seq( + "number of written files", + "number of dynamic part", + "number of output rows") + + val metrics = spark.sharedState.listener.getExecutionMetrics(executionId) + + metricsNames.zip(metricsValues).foreach { case (metricsName, expected) => + val sqlMetric = executedNode.metrics.find(_.name == metricsName) + assert(sqlMetric.isDefined) + val accumulatorId = sqlMetric.get.accumulatorId + val metricValue = metrics(accumulatorId).replaceAll(",", "").toInt + assert(metricValue == expected) + } + + val totalNumBytesMetric = executedNode.metrics.find(_.name == "bytes of written output").get + val totalNumBytes = metrics(totalNumBytesMetric.accumulatorId).replaceAll(",", "").toInt + assert(totalNumBytes > 0) + val writingTimeMetric = executedNode.metrics.find(_.name == "average writing time (ms)").get + val writingTime = metrics(writingTimeMetric.accumulatorId).replaceAll(",", "").toInt + assert(writingTime >= 0) + } + + private def testMetricsNonDynamicPartition( + dataFormat: String, + tableName: String): Unit = { + withTable(tableName) { + Seq((1, 2)).toDF("i", "j") + .write.format(dataFormat).mode("overwrite").saveAsTable(tableName) + + val tableLocation = + new File(spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location) + + // 2 files, 100 rows, 0 dynamic partition. + verifyWriteDataMetrics(Seq(2, 0, 100)) { + (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) + .write.format(dataFormat).mode("overwrite").insertInto(tableName) + } + assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) + } + } + + private def testMetricsDynamicPartition( + provider: String, + dataFormat: String, + tableName: String): Unit = { + withTempPath { dir => + spark.sql( + s""" + |CREATE TABLE $tableName(a int, b int) + |USING $provider + |PARTITIONED BY(a) + |LOCATION '${dir.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + val df = spark.range(start = 0, end = 40, step = 1, numPartitions = 1) + .selectExpr("id a", "id b") + + // 40 files, 80 rows, 40 dynamic partitions. + verifyWriteDataMetrics(Seq(40, 40, 80)) { + df.union(df).repartition(2, $"a") + .write + .format(dataFormat) + .mode("overwrite") + .insertInto(tableName) + } + assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) + } + } + + test("writing data out metrics: parquet") { + testMetricsNonDynamicPartition("parquet", "t1") + } + + test("writing data out metrics with dynamic partition: parquet") { + testMetricsDynamicPartition("parquet", "parquet", "t1") + } + + test("writing data out metrics: hive") { + testMetricsNonDynamicPartition("hive", "t1") + } + + test("writing data out metrics dynamic partition: hive") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + testMetricsDynamicPartition("hive", "hive", "t1") + } + } +} From b8e4d567a7d6c2ff277700d4e7707e57e87c7808 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Thu, 6 Jul 2017 16:00:31 +0800 Subject: [PATCH 106/779] [SPARK-21324][TEST] Improve statistics test suites ## What changes were proposed in this pull request? 1. move `StatisticsCollectionTestBase` to a separate file. 2. move some test cases to `StatisticsCollectionSuite` so that `hive/StatisticsSuite` only keeps tests that need hive support. 3. clear up some test cases. ## How was this patch tested? Existing tests. Author: wangzhenhua Author: Zhenhua Wang Closes #18545 from wzhfy/cleanStatSuites. --- .../spark/sql/StatisticsCollectionSuite.scala | 193 +++--------------- .../sql/StatisticsCollectionTestBase.scala | 192 +++++++++++++++++ .../spark/sql/hive/StatisticsSuite.scala | 124 +++-------- 3 files changed, 258 insertions(+), 251 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index d9392de37a815..843ced7f0e697 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -17,19 +17,12 @@ package org.apache.spark.sql -import java.{lang => jl} -import java.sql.{Date, Timestamp} - import scala.collection.mutable -import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ @@ -58,6 +51,37 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("analyzing views is not supported") { + def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { + val err = intercept[AnalysisException] { + sql(analyzeCommand) + } + assert(err.message.contains("ANALYZE TABLE is not supported")) + } + + val tableName = "tbl" + withTable(tableName) { + spark.range(10).write.saveAsTable(tableName) + val viewName = "view" + withView(viewName) { + sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") + } + } + } + + test("statistics collection of a table with zero column") { + val table_no_cols = "table_no_cols" + withTable(table_no_cols) { + val rddNoCols = sparkContext.parallelize(1 to 10).map(_ => Row.empty) + val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) + dfNoCols.write.format("json").saveAsTable(table_no_cols) + sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") + checkTableStats(table_no_cols, hasSizeInBytes = true, expectedRowCounts = Some(10)) + } + } + test("analyze column command - unsupported types and invalid columns") { val tableName = "column_stats_test1" withTable(tableName) { @@ -239,154 +263,3 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } - - -/** - * The base for test cases that we want to include in both the hive module (for verifying behavior - * when using the Hive external catalog) as well as in the sql/core module. - */ -abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { - import testImplicits._ - - private val dec1 = new java.math.BigDecimal("1.000000000000000000") - private val dec2 = new java.math.BigDecimal("8.000000000000000000") - private val d1 = Date.valueOf("2016-05-08") - private val d2 = Date.valueOf("2016-05-09") - private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") - private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") - - /** - * Define a very simple 3 row table used for testing column serialization. - * Note: last column is seq[int] which doesn't support stats collection. - */ - protected val data = Seq[ - (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long, - jl.Double, jl.Float, java.math.BigDecimal, - String, Array[Byte], Date, Timestamp, - Seq[Int])]( - (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null), - (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null), - (null, null, null, null, null, null, null, null, null, null, null, null, null) - ) - - /** A mapping from column to the stats collected. */ - protected val stats = mutable.LinkedHashMap( - "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), - "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), - "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), - "cstring" -> ColumnStat(2, None, None, 1, 3, 3), - "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), - Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), - Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) - ) - - private val randomName = new Random(31) - - def checkTableStats( - tableName: String, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { - val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats - if (hasSizeInBytes || expectedRowCounts.nonEmpty) { - assert(stats.isDefined) - assert(stats.get.sizeInBytes >= 0) - assert(stats.get.rowCount === expectedRowCounts) - } else { - assert(stats.isEmpty) - } - - stats - } - - /** - * Compute column stats for the given DataFrame and compare it with colStats. - */ - def checkColStats( - df: DataFrame, - colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { - val tableName = "column_stats_test_" + randomName.nextInt(1000) - withTable(tableName) { - df.write.saveAsTable(tableName) - - // Collect statistics - sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + - colStats.keys.mkString(", ")) - - // Validate statistics - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) - assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == colStats.size) - - colStats.foreach { case (k, v) => - withClue(s"column $k") { - assert(table.stats.get.colStats(k) == v) - } - } - } - } - - // This test will be run twice: with and without Hive support - test("SPARK-18856: non-empty partitioned table should not report zero size") { - withTable("ds_tbl", "hive_tbl") { - spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") - val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats - assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") - - if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { - sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") - sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") - val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats - assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") - } - } - } - - // This test will be run twice: with and without Hive support - test("conversion from CatalogStatistics to Statistics") { - withTable("ds_tbl", "hive_tbl") { - // Test data source table - checkStatsConversion(tableName = "ds_tbl", isDatasourceTable = true) - // Test hive serde table - if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { - checkStatsConversion(tableName = "hive_tbl", isDatasourceTable = false) - } - } - } - - private def checkStatsConversion(tableName: String, isDatasourceTable: Boolean): Unit = { - // Create an empty table and run analyze command on it. - val createTableSql = if (isDatasourceTable) { - s"CREATE TABLE $tableName (c1 INT, c2 STRING) USING PARQUET" - } else { - s"CREATE TABLE $tableName (c1 INT, c2 STRING)" - } - sql(createTableSql) - // Analyze only one column. - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1") - val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect { - case catalogRel: CatalogRelation => (catalogRel, catalogRel.tableMeta) - case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) - }.head - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) - // Check catalog statistics - assert(catalogTable.stats.isDefined) - assert(catalogTable.stats.get.sizeInBytes == 0) - assert(catalogTable.stats.get.rowCount == Some(0)) - assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) - - // Check relation statistics - assert(relation.stats.sizeInBytes == 0) - assert(relation.stats.rowCount == Some(0)) - assert(relation.stats.attributeStats.size == 1) - val (attribute, colStat) = relation.stats.attributeStats.head - assert(attribute.name == "c1") - assert(colStat == emptyColStat) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala new file mode 100644 index 0000000000000..41569762d3c59 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.{lang => jl} +import java.sql.{Date, Timestamp} + +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.Decimal + + +/** + * The base for statistics test cases that we want to include in both the hive module (for + * verifying behavior when using the Hive external catalog) as well as in the sql/core module. + */ +abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { + import testImplicits._ + + private val dec1 = new java.math.BigDecimal("1.000000000000000000") + private val dec2 = new java.math.BigDecimal("8.000000000000000000") + private val d1 = Date.valueOf("2016-05-08") + private val d2 = Date.valueOf("2016-05-09") + private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") + private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + + /** + * Define a very simple 3 row table used for testing column serialization. + * Note: last column is seq[int] which doesn't support stats collection. + */ + protected val data = Seq[ + (jl.Boolean, jl.Byte, jl.Short, jl.Integer, jl.Long, + jl.Double, jl.Float, java.math.BigDecimal, + String, Array[Byte], Date, Timestamp, + Seq[Int])]( + (false, 1.toByte, 1.toShort, 1, 1L, 1.0, 1.0f, dec1, "s1", "b1".getBytes, d1, t1, null), + (true, 2.toByte, 3.toShort, 4, 5L, 6.0, 7.0f, dec2, "ss9", "bb0".getBytes, d2, t2, null), + (null, null, null, null, null, null, null, null, null, null, null, null, null) + ) + + /** A mapping from column to the stats collected. */ + protected val stats = mutable.LinkedHashMap( + "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), + "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), + "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), + "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), + "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), + "cstring" -> ColumnStat(2, None, None, 1, 3, 3), + "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), + "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), + Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), + Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) + ) + + private val randomName = new Random(31) + + def getCatalogTable(tableName: String): CatalogTable = { + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + } + + def getCatalogStatistics(tableName: String): CatalogStatistics = { + getCatalogTable(tableName).stats.get + } + + def checkTableStats( + tableName: String, + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { + val stats = getCatalogTable(tableName).stats + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes >= 0) + assert(stats.get.rowCount === expectedRowCounts) + } else { + assert(stats.isEmpty) + } + + stats + } + + /** + * Compute column stats for the given DataFrame and compare it with colStats. + */ + def checkColStats( + df: DataFrame, + colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + val tableName = "column_stats_test_" + randomName.nextInt(1000) + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Collect statistics + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + + colStats.keys.mkString(", ")) + + // Validate statistics + val table = getCatalogTable(tableName) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == colStats.size) + + colStats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) + } + } + } + } + + // This test will be run twice: with and without Hive support + test("SPARK-18856: non-empty partitioned table should not report zero size") { + withTable("ds_tbl", "hive_tbl") { + spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") + val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats + assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") + + if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { + sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") + sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") + val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats + assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") + } + } + } + + // This test will be run twice: with and without Hive support + test("conversion from CatalogStatistics to Statistics") { + withTable("ds_tbl", "hive_tbl") { + // Test data source table + checkStatsConversion(tableName = "ds_tbl", isDatasourceTable = true) + // Test hive serde table + if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { + checkStatsConversion(tableName = "hive_tbl", isDatasourceTable = false) + } + } + } + + private def checkStatsConversion(tableName: String, isDatasourceTable: Boolean): Unit = { + // Create an empty table and run analyze command on it. + val createTableSql = if (isDatasourceTable) { + s"CREATE TABLE $tableName (c1 INT, c2 STRING) USING PARQUET" + } else { + s"CREATE TABLE $tableName (c1 INT, c2 STRING)" + } + sql(createTableSql) + // Analyze only one column. + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1") + val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect { + case catalogRel: CatalogRelation => (catalogRel, catalogRel.tableMeta) + case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) + }.head + val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + // Check catalog statistics + assert(catalogTable.stats.isDefined) + assert(catalogTable.stats.get.sizeInBytes == 0) + assert(catalogTable.stats.get.rowCount == Some(0)) + assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) + + // Check relation statistics + assert(relation.stats.sizeInBytes == 0) + assert(relation.stats.rowCount == Some(0)) + assert(relation.stats.attributeStats.size == 1) + val (attribute, colStat) = relation.stats.attributeStats.head + assert(attribute.name == "c1") + assert(colStat == emptyColStat) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index c601038a2b0af..e00fa64e9f2ce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -25,7 +25,7 @@ import scala.util.matching.Regex import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { @@ -82,58 +81,42 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes // Non-partitioned table - sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() - sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + val nonPartTable = "non_part_table" + withTable(nonPartTable) { + sql(s"CREATE TABLE $nonPartTable (key STRING, value STRING)") + sql(s"INSERT INTO TABLE $nonPartTable SELECT * FROM src") + sql(s"INSERT INTO TABLE $nonPartTable SELECT * FROM src") - sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") + sql(s"ANALYZE TABLE $nonPartTable COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable") === BigInt(11624)) - - sql("DROP TABLE analyzeTable").collect() + assert(queryTotalSize(nonPartTable) === BigInt(11624)) + } // Partitioned table - sql( - """ - |CREATE TABLE analyzeTable_part (key STRING, value STRING) PARTITIONED BY (ds STRING) - """.stripMargin).collect() - sql( - """ - |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-01') - |SELECT * FROM src - """.stripMargin).collect() - sql( - """ - |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-02') - |SELECT * FROM src - """.stripMargin).collect() - sql( - """ - |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-03') - |SELECT * FROM src - """.stripMargin).collect() + val partTable = "part_table" + withTable(partTable) { + sql(s"CREATE TABLE $partTable (key STRING, value STRING) PARTITIONED BY (ds STRING)") + sql(s"INSERT INTO TABLE $partTable PARTITION (ds='2010-01-01') SELECT * FROM src") + sql(s"INSERT INTO TABLE $partTable PARTITION (ds='2010-01-02') SELECT * FROM src") + sql(s"INSERT INTO TABLE $partTable PARTITION (ds='2010-01-03') SELECT * FROM src") - assert(queryTotalSize("analyzeTable_part") === spark.sessionState.conf.defaultSizeInBytes) + assert(queryTotalSize(partTable) === spark.sessionState.conf.defaultSizeInBytes) - sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") + sql(s"ANALYZE TABLE $partTable COMPUTE STATISTICS noscan") - assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) - - sql("DROP TABLE analyzeTable_part").collect() + assert(queryTotalSize(partTable) === BigInt(17436)) + } // Try to analyze a temp table - sql("""SELECT * FROM src""").createOrReplaceTempView("tempTable") - intercept[AnalysisException] { - sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") + withView("tempTable") { + sql("""SELECT * FROM src""").createOrReplaceTempView("tempTable") + intercept[AnalysisException] { + sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") + } } - spark.sessionState.catalog.dropTable( - TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } test("SPARK-21079 - analyze table with location different than that of individual partitions") { - def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes - val tableName = "analyzeTable_part" withTable(tableName) { withTempPath { path => @@ -148,15 +131,12 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - assert(queryTotalSize(tableName) === BigInt(17436)) + assert(getCatalogStatistics(tableName).sizeInBytes === BigInt(17436)) } } } test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { - def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes - val sourceTableName = "analyzeTable_part" val tableName = "analyzeTable_part_vis" withTable(sourceTableName, tableName) { @@ -188,39 +168,19 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Register only one of the partitions found on disk val ds = partitionDates.head - sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')").collect() + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')") // Analyze original table - expect 3 partitions sql(s"ANALYZE TABLE $sourceTableName COMPUTE STATISTICS noscan") - assert(queryTotalSize(sourceTableName) === BigInt(3 * 5812)) + assert(getCatalogStatistics(sourceTableName).sizeInBytes === BigInt(3 * 5812)) // Analyze partial-copy table - expect only 1 partition sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - assert(queryTotalSize(tableName) === BigInt(5812)) + assert(getCatalogStatistics(tableName).sizeInBytes === BigInt(5812)) } } } - test("analyzing views is not supported") { - def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { - val err = intercept[AnalysisException] { - sql(analyzeCommand) - } - assert(err.message.contains("ANALYZE TABLE is not supported")) - } - - val tableName = "tbl" - withTable(tableName) { - spark.range(10).write.saveAsTable(tableName) - val viewName = "view" - withView(viewName) { - sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") - assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") - assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") - } - } - } - test("test table-level statistics for hive tables created in HiveExternalCatalog") { val textTable = "textTable" withTable(textTable) { @@ -290,8 +250,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto if (analyzedByHive) hiveClient.runSqlHive(s"ANALYZE TABLE $tabName COMPUTE STATISTICS") val describeResult1 = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") - val tableMetadata = - spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName)).properties + val tableMetadata = getCatalogTable(tabName).properties // statistics info is not contained in the metadata of the original table assert(Seq(StatsSetupConst.COLUMN_STATS_ACCURATE, StatsSetupConst.NUM_FILES, @@ -327,8 +286,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val tabName = "tab1" withTable(tabName) { createNonPartitionedTable(tabName, analyzedByHive = false, analyzedBySpark = false) - checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = None) + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = None) // ALTER TABLE SET TBLPROPERTIES invalidates some contents of Hive specific statistics // This is triggered by the Hive alterTable API @@ -370,10 +328,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } test("alter table should not have the side effect to store statistics in Spark side") { - def getCatalogTable(tableName: String): CatalogTable = { - spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) - } - val table = "alter_table_side_effect" withTable(table) { sql(s"CREATE TABLE $table (i string, j string)") @@ -637,12 +591,12 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it // for robustness - withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "true") { + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true") { checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } - withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "true") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { // We still can get tableSize from Hive before Analyze checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = None) sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") @@ -759,8 +713,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val parquetTable = "parquetTable" withTable(parquetTable) { sql(createTableCmd) - val catalogTable = spark.sessionState.catalog.getTableMetadata( - TableIdentifier(parquetTable)) + val catalogTable = getCatalogTable(parquetTable) assert(DDLUtils.isDatasourceTable(catalogTable)) // Add a filter to avoid creating too many partitions @@ -795,17 +748,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto "partitioned data source table", "CREATE TABLE parquetTable (key STRING, value STRING) USING PARQUET PARTITIONED BY (key)") - test("statistics collection of a table with zero column") { - val table_no_cols = "table_no_cols" - withTable(table_no_cols) { - val rddNoCols = sparkContext.parallelize(1 to 10).map(_ => Row.empty) - val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) - dfNoCols.write.format("json").saveAsTable(table_no_cols) - sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") - checkTableStats(table_no_cols, hasSizeInBytes = true, expectedRowCounts = Some(10)) - } - } - /** Used to test refreshing cached metadata once table stats are updated. */ private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean) : (CatalogStatistics, CatalogStatistics) = { From d540dfbff33aa2f8571e0de149dfa3f4e7321113 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 6 Jul 2017 19:12:15 +0800 Subject: [PATCH 107/779] [SPARK-21273][SQL][FOLLOW-UP] Add missing test cases back and revise code style ## What changes were proposed in this pull request? Add missing test cases back and revise code style Follow up the previous PR: https://github.com/apache/spark/pull/18479 ## How was this patch tested? Unit test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18548 from gengliangwang/stat_propagation_revise. --- .../plans/logical/LogicalPlanVisitor.scala | 2 +- .../BasicStatsEstimationSuite.scala | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index b23045810a4f6..2652f6d72730c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -38,10 +38,10 @@ trait LogicalPlanVisitor[T] { case p: Range => visitRange(p) case p: Repartition => visitRepartition(p) case p: RepartitionByExpression => visitRepartitionByExpr(p) + case p: ResolvedHint => visitHint(p) case p: Sample => visitSample(p) case p: ScriptTransformation => visitScriptTransform(p) case p: Union => visitUnion(p) - case p: ResolvedHint => visitHint(p) case p: LogicalPlan => default(p) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 5fd21a06a109d..913be6d1ff07f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -78,6 +78,37 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { checkStats(globalLimit, stats) } + test("sample estimation") { + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan) + checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) + + // Child doesn't have rowCount in stats + val childStats = Statistics(sizeInBytes = 120) + val childPlan = DummyLogicalPlan(childStats, childStats) + val sample2 = + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan) + checkStats(sample2, Statistics(sizeInBytes = 14)) + } + + test("estimate statistics when the conf changes") { + val expectedDefaultStats = + Statistics( + sizeInBytes = 40, + rowCount = Some(10), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) + val expectedCboStats = + Statistics( + sizeInBytes = 4, + rowCount = Some(1), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) + + val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) + checkStats( + plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats) + } + /** Check estimated stats when cbo is turned on/off. */ private def checkStats( plan: LogicalPlan, @@ -99,3 +130,17 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit = checkStats(plan, expectedStats, expectedStats) } + +/** + * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes + * a simple statistics or a cbo estimated statistics based on the conf. + */ +private case class DummyLogicalPlan( + defaultStats: Statistics, + cboStats: Statistics) + extends LeafNode { + + override def output: Seq[Attribute] = Nil + + override def computeStats(): Statistics = if (conf.cboEnabled) cboStats else defaultStats +} From 565e7a8d4ae7879ee704fb94ae9b3da31e202d7e Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Thu, 6 Jul 2017 19:49:34 +0800 Subject: [PATCH 108/779] [SPARK-20950][CORE] add a new config to diskWriteBufferSize which is hard coded before ## What changes were proposed in this pull request? This PR Improvement in two: 1.With spark.shuffle.spill.diskWriteBufferSize configure diskWriteBufferSize of ShuffleExternalSorter. when change the size of the diskWriteBufferSize to test `forceSorterToSpill` The average performance of running 10 times is as follows:(their unit is MS). ``` diskWriteBufferSize: 1M 512K 256K 128K 64K 32K 16K 8K 4K --------------------------------------------------------------------------------------- RecordSize = 2.5M 742 722 694 686 667 668 671 669 683 RecordSize = 1M 294 293 292 287 283 285 281 279 285 ``` 2.Remove outputBufferSizeInBytes and inputBufferSizeInBytes to initialize in mergeSpillsWithFileStream function. ## How was this patch tested? The unit test. Author: caoxuewen Closes #18174 from heary-cao/buffersize. --- .../shuffle/sort/ShuffleExternalSorter.java | 11 +++++--- .../shuffle/sort/UnsafeShuffleWriter.java | 14 +++++++--- .../unsafe/sort/UnsafeSorterSpillWriter.java | 24 ++++++++++------- .../spark/internal/config/package.scala | 27 +++++++++++++++++++ 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c33d1e33f030f..338faaadb33d4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -43,6 +43,7 @@ import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; +import org.apache.spark.internal.config.package$; /** * An external sorter that is specialized for sort-based shuffle. @@ -82,6 +83,9 @@ final class ShuffleExternalSorter extends MemoryConsumer { /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; + /** The buffer size to use when writing the sorted records to an on-disk file */ + private final int diskWriteBufferSize; + /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -116,13 +120,14 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.taskContext = taskContext; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.fileBufferSizeBytes = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); this.peakMemoryUsedBytes = getMemoryUsage(); + this.diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); } /** @@ -155,7 +160,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // be an API to directly transfer bytes from managed memory to the disk writer, we buffer // data through a byte array. This array does not need to be large enough to hold a single // record; - final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + final byte[] writeBuffer = new byte[diskWriteBufferSize]; // Because this output will be read during shuffle, its compression codec must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use @@ -195,7 +200,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { - final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); Platform.copyMemory( recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); writer.write(writeBuffer, 0, toTransfer); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 34c179990214f..1b578491b81d7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -55,6 +55,7 @@ import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; +import org.apache.spark.internal.config.package$; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -65,6 +66,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @VisibleForTesting static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; + static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; private final IndexShuffleBlockResolver shuffleBlockResolver; @@ -78,6 +80,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; private final int initialSortBufferSize; + private final int inputBufferSizeInBytes; + private final int outputBufferSizeInBytes; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -140,6 +144,10 @@ public UnsafeShuffleWriter( this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE); + this.inputBufferSizeInBytes = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.outputBufferSizeInBytes = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; open(); } @@ -209,7 +217,7 @@ private void open() throws IOException { partitioner.numPartitions(), sparkConf, writeMetrics); - serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); serOutputStream = serializer.serializeStream(serBuffer); } @@ -360,12 +368,10 @@ private long[] mergeSpillsWithFileStream( final OutputStream bos = new BufferedOutputStream( new FileOutputStream(outputFile), - (int) sparkConf.getSizeAsKb("spark.shuffle.unsafe.file.output.buffer", "32k") * 1024); + outputBufferSizeInBytes); // Use a counting output stream to avoid having to close the underlying file and ask // the file system for its size after each partition is written. final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); - final int inputBufferSizeInBytes = - (int) sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; boolean threwException = true; try { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 164b9d70b79d7..f9b5493755443 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -20,9 +20,10 @@ import java.io.File; import java.io.IOException; -import org.apache.spark.serializer.SerializerManager; import scala.Tuple2; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.storage.BlockId; @@ -30,6 +31,7 @@ import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.internal.config.package$; /** * Spills a list of sorted records to disk. Spill files have the following format: @@ -38,12 +40,16 @@ */ public final class UnsafeSorterSpillWriter { - static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + private final SparkConf conf = new SparkConf(); + + /** The buffer size to use when writing the sorted records to an on-disk file */ + private final int diskWriteBufferSize = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer // data through a byte array. - private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + private byte[] writeBuffer = new byte[diskWriteBufferSize]; private final File file; private final BlockId blockId; @@ -114,7 +120,7 @@ public void write( writeIntToBuffer(recordLength, 0); writeLongToBuffer(keyPrefix, 4); int dataRemaining = recordLength; - int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + int freeSpaceInWriteBuffer = diskWriteBufferSize - 4 - 8; // space used by prefix + len long recordReadPosition = baseOffset; while (dataRemaining > 0) { final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); @@ -122,15 +128,15 @@ public void write( baseObject, recordReadPosition, writeBuffer, - Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + Platform.BYTE_ARRAY_OFFSET + (diskWriteBufferSize - freeSpaceInWriteBuffer), toTransfer); - writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + writer.write(writeBuffer, 0, (diskWriteBufferSize - freeSpaceInWriteBuffer) + toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; - freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + freeSpaceInWriteBuffer = diskWriteBufferSize; } - if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { - writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + if (freeSpaceInWriteBuffer < diskWriteBufferSize) { + writer.write(writeBuffer, 0, (diskWriteBufferSize - freeSpaceInWriteBuffer)); } writer.recordWritten(); } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8dee0d970c4c6..a629810bf093a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -336,4 +336,31 @@ package object config { "spark.") .booleanConf .createWithDefault(false) + + private[spark] val SHUFFLE_FILE_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.file.buffer") + .doc("Size of the in-memory buffer for each shuffle file output stream. " + + "These buffers reduce the number of disk seeks and system calls made " + + "in creating intermediate shuffle files.") + .bytesConf(ByteUnit.KiB) + .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, + s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") + .createWithDefaultString("32k") + + private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.unsafe.file.output.buffer") + .doc("The file system for this buffer size after each partition " + + "is written in unsafe shuffle writer.") + .bytesConf(ByteUnit.KiB) + .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, + s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") + .createWithDefaultString("32k") + + private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize") + .doc("The buffer size to use when writing the sorted records to an on-disk file.") + .bytesConf(ByteUnit.BYTE) + .checkValue(v => v > 0 && v <= Int.MaxValue, + s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") + .createWithDefault(1024 * 1024) } From 26ac085debb54d0104762d1cd4187cdf73f301ba Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Fri, 7 Jul 2017 01:04:57 +0800 Subject: [PATCH 109/779] [SPARK-21228][SQL] InSet incorrect handling of structs ## What changes were proposed in this pull request? When data type is struct, InSet now uses TypeUtils.getInterpretedOrdering (similar to EqualTo) to build a TreeSet. In other cases it will use a HashSet as before (which should be faster). Similarly, In.eval uses Ordering.equiv instead of equals. ## How was this patch tested? New test in SQLQuerySuite. Author: Bogdan Raducanu Closes #18455 from bogdanrdc/SPARK-21228. --- .../sql/catalyst/expressions/predicates.scala | 57 ++++++++++++------- .../catalyst/expressions/PredicateSuite.scala | 31 +++++----- .../catalyst/optimizer/OptimizeInSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 22 +++++++ 4 files changed, 78 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f3fe58caa6fe2..7bf10f199f1c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.immutable.TreeSet + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -175,20 +176,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. """.stripMargin) } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } } case _ => - if (list.exists(l => l.dataType != value.dataType)) { - TypeCheckResult.TypeCheckFailure("Arguments must be same type") + val mismatchOpt = list.find(l => l.dataType != value.dataType) + if (mismatchOpt.isDefined) { + TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + + s"${value.dataType} != ${mismatchOpt.get.dataType}") } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } } } override def children: Seq[Expression] = value +: list lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) + private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -203,10 +207,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { var hasNull = false list.foreach { e => val v = e.eval(input) - if (v == evaluatedValue) { - return true - } else if (v == null) { + if (v == null) { hasNull = true + } else if (ordering.equiv(v, evaluatedValue)) { + return true } } if (hasNull) { @@ -265,7 +269,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with override def nullable: Boolean = child.nullable || hasNull protected override def nullSafeEval(value: Any): Any = { - if (hset.contains(value)) { + if (set.contains(value)) { true } else if (hasNull) { null @@ -274,27 +278,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } } - def getHSet(): Set[Any] = hset + @transient private[this] lazy val set = child.dataType match { + case _: AtomicType => hset + case _: NullType => hset + case _ => + // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows + TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset + } + + def getSet(): Set[Any] = set override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val setName = classOf[Set[Any]].getName val InSetName = classOf[InSet].getName val childGen = child.genCode(ctx) ctx.references += this - val hsetTerm = ctx.freshName("hset") - val hasNullTerm = ctx.freshName("hasNull") - ctx.addMutableState(setName, hsetTerm, - s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();") - ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") + val setTerm = ctx.freshName("set") + val setNull = if (hasNull) { + s""" + |if (!${ev.value}) { + | ${ev.isNull} = true; + |} + """.stripMargin + } else { + "" + } + ctx.addMutableState(setName, setTerm, + s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();") ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; boolean ${ev.value} = false; if (!${ev.isNull}) { - ${ev.value} = $hsetTerm.contains(${childGen.value}); - if (!${ev.value} && $hasNullTerm) { - ${ev.isNull} = true; - } + ${ev.value} = $setTerm.contains(${childGen.value}); + $setNull } """) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 6fe295c3dd936..ef510a95ef446 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test(s"3VL $name") { truthTable.foreach { case (l, r, answer) => - val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType)) + val expr = op(NonFoldableLiteral.create(l, BooleanType), + NonFoldableLiteral.create(r, BooleanType)) checkEvaluation(expr, answer) } } @@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (false, true) :: (null, null) :: Nil notTrueTable.foreach { case (v, answer) => - checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer) + checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer) } checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType) } @@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("IN") { - checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(NonFoldableLiteral(null, IntegerType), - Seq(NonFoldableLiteral(null, IntegerType))), null) - checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), + Literal(2))), null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), + Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) checkEvaluation(In(Literal(1), Seq.empty), false) - checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + true) + checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + null) checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) checkEvaluation( - And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), + And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), + Literal(2)))), true) - val ns = NonFoldableLiteral(null, StringType) + val ns = NonFoldableLiteral.create(null, StringType) checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) checkEvaluation(In(ns, Seq(ns)), null) checkEvaluation(In(Literal("a"), Seq(ns)), null) @@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { case _ => value } } - val input = inputData.map(NonFoldableLiteral(_, t)) + val input = inputData.map(NonFoldableLiteral.create(_, t)) val expected = if (inputData(0) == null) { null } else if (inputData.slice(1, 10).contains(inputData(0))) { @@ -279,7 +284,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("BinaryComparison: null test") { // Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757 val normalInt = Literal(-1) - val nullInt = NonFoldableLiteral(null, IntegerType) + val nullInt = NonFoldableLiteral.create(null, IntegerType) def nullTest(op: (Expression, Expression) => Expression): Unit = { checkEvaluation(op(normalInt, nullInt), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 6a77580b29a21..28bf7b6f84341 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -169,7 +169,7 @@ class OptimizeInSuite extends PlanTest { val optimizedPlan = OptimizeIn(plan) optimizedPlan match { case Filter(cond, _) - if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => + if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 => // pass case _ => fail("Unexpected result for OptimizedIn") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 68f61cfab6d2f..5171aaebc9907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2616,4 +2616,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)")) assert(e.message.contains("Invalid number of arguments")) } + + test("SPARK-21228: InSet incorrect handling of structs") { + withTempView("A") { + // reduce this from the default of 10 so the repro query text is not too long + withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) { + // a relation that has 1 column of struct type with values (1,1), ..., (9, 9) + spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a") + .createOrReplaceTempView("A") + val df = sql( + """ + |SELECT * from + | (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows + | -- the IN will become InSet with a Set of GenericInternalRows + | -- a GenericInternalRow is never equal to an UnsafeRow so the query would + | -- returns 0 results, which is incorrect + | WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L), + | NAMED_STRUCT('a', 3L, 'b', 3L)) + """.stripMargin) + checkAnswer(df, Row(Row(1, 1))) + } + } + } } From 48e44b24a7663142176102ac4c6bf4242f103804 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 7 Jul 2017 01:07:45 +0800 Subject: [PATCH 110/779] [SPARK-21204][SQL] Add support for Scala Set collection types in serialization ## What changes were proposed in this pull request? Currently we can't produce a `Dataset` containing `Set` in SparkSQL. This PR tries to support serialization/deserialization of `Set`. Because there's no corresponding internal data type in SparkSQL for a `Set`, the most proper choice for serializing a set should be an array. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #18416 from viirya/SPARK-21204. --- .../spark/sql/catalyst/ScalaReflection.scala | 28 +++++++++++++++-- .../expressions/objects/objects.scala | 5 +-- .../org/apache/spark/sql/SQLImplicits.scala | 10 ++++++ .../spark/sql/DataFrameAggregateSuite.scala | 10 ++++++ .../spark/sql/DatasetPrimitiveSuite.scala | 31 +++++++++++++++++++ 5 files changed, 79 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 814f2c10b9097..4d5401f30d392 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -309,7 +309,10 @@ object ScalaReflection extends ScalaReflection { Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } - case t if t <:< localTypeOf[Seq[_]] => + // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array + // to a `Set`, if there are duplicated elements, the elements will be de-duplicated. + case t if t <:< localTypeOf[Seq[_]] || + t <:< localTypeOf[scala.collection.Set[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) @@ -327,8 +330,10 @@ object ScalaReflection extends ScalaReflection { } val companion = t.normalize.typeSymbol.companionSymbol.typeSignature - val cls = companion.declaration(newTermName("newBuilder")) match { - case NoSymbol => classOf[Seq[_]] + val cls = companion.member(newTermName("newBuilder")) match { + case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]] + case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] => + classOf[scala.collection.Set[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } UnresolvedMapObjects(mapFunction, getPath, Some(cls)) @@ -502,6 +507,19 @@ object ScalaReflection extends ScalaReflection { serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) + case t if t <:< localTypeOf[scala.collection.Set[_]] => + val TypeRef(_, _, Seq(elementType)) = t + + // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. + // Note that the property of `Set` is only kept when manipulating the data as domain object. + val newInput = + Invoke( + inputObject, + "toSeq", + ObjectType(classOf[Seq[_]])) + + toCatalystArray(newInput, elementType) + case t if t <:< localTypeOf[String] => StaticInvoke( classOf[UTF8String], @@ -713,6 +731,10 @@ object ScalaReflection extends ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< localTypeOf[Set[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 24c06d8b14b54..9b28a18035b1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -627,8 +627,9 @@ case class MapObjects private( val (initCollection, addElement, getResult): (String, String => String, String) = customCollectionCls match { - case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - // Scala sequence + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) || + classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala sequence or set val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" val builder = ctx.freshName("collectionBuilder") ( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 86574e2f71d92..05db292bd41b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -171,6 +171,16 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.3.0 */ implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() + /** + * Notice that we serialize `Set` to Catalyst array. The set property is only kept when + * manipulating the domain objects. The serialization format doesn't keep the set property. + * When we have a Catalyst array which contains duplicated elements and convert it to + * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. + * + * @since 2.3.0 + */ + implicit def newSetEncoder[T <: Set[_] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Arrays /** @since 1.6.1 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 5db354d79bb6e..b52d50b195bcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -460,6 +460,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df.select(collect_set($"a"), collect_set($"b")), Seq(Row(Seq(1, 2, 3), Seq(2, 4))) ) + + checkDataset( + df.select(collect_set($"a").as("aSet")).as[Set[Int]], + Set(1, 2, 3)) + checkDataset( + df.select(collect_set($"b").as("bSet")).as[Set[Int]], + Set(2, 4)) + checkDataset( + df.select(collect_set($"a"), collect_set($"b")).as[(Set[Int], Set[Int])], + Seq(Set(1, 2, 3) -> Set(2, 4)): _*) } test("collect functions structs") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index a6847dcfbffc4..f62f9e23db66d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.immutable.{HashSet => HSet} import scala.collection.immutable.Queue import scala.collection.mutable.{LinkedHashMap => LHMap} import scala.collection.mutable.ArrayBuffer @@ -342,6 +343,31 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) } + test("arbitrary sets") { + checkDataset(Seq(Set(1, 2, 3, 4)).toDS(), Set(1, 2, 3, 4)) + checkDataset(Seq(Set(1.toLong, 2.toLong)).toDS(), Set(1.toLong, 2.toLong)) + checkDataset(Seq(Set(1.toDouble, 2.toDouble)).toDS(), Set(1.toDouble, 2.toDouble)) + checkDataset(Seq(Set(1.toFloat, 2.toFloat)).toDS(), Set(1.toFloat, 2.toFloat)) + checkDataset(Seq(Set(1.toByte, 2.toByte)).toDS(), Set(1.toByte, 2.toByte)) + checkDataset(Seq(Set(1.toShort, 2.toShort)).toDS(), Set(1.toShort, 2.toShort)) + checkDataset(Seq(Set(true, false)).toDS(), Set(true, false)) + checkDataset(Seq(Set("test1", "test2")).toDS(), Set("test1", "test2")) + checkDataset(Seq(Set(Tuple1(1), Tuple1(2))).toDS(), Set(Tuple1(1), Tuple1(2))) + + checkDataset(Seq(HSet(1, 2)).toDS(), HSet(1, 2)) + checkDataset(Seq(HSet(1.toLong, 2.toLong)).toDS(), HSet(1.toLong, 2.toLong)) + checkDataset(Seq(HSet(1.toDouble, 2.toDouble)).toDS(), HSet(1.toDouble, 2.toDouble)) + checkDataset(Seq(HSet(1.toFloat, 2.toFloat)).toDS(), HSet(1.toFloat, 2.toFloat)) + checkDataset(Seq(HSet(1.toByte, 2.toByte)).toDS(), HSet(1.toByte, 2.toByte)) + checkDataset(Seq(HSet(1.toShort, 2.toShort)).toDS(), HSet(1.toShort, 2.toShort)) + checkDataset(Seq(HSet(true, false)).toDS(), HSet(true, false)) + checkDataset(Seq(HSet("test1", "test2")).toDS(), HSet("test1", "test2")) + checkDataset(Seq(HSet(Tuple1(1), Tuple1(2))).toDS(), HSet(Tuple1(1), Tuple1(2))) + + checkDataset(Seq(Seq(Some(1), None), Seq(Some(2))).toDF("c").as[Set[Integer]], + Seq(Set[Integer](1, null), Set[Integer](2)): _*) + } + test("nested sequences") { checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) @@ -352,6 +378,11 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) } + test("nested set") { + checkDataset(Seq(Set(HSet(1, 2), HSet(3, 4))).toDS(), Set(HSet(1, 2), HSet(3, 4))) + checkDataset(Seq(HSet(Set(1, 2), Set(3, 4))).toDS(), HSet(Set(1, 2), Set(3, 4))) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) From bf66335acab3c0c188f6c378eb8aa6948a259cb2 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 6 Jul 2017 13:58:27 -0700 Subject: [PATCH 111/779] [SPARK-21323][SQL] Rename plans.logical.statsEstimation.Range to ValueInterval ## What changes were proposed in this pull request? Rename org.apache.spark.sql.catalyst.plans.logical.statsEstimation.Range to ValueInterval. The current naming is identical to logical operator "range". Refactoring it to ValueInterval is more accurate. ## How was this patch tested? unit test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18549 from gengliangwang/ValueInterval. --- .../statsEstimation/FilterEstimation.scala | 36 ++++++++-------- .../statsEstimation/JoinEstimation.scala | 14 +++---- .../{Range.scala => ValueInterval.scala} | 41 ++++++++++--------- 3 files changed, 48 insertions(+), 43 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/{Range.scala => ValueInterval.scala} (65%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 5a3bee7b9e449..e13db85c7a76e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -316,8 +316,8 @@ case class FilterEstimation(plan: Filter) extends Logging { // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. - val statsRange = Range(colStat.min, colStat.max, attr.dataType) - if (statsRange.contains(literal)) { + val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType) + if (statsInterval.contains(literal)) { if (update) { // We update ColumnStat structure after apply this equality predicate: // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal @@ -388,9 +388,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => - val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange] + val statsInterval = + ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => - v != null && statsRange.contains(Literal(v, dataType)) + v != null && statsInterval.contains(Literal(v, dataType)) } if (validQuerySet.isEmpty) { @@ -440,12 +441,13 @@ case class FilterEstimation(plan: Filter) extends Logging { update: Boolean): Option[BigDecimal] = { val colStat = colStatsMap(attr) - val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] - val max = statsRange.max.toBigDecimal - val min = statsRange.min.toBigDecimal + val statsInterval = + ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] + val max = statsInterval.max.toBigDecimal + val min = statsInterval.min.toBigDecimal val ndv = BigDecimal(colStat.distinctCount) - // determine the overlapping degree between predicate range and column's range + // determine the overlapping degree between predicate interval and column's interval val numericLiteral = if (literal.dataType == BooleanType) { if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) } else { @@ -566,18 +568,18 @@ case class FilterEstimation(plan: Filter) extends Logging { } val colStatLeft = colStatsMap(attrLeft) - val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) - .asInstanceOf[NumericRange] - val maxLeft = statsRangeLeft.max - val minLeft = statsRangeLeft.min + val statsIntervalLeft = ValueInterval(colStatLeft.min, colStatLeft.max, attrLeft.dataType) + .asInstanceOf[NumericValueInterval] + val maxLeft = statsIntervalLeft.max + val minLeft = statsIntervalLeft.min val colStatRight = colStatsMap(attrRight) - val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) - .asInstanceOf[NumericRange] - val maxRight = statsRangeRight.max - val minRight = statsRangeRight.min + val statsIntervalRight = ValueInterval(colStatRight.min, colStatRight.max, attrRight.dataType) + .asInstanceOf[NumericValueInterval] + val maxRight = statsIntervalRight.max + val minRight = statsIntervalRight.min - // determine the overlapping degree between predicate range and column's range + // determine the overlapping degree between predicate interval and column's interval val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { // Left < Right or Left <= Right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index f48196997a24d..dcbe36da91dfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -175,9 +175,9 @@ case class InnerOuterEstimation(join: Join) extends Logging { // Check if the two sides are disjoint val leftKeyStats = leftStats.attributeStats(leftKey) val rightKeyStats = rightStats.attributeStats(rightKey) - val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - if (Range.isIntersected(lRange, rRange)) { + val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + if (ValueInterval.isIntersected(lInterval, rInterval)) { // Get the largest ndv among pairs of join keys val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) if (maxNdv > ndvDenom) ndvDenom = maxNdv @@ -239,16 +239,16 @@ case class InnerOuterEstimation(join: Join) extends Logging { joinKeyPairs.foreach { case (leftKey, rightKey) => val leftKeyStats = leftStats.attributeStats(leftKey) val rightKeyStats = rightStats.attributeStats(rightKey) - val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) // When we reach here, join selectivity is not zero, so each pair of join keys should be // intersected. - assert(Range.isIntersected(lRange, rRange)) + assert(ValueInterval.isIntersected(lInterval, rInterval)) // Update intersected column stats assert(leftKey.dataType.sameType(rightKey.dataType)) val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) - val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) + val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala similarity index 65% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala index 4ac5ba5689f82..0caaf796a3b68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.types._ /** Value range of a column. */ -trait Range { +trait ValueInterval { def contains(l: Literal): Boolean } -/** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: Decimal, max: Decimal) extends Range { +/** For simplicity we use decimal to unify operations of numeric intervals. */ +case class NumericValueInterval(min: Decimal, max: Decimal) extends ValueInterval { override def contains(l: Literal): Boolean = { val lit = EstimationUtils.toDecimal(l.value, l.dataType) min <= lit && max >= lit @@ -38,46 +38,49 @@ case class NumericRange(min: Decimal, max: Decimal) extends Range { * This version of Spark does not have min/max for binary/string types, we define their default * behaviors by this class. */ -class DefaultRange extends Range { +class DefaultValueInterval extends ValueInterval { override def contains(l: Literal): Boolean = true } /** This is for columns with only null values. */ -class NullRange extends Range { +class NullValueInterval extends ValueInterval { override def contains(l: Literal): Boolean = false } -object Range { - def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { - case StringType | BinaryType => new DefaultRange() - case _ if min.isEmpty || max.isEmpty => new NullRange() +object ValueInterval { + def apply( + min: Option[Any], + max: Option[Any], + dataType: DataType): ValueInterval = dataType match { + case StringType | BinaryType => new DefaultValueInterval() + case _ if min.isEmpty || max.isEmpty => new NullValueInterval() case _ => - NumericRange( + NumericValueInterval( min = EstimationUtils.toDecimal(min.get, dataType), max = EstimationUtils.toDecimal(max.get, dataType)) } - def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { - case (_, _: DefaultRange) | (_: DefaultRange, _) => - // The DefaultRange represents string/binary types which do not have max/min stats, + def isIntersected(r1: ValueInterval, r2: ValueInterval): Boolean = (r1, r2) match { + case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) => + // The DefaultValueInterval represents string/binary types which do not have max/min stats, // we assume they are intersected to be conservative on estimation true - case (_, _: NullRange) | (_: NullRange, _) => + case (_, _: NullValueInterval) | (_: NullValueInterval, _) => false - case (n1: NumericRange, n2: NumericRange) => + case (n1: NumericValueInterval, n2: NumericValueInterval) => n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } /** - * Intersected results of two ranges. This is only for two overlapped ranges. + * Intersected results of two intervals. This is only for two overlapped intervals. * The outputs are the intersected min/max values. */ - def intersect(r1: Range, r2: Range, dt: DataType): (Option[Any], Option[Any]) = { + def intersect(r1: ValueInterval, r2: ValueInterval, dt: DataType): (Option[Any], Option[Any]) = { (r1, r2) match { - case (_, _: DefaultRange) | (_: DefaultRange, _) => + case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) => // binary/string types don't support intersecting. (None, None) - case (n1: NumericRange, n2: NumericRange) => + case (n1: NumericValueInterval, n2: NumericValueInterval) => // Choose the maximum of two min values, and the minimum of two max values. val newMin = if (n1.min <= n2.min) n2.min else n1.min val newMax = if (n1.max <= n2.max) n1.max else n2.max From 0217dfd26f89133f146197359b556c9bf5aca172 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 6 Jul 2017 17:28:20 -0700 Subject: [PATCH 112/779] [SPARK-21267][SS][DOCS] Update Structured Streaming Documentation ## What changes were proposed in this pull request? Few changes to the Structured Streaming documentation - Clarify that the entire stream input table is not materialized - Add information for Ganglia - Add Kafka Sink to the main docs - Removed a couple of leftover experimental tags - Added more associated reading material and talk videos. In addition, https://github.com/apache/spark/pull/16856 broke the link to the RDD programming guide in several places while renaming the page. This PR fixes those sameeragarwal cloud-fan. - Added a redirection to avoid breaking internal and possible external links. - Removed unnecessary redirection pages that were there since the separate scala, java, and python programming guides were merged together in 2013 or 2014. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tathagata Das Closes #18485 from tdas/SPARK-21267. --- docs/_layouts/global.html | 7 +- docs/index.md | 13 +- docs/java-programming-guide.md | 7 - docs/programming-guide.md | 7 + docs/python-programming-guide.md | 7 - docs/rdd-programming-guide.md | 2 +- docs/scala-programming-guide.md | 7 - docs/sql-programming-guide.md | 16 +- .../structured-streaming-programming-guide.md | 172 +++++++++++++++--- .../scala/org/apache/spark/sql/Dataset.scala | 3 - 10 files changed, 169 insertions(+), 72 deletions(-) delete mode 100644 docs/java-programming-guide.md create mode 100644 docs/programming-guide.md delete mode 100644 docs/python-programming-guide.md delete mode 100644 docs/scala-programming-guide.md diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index c00d0db63cd10..570483c0b04ea 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -69,11 +69,10 @@ Programming Guides
+### Reporting Metrics using Dropwizard +Spark supports reporting metrics using the [Dropwizard Library](monitoring.html#metrics). To enable metrics of Structured Streaming queries to be reported as well, you have to explicitly enable the configuration `spark.sql.streaming.metricsEnabled` in the SparkSession. + +
+
+{% highlight scala %} +spark.conf.set("spark.sql.streaming.metricsEnabled", "true") +// or +spark.sql("SET spark.sql.streaming.metricsEnabled=true") +{% endhighlight %} +
+
+{% highlight java %} +spark.conf().set("spark.sql.streaming.metricsEnabled", "true"); +// or +spark.sql("SET spark.sql.streaming.metricsEnabled=true"); +{% endhighlight %} +
+
+{% highlight python %} +spark.conf.set("spark.sql.streaming.metricsEnabled", "true") +# or +spark.sql("SET spark.sql.streaming.metricsEnabled=true") +{% endhighlight %} +
+
+{% highlight r %} +sql("SET spark.sql.streaming.metricsEnabled=true") +{% endhighlight %} +
+
+ + +All queries started in the SparkSession after this configuration has been enabled will report metrics through Dropwizard to whatever [sinks](monitoring.html#metrics) have been configured (e.g. Ganglia, Graphite, JMX, etc.). + ## Recovering from Failures with Checkpointing In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). @@ -1971,8 +2082,23 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat -# Where to go from here -- Examples: See and run the -[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming)/[R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r/streaming) -examples. +# Additional Information + +**Further Reading** + +- See and run the + [Scala]({{site.SPARK_GITHUB_URL}}/tree/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming)/[R]({{site.SPARK_GITHUB_URL}}/tree/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/r/streaming) + examples. + - [Instructions](index.html#running-the-examples-and-shell) on how to run Spark examples +- Read about integrating with Kafka in the [Structured Streaming Kafka Integration Guide](structured-streaming-kafka-integration.html) +- Read more details about using DataFrames/Datasets in the [Spark SQL Programming Guide](sql-programming-guide.html) +- Third-party Blog Posts + - [Real-time Streaming ETL with Structured Streaming in Apache Spark 2.1 (Databricks Blog)](https://databricks.com/blog/2017/01/19/real-time-streaming-etl-structured-streaming-apache-spark-2-1.html) + - [Real-Time End-to-End Integration with Apache Kafka in Apache Spark’s Structured Streaming (Databricks Blog)](https://databricks.com/blog/2017/04/04/real-time-end-to-end-integration-with-apache-kafka-in-apache-sparks-structured-streaming.html) + - [Event-time Aggregation and Watermarking in Apache Spark’s Structured Streaming (Databricks Blog)](https://databricks.com/blog/2017/05/08/event-time-aggregation-watermarking-apache-sparks-structured-streaming.html) + +**Talks** + +- Spark Summit 2017 Talk - [Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark](https://spark-summit.org/2017/events/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark/) - Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7be4aa1ca9562..b1638a2180b07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -520,7 +520,6 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming @@ -581,7 +580,6 @@ class Dataset[T] private[sql]( } /** - * :: Experimental :: * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time * before which we assume no more late data is going to arrive. * @@ -605,7 +603,6 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. From 40c7add3a4c811202d1fa2be9606aca08df81266 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Jul 2017 08:44:31 +0800 Subject: [PATCH 113/779] [SPARK-20946][SQL] Do not update conf for existing SparkContext in SparkSession.getOrCreate ## What changes were proposed in this pull request? SparkContext is shared by all sessions, we should not update its conf for only one session. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #18536 from cloud-fan/config. --- .../spark/ml/recommendation/ALSSuite.scala | 4 +--- .../apache/spark/ml/tree/impl/TreeTests.scala | 2 -- .../org/apache/spark/sql/SparkSession.scala | 19 +++++++------------ .../spark/sql/SparkSessionBuilderSuite.scala | 8 +++----- 4 files changed, 11 insertions(+), 22 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 3094f52ba1bc5..b57fc8d21ab34 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -818,15 +818,13 @@ class ALSCleanerSuite extends SparkFunSuite { FileUtils.listFiles(localDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet try { conf.set("spark.local.dir", localDir.getAbsolutePath) - val sc = new SparkContext("local[2]", "test", conf) + val sc = new SparkContext("local[2]", "ALSCleanerSuite", conf) try { sc.setCheckpointDir(checkpointDir.getAbsolutePath) // Generate test data val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0) // Implicitly test the cleaning of parents during ALS training val spark = SparkSession.builder - .master("local[2]") - .appName("ALSCleanerSuite") .sparkContext(sc) .getOrCreate() import spark.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 92a236928e90b..b6894b30b0c2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -43,8 +43,6 @@ private[ml] object TreeTests extends SparkFunSuite { categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { val spark = SparkSession.builder() - .master("local[2]") - .appName("TreeTests") .sparkContext(data.sparkContext) .getOrCreate() import spark.implicits._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 0ddcd2111aa58..6dfe8a66baa9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -867,7 +867,7 @@ object SparkSession { * * @since 2.2.0 */ - def withExtensions(f: SparkSessionExtensions => Unit): Builder = { + def withExtensions(f: SparkSessionExtensions => Unit): Builder = synchronized { f(extensions) this } @@ -912,21 +912,16 @@ object SparkSession { // No active nor global default session. Create a new one. val sparkContext = userSuppliedContext.getOrElse { - // set app name if not given - val randomAppName = java.util.UUID.randomUUID().toString val sparkConf = new SparkConf() options.foreach { case (k, v) => sparkConf.set(k, v) } + + // set a random app name if not given. if (!sparkConf.contains("spark.app.name")) { - sparkConf.setAppName(randomAppName) - } - val sc = SparkContext.getOrCreate(sparkConf) - // maybe this is an existing SparkContext, update its SparkConf which maybe used - // by SparkSession - options.foreach { case (k, v) => sc.conf.set(k, v) } - if (!sc.conf.contains("spark.app.name")) { - sc.conf.setAppName(randomAppName) + sparkConf.setAppName(java.util.UUID.randomUUID().toString) } - sc + + SparkContext.getOrCreate(sparkConf) + // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions. } // Initialize extensions if the user has defined a configurator class. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index cdac6827082c4..770e15629c839 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -102,11 +102,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(session.conf.get("key1") == "value1") assert(session.conf.get("key2") == "value2") assert(session.sparkContext == sparkContext2) - assert(session.sparkContext.conf.get("key1") == "value1") - // If the created sparkContext is not passed through the Builder's API sparkContext, - // the conf of this sparkContext will also contain the conf set through the API config. - assert(session.sparkContext.conf.get("key2") == "value2") - assert(session.sparkContext.conf.get("spark.app.name") == "test") + // We won't update conf for existing `SparkContext` + assert(!sparkContext2.conf.contains("key2")) + assert(sparkContext2.conf.get("key1") == "value1") session.stop() } From e5bb26174d3336e07dd670eec4fd2137df346163 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 6 Jul 2017 18:11:41 -0700 Subject: [PATCH 114/779] [SPARK-21329][SS] Make EventTimeWatermarkExec explicitly UnaryExecNode ## What changes were proposed in this pull request? Making EventTimeWatermarkExec explicitly UnaryExecNode /cc tdas zsxwing ## How was this patch tested? Local build. Author: Jacek Laskowski Closes #18509 from jaceklaskowski/EventTimeWatermarkExec-UnaryExecNode. --- .../sql/execution/streaming/EventTimeWatermarkExec.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 25cf609fc336e..87e5b78550423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.MetadataBuilder import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.AccumulatorV2 @@ -81,7 +81,7 @@ class EventTimeStatsAccum(protected var currentStats: EventTimeStats = EventTime case class EventTimeWatermarkExec( eventTime: Attribute, delay: CalendarInterval, - child: SparkPlan) extends SparkPlan { + child: SparkPlan) extends UnaryExecNode { val eventTimeStats = new EventTimeStatsAccum() val delayMs = EventTimeWatermark.getDelayMs(delay) @@ -117,6 +117,4 @@ case class EventTimeWatermarkExec( a } } - - override def children: Seq[SparkPlan] = child :: Nil } From d451b7f43d559aa1efd7ac3d1cbec5249f3a7a24 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 7 Jul 2017 12:24:03 +0800 Subject: [PATCH 115/779] [SPARK-21326][SPARK-21066][ML] Use TextFileFormat in LibSVMFileFormat and allow multiple input paths for determining numFeatures ## What changes were proposed in this pull request? This is related with [SPARK-19918](https://issues.apache.org/jira/browse/SPARK-19918) and [SPARK-18362](https://issues.apache.org/jira/browse/SPARK-18362). This PR proposes to use `TextFileFormat` and allow multiple input paths (but with a warning) when determining the number of features in LibSVM data source via an extra scan. There are three points here: - The main advantage of this change should be to remove file-listing bottlenecks in driver side. - Another advantage is ones from using `FileScanRDD`. For example, I guess we can use `spark.sql.files.ignoreCorruptFiles` option when determining the number of features. - We can unify the schema inference code path in text based data sources. This is also a preparation for [SPARK-21289](https://issues.apache.org/jira/browse/SPARK-21289). ## How was this patch tested? Unit tests in `LibSVMRelationSuite`. Closes #18288 Author: hyukjinkwon Closes #18556 from HyukjinKwon/libsvm-schema. --- .../ml/source/libsvm/LibSVMRelation.scala | 26 +++++++++---------- .../org/apache/spark/mllib/util/MLUtils.scala | 25 ++++++++++++++++-- .../source/libsvm/LibSVMRelationSuite.scala | 17 +++++++++--- 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index f68847a664b69..dec118330aec6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.spark.internal.Logging import org.apache.spark.TaskContext import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vectors, VectorUDT} @@ -66,7 +67,10 @@ private[libsvm] class LibSVMOutputWriter( /** @see [[LibSVMDataSource]] for public documentation. */ // If this is moved or renamed, please update DataSource's backwardCompatibilityMap. -private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { +private[libsvm] class LibSVMFileFormat + extends TextBasedFileFormat + with DataSourceRegister + with Logging { override def shortName(): String = "libsvm" @@ -89,18 +93,14 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour files: Seq[FileStatus]): Option[StructType] = { val libSVMOptions = new LibSVMOptions(options) val numFeatures: Int = libSVMOptions.numFeatures.getOrElse { - // Infers number of features if the user doesn't specify (a valid) one. - val dataFiles = files.filterNot(_.getPath.getName startsWith "_") - val path = if (dataFiles.length == 1) { - dataFiles.head.getPath.toUri.toString - } else if (dataFiles.isEmpty) { - throw new IOException("No input path specified for libsvm data") - } else { - throw new IOException("Multiple input paths are not supported for libsvm data.") - } - - val sc = sparkSession.sparkContext - val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism) + require(files.nonEmpty, "No input path specified for libsvm data") + logWarning( + "'numFeatures' option not specified, determining the number of features by going " + + "though the input. If you know the number in advance, please specify it via " + + "'numFeatures' option to avoid the extra scan.") + + val paths = files.map(_.getPath.toUri.toString) + val parsed = MLUtils.parseLibSVMFile(sparkSession, paths) MLUtils.computeNumFeatures(parsed) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4fdad05973969..14af8b5c73870 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -28,8 +28,10 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.BernoulliCellSampler @@ -102,6 +104,25 @@ object MLUtils extends Logging { .map(parseLibSVMRecord) } + private[spark] def parseLibSVMFile( + sparkSession: SparkSession, paths: Seq[String]): RDD[(Double, Array[Int], Array[Double])] = { + val lines = sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value") + + import lines.sqlContext.implicits._ + + lines.select(trim($"value").as("line")) + .filter(not((length($"line") === 0).or($"line".startsWith("#")))) + .as[String] + .rdd + .map(MLUtils.parseLibSVMRecord) + } + private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = { val items = line.split(' ') val label = items.head.toDouble diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index e164d279f3f02..a67e49d54e148 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -35,15 +35,22 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { override def beforeAll(): Unit = { super.beforeAll() - val lines = + val lines0 = """ |1 1:1.0 3:2.0 5:3.0 |0 + """.stripMargin + val lines1 = + """ |0 2:4.0 4:5.0 6:6.0 """.stripMargin val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") - val file = new File(dir, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + val succ = new File(dir, "_SUCCESS") + val file0 = new File(dir, "part-00000") + val file1 = new File(dir, "part-00001") + Files.write("", succ, StandardCharsets.UTF_8) + Files.write(lines0, file0, StandardCharsets.UTF_8) + Files.write(lines1, file1, StandardCharsets.UTF_8) path = dir.toURI.toString } @@ -145,7 +152,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("create libsvmTable table without schema and path") { try { - val e = intercept[IOException](spark.sql("CREATE TABLE libsvmTable USING libsvm")) + val e = intercept[IllegalArgumentException] { + spark.sql("CREATE TABLE libsvmTable USING libsvm") + } assert(e.getMessage.contains("No input path specified for libsvm data")) } finally { spark.sql("DROP TABLE IF EXISTS libsvmTable") From 53c2eb59b2cc557081f6a252748dc38511601b0d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 7 Jul 2017 14:05:22 +0900 Subject: [PATCH 116/779] [SPARK-21327][SQL][PYSPARK] ArrayConstructor should handle an array of typecode 'l' as long rather than int in Python 2. ## What changes were proposed in this pull request? Currently `ArrayConstructor` handles an array of typecode `'l'` as `int` when converting Python object in Python 2 into Java object, so if the value is larger than `Integer.MAX_VALUE` or smaller than `Integer.MIN_VALUE` then the overflow occurs. ```python import array data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))] df = spark.createDataFrame(data) df.show(truncate=False) ``` ``` +----------+ |longarray | +----------+ |[0, 0, -1]| +----------+ ``` This should be: ``` +----------------------------------------------+ |longarray | +----------------------------------------------+ |[-9223372036854775808, 0, 9223372036854775807]| +----------------------------------------------+ ``` ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #18553 from ueshin/issues/SPARK-21327. --- .../scala/org/apache/spark/api/python/SerDeUtil.scala | 10 ++++++++++ python/pyspark/sql/tests.py | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 6e4eab4b805c1..42f67e8dbe865 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -73,6 +73,16 @@ private[spark] object SerDeUtil extends Logging { // This must be ISO 8859-1 / Latin 1, not UTF-8, to interoperate correctly val data = args(1).asInstanceOf[String].getBytes(StandardCharsets.ISO_8859_1) construct(typecode, machineCodes(typecode), data) + } else if (args.length == 2 && args(0) == "l") { + // On Python 2, an array of typecode 'l' should be handled as long rather than int. + val values = args(1).asInstanceOf[JArrayList[_]] + val result = new Array[Long](values.size) + var i = 0 + while (i < values.size) { + result(i) = values.get(i).asInstanceOf[Number].longValue() + i += 1 + } + result } else { super.construct(args) } diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c0e3b8d132396..9db2f40474f70 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2342,6 +2342,12 @@ def test_to_pandas(self): self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) + def test_create_dataframe_from_array_of_long(self): + import array + data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))] + df = self.spark.createDataFrame(data) + self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) + class HiveSparkSubmitTests(SparkSubmitTests): From c09b31eb8fa83d5463a045c9278f5874ae505a8e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 7 Jul 2017 13:09:32 +0800 Subject: [PATCH 117/779] [SPARK-21217][SQL] Support ColumnVector.Array.toArray() ## What changes were proposed in this pull request? This PR implements bulk-copy for `ColumnVector.Array.toArray()` methods (e.g. `toIntArray()`) in `ColumnVector.Array` by using `System.arrayCopy()` or `Platform.copyMemory()`. Before this PR, when one of these method is called, the generic method in `ArrayData` is called. It is not fast since element-wise copy is performed. This PR can improve performance of a benchmark program by 1.9x and 3.2x. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_131-8u131-b11-0ubuntu1.16.04.2-b11 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Int Array Best/Avg Time(ms) Rate(M/s) Per Row(ns) ------------------------------------------------------------------------------------------------ ON_HEAP 586 / 628 14.3 69.9 OFF_HEAP 893 / 902 9.4 106.5 ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_131-8u131-b11-0ubuntu1.16.04.2-b11 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Int Array Best/Avg Time(ms) Rate(M/s) Per Row(ns) ------------------------------------------------------------------------------------------------ ON_HEAP 306 / 331 27.4 36.4 OFF_HEAP 282 / 287 29.8 33.6 ``` Source program ``` (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val len = 8 * 1024 * 1024 val column = ColumnVector.allocate(len * 2, new ArrayType(IntegerType, false), memMode) val data = column.arrayData var i = 0 while (i < len) { data.putInt(i, i) i += 1 } column.putArray(0, 0, len) val benchmark = new Benchmark("Int Array", len, minNumIters = 20) benchmark.addCase(s"$memMode") { iter => var i = 0 while (i < 50) { column.getArray(0).toIntArray i += 1 } } benchmark.run }} ``` ## How was this patch tested? Added test suite Author: Kazuaki Ishizaki Closes #18425 from kiszk/SPARK-21217. --- .../execution/vectorized/ColumnVector.java | 56 ++++++++++++++++++ .../vectorized/OffHeapColumnVector.java | 58 +++++++++++++++++++ .../vectorized/OnHeapColumnVector.java | 58 +++++++++++++++++++ .../vectorized/ColumnarBatchSuite.scala | 49 ++++++++++++++++ 4 files changed, 221 insertions(+) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 24260a60197f2..0c027f80d48cc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -100,6 +100,27 @@ public ArrayData copy() { throw new UnsupportedOperationException(); } + @Override + public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } + + @Override + public byte[] toByteArray() { return data.getBytes(offset, length); } + + @Override + public short[] toShortArray() { return data.getShorts(offset, length); } + + @Override + public int[] toIntArray() { return data.getInts(offset, length); } + + @Override + public long[] toLongArray() { return data.getLongs(offset, length); } + + @Override + public float[] toFloatArray() { return data.getFloats(offset, length); } + + @Override + public double[] toDoubleArray() { return data.getDoubles(offset, length); } + // TODO: this is extremely expensive. @Override public Object[] array() { @@ -366,6 +387,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract boolean getBoolean(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract boolean[] getBooleans(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -386,6 +412,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract byte getByte(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract byte[] getBytes(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -406,6 +437,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract short getShort(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract short[] getShorts(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -432,6 +468,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract int getInt(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract int[] getInts(int rowId, int count); + /** * Returns the dictionary Id for rowId. * This should only be called when the ColumnVector is dictionaryIds. @@ -465,6 +506,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract long getLong(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract long[] getLongs(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -491,6 +537,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract float getFloat(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract float[] getFloats(int rowId, int count); + /** * Sets the value at rowId to `value`. */ @@ -517,6 +568,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract double getDouble(int rowId); + /** + * Gets values from [rowId, rowId + count) + */ + public abstract double[] getDoubles(int rowId, int count); + /** * Puts a byte array that already exists in this column. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index a7d3744d00e91..2d1f3da8e7463 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -134,6 +134,16 @@ public void putBooleans(int rowId, int count, boolean value) { @Override public boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } + @Override + public boolean[] getBooleans(int rowId, int count) { + assert(dictionary == null); + boolean[] array = new boolean[count]; + for (int i = 0; i < count; ++i) { + array[i] = (Platform.getByte(null, data + rowId + i) == 1); + } + return array; + } + // // APIs dealing with Bytes // @@ -165,6 +175,14 @@ public byte getByte(int rowId) { } } + @Override + public byte[] getBytes(int rowId, int count) { + assert(dictionary == null); + byte[] array = new byte[count]; + Platform.copyMemory(null, data + rowId, array, Platform.BYTE_ARRAY_OFFSET, count); + return array; + } + // // APIs dealing with shorts // @@ -197,6 +215,14 @@ public short getShort(int rowId) { } } + @Override + public short[] getShorts(int rowId, int count) { + assert(dictionary == null); + short[] array = new short[count]; + Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + return array; + } + // // APIs dealing with ints // @@ -244,6 +270,14 @@ public int getInt(int rowId) { } } + @Override + public int[] getInts(int rowId, int count) { + assert(dictionary == null); + int[] array = new int[count]; + Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + return array; + } + /** * Returns the dictionary Id for rowId. * This should only be called when the ColumnVector is dictionaryIds. @@ -302,6 +336,14 @@ public long getLong(int rowId) { } } + @Override + public long[] getLongs(int rowId, int count) { + assert(dictionary == null); + long[] array = new long[count]; + Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + return array; + } + // // APIs dealing with floats // @@ -348,6 +390,14 @@ public float getFloat(int rowId) { } } + @Override + public float[] getFloats(int rowId, int count) { + assert(dictionary == null); + float[] array = new float[count]; + Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + return array; + } + // // APIs dealing with doubles @@ -395,6 +445,14 @@ public double getDouble(int rowId) { } } + @Override + public double[] getDoubles(int rowId, int count) { + assert(dictionary == null); + double[] array = new double[count]; + Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + return array; + } + // // APIs dealing with Arrays. // diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 94ed32294cfae..506434364be48 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -130,6 +130,16 @@ public boolean getBoolean(int rowId) { return byteData[rowId] == 1; } + @Override + public boolean[] getBooleans(int rowId, int count) { + assert(dictionary == null); + boolean[] array = new boolean[count]; + for (int i = 0; i < count; ++i) { + array[i] = (byteData[rowId + i] == 1); + } + return array; + } + // // @@ -162,6 +172,14 @@ public byte getByte(int rowId) { } } + @Override + public byte[] getBytes(int rowId, int count) { + assert(dictionary == null); + byte[] array = new byte[count]; + System.arraycopy(byteData, rowId, array, 0, count); + return array; + } + // // APIs dealing with Shorts // @@ -192,6 +210,14 @@ public short getShort(int rowId) { } } + @Override + public short[] getShorts(int rowId, int count) { + assert(dictionary == null); + short[] array = new short[count]; + System.arraycopy(shortData, rowId, array, 0, count); + return array; + } + // // APIs dealing with Ints @@ -234,6 +260,14 @@ public int getInt(int rowId) { } } + @Override + public int[] getInts(int rowId, int count) { + assert(dictionary == null); + int[] array = new int[count]; + System.arraycopy(intData, rowId, array, 0, count); + return array; + } + /** * Returns the dictionary Id for rowId. * This should only be called when the ColumnVector is dictionaryIds. @@ -286,6 +320,14 @@ public long getLong(int rowId) { } } + @Override + public long[] getLongs(int rowId, int count) { + assert(dictionary == null); + long[] array = new long[count]; + System.arraycopy(longData, rowId, array, 0, count); + return array; + } + // // APIs dealing with floats // @@ -325,6 +367,14 @@ public float getFloat(int rowId) { } } + @Override + public float[] getFloats(int rowId, int count) { + assert(dictionary == null); + float[] array = new float[count]; + System.arraycopy(floatData, rowId, array, 0, count); + return array; + } + // // APIs dealing with doubles // @@ -366,6 +416,14 @@ public double getDouble(int rowId) { } } + @Override + public double[] getDoubles(int rowId, int count) { + assert(dictionary == null); + double[] array = new double[count]; + System.arraycopy(doubleData, rowId, array, 0, count); + return array; + } + // // APIs dealing with Arrays // diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 80d41577dcf2d..ccf7aa7022a2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -709,6 +709,55 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("toArray for primitive types") { + // (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + (MemoryMode.ON_HEAP :: Nil).foreach { memMode => { + val len = 4 + + val columnBool = ColumnVector.allocate(len, new ArrayType(BooleanType, false), memMode) + val boolArray = Array(false, true, false, true) + boolArray.zipWithIndex.map { case (v, i) => columnBool.arrayData.putBoolean(i, v) } + columnBool.putArray(0, 0, len) + assert(columnBool.getArray(0).toBooleanArray === boolArray) + + val columnByte = ColumnVector.allocate(len, new ArrayType(ByteType, false), memMode) + val byteArray = Array[Byte](0, 1, 2, 3) + byteArray.zipWithIndex.map { case (v, i) => columnByte.arrayData.putByte(i, v) } + columnByte.putArray(0, 0, len) + assert(columnByte.getArray(0).toByteArray === byteArray) + + val columnShort = ColumnVector.allocate(len, new ArrayType(ShortType, false), memMode) + val shortArray = Array[Short](0, 1, 2, 3) + shortArray.zipWithIndex.map { case (v, i) => columnShort.arrayData.putShort(i, v) } + columnShort.putArray(0, 0, len) + assert(columnShort.getArray(0).toShortArray === shortArray) + + val columnInt = ColumnVector.allocate(len, new ArrayType(IntegerType, false), memMode) + val intArray = Array(0, 1, 2, 3) + intArray.zipWithIndex.map { case (v, i) => columnInt.arrayData.putInt(i, v) } + columnInt.putArray(0, 0, len) + assert(columnInt.getArray(0).toIntArray === intArray) + + val columnLong = ColumnVector.allocate(len, new ArrayType(LongType, false), memMode) + val longArray = Array[Long](0, 1, 2, 3) + longArray.zipWithIndex.map { case (v, i) => columnLong.arrayData.putLong(i, v) } + columnLong.putArray(0, 0, len) + assert(columnLong.getArray(0).toLongArray === longArray) + + val columnFloat = ColumnVector.allocate(len, new ArrayType(FloatType, false), memMode) + val floatArray = Array(0.0F, 1.1F, 2.2F, 3.3F) + floatArray.zipWithIndex.map { case (v, i) => columnFloat.arrayData.putFloat(i, v) } + columnFloat.putArray(0, 0, len) + assert(columnFloat.getArray(0).toFloatArray === floatArray) + + val columnDouble = ColumnVector.allocate(len, new ArrayType(DoubleType, false), memMode) + val doubleArray = Array(0.0, 1.1, 2.2, 3.3) + doubleArray.zipWithIndex.map { case (v, i) => columnDouble.arrayData.putDouble(i, v) } + columnDouble.putArray(0, 0, len) + assert(columnDouble.getArray(0).toDoubleArray === doubleArray) + }} + } + test("Struct Column") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType().add("int", IntegerType).add("double", DoubleType) From 5df99bd364561c6f4c02308149ba5eb71f89247e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 7 Jul 2017 13:12:20 +0800 Subject: [PATCH 118/779] [SPARK-20703][SQL][FOLLOW-UP] Associate metrics with data writes onto DataFrameWriter operations ## What changes were proposed in this pull request? Remove time metrics since it seems no way to measure it in non per-row tracking. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #18558 from viirya/SPARK-20703-followup. --- .../command/DataWritingCommand.scala | 10 --------- .../datasources/FileFormatWriter.scala | 22 +++---------------- .../sql/hive/execution/SQLMetricsSuite.scala | 3 --- 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 0c381a2c02986..700f7f81dc8a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -30,7 +30,6 @@ trait DataWritingCommand extends RunnableCommand { override lazy val metrics: Map[String, SQLMetric] = { val sparkContext = SparkContext.getActive.get Map( - "avgTime" -> SQLMetrics.createMetric(sparkContext, "average writing time (ms)"), "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), @@ -47,23 +46,14 @@ trait DataWritingCommand extends RunnableCommand { var numFiles = 0 var totalNumBytes: Long = 0L var totalNumOutput: Long = 0L - var totalWritingTime: Long = 0L writeSummaries.foreach { summary => numPartitions += summary.updatedPartitions.size numFiles += summary.numOutputFile totalNumBytes += summary.numOutputBytes totalNumOutput += summary.numOutputRows - totalWritingTime += summary.totalWritingTime } - val avgWritingTime = if (numFiles > 0) { - (totalWritingTime / numFiles).toLong - } else { - 0L - } - - metrics("avgTime").add(avgWritingTime) metrics("numFiles").add(numFiles) metrics("numOutputBytes").add(totalNumBytes) metrics("numOutputRows").add(totalNumOutput) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 64866630623ab..9eb9eae699e94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -275,8 +275,6 @@ object FileFormatWriter extends Logging { /** * The data structures used to measure metrics during writing. */ - protected var totalWritingTime: Long = 0L - protected var timeOnCurrentFile: Long = 0L protected var numOutputRows: Long = 0L protected var numOutputBytes: Long = 0L @@ -343,9 +341,7 @@ object FileFormatWriter extends Logging { } val internalRow = iter.next() - val startTime = System.nanoTime() currentWriter.write(internalRow) - timeOnCurrentFile += (System.nanoTime() - startTime) recordsInFile += 1 } releaseResources() @@ -355,17 +351,13 @@ object FileFormatWriter extends Logging { updatedPartitions = Set.empty, numOutputFile = fileCounter + 1, numOutputBytes = numOutputBytes, - numOutputRows = numOutputRows, - totalWritingTime = totalWritingTime) + numOutputRows = numOutputRows) } override def releaseResources(): Unit = { if (currentWriter != null) { try { - val startTime = System.nanoTime() currentWriter.close() - totalWritingTime += (timeOnCurrentFile + System.nanoTime() - startTime) / 1000 / 1000 - timeOnCurrentFile = 0 numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, currentPath) } finally { currentWriter = null @@ -504,9 +496,7 @@ object FileFormatWriter extends Logging { releaseResources() newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) } - val startTime = System.nanoTime() currentWriter.write(getOutputRow(row)) - timeOnCurrentFile += (System.nanoTime() - startTime) recordsInFile += 1 } if (currentPartColsAndBucketId != null) { @@ -519,17 +509,13 @@ object FileFormatWriter extends Logging { updatedPartitions = updatedPartitions.toSet, numOutputFile = totalFileCounter, numOutputBytes = numOutputBytes, - numOutputRows = numOutputRows, - totalWritingTime = totalWritingTime) + numOutputRows = numOutputRows) } override def releaseResources(): Unit = { if (currentWriter != null) { try { - val startTime = System.nanoTime() currentWriter.close() - totalWritingTime += (timeOnCurrentFile + System.nanoTime() - startTime) / 1000 / 1000 - timeOnCurrentFile = 0 numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, currentPath) } finally { currentWriter = null @@ -547,11 +533,9 @@ object FileFormatWriter extends Logging { * @param numOutputFile the total number of files. * @param numOutputRows the number of output rows. * @param numOutputBytes the bytes of output data. - * @param totalWritingTime the total writing time in ms. */ case class ExecutedWriteSummary( updatedPartitions: Set[String], numOutputFile: Int, numOutputRows: Long, - numOutputBytes: Long, - totalWritingTime: Long) + numOutputBytes: Long) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala index 1ef1988d4c605..24c038587d1d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala @@ -65,9 +65,6 @@ class SQLMetricsSuite extends SQLTestUtils with TestHiveSingleton { val totalNumBytesMetric = executedNode.metrics.find(_.name == "bytes of written output").get val totalNumBytes = metrics(totalNumBytesMetric.accumulatorId).replaceAll(",", "").toInt assert(totalNumBytes > 0) - val writingTimeMetric = executedNode.metrics.find(_.name == "average writing time (ms)").get - val writingTime = metrics(writingTimeMetric.accumulatorId).replaceAll(",", "").toInt - assert(writingTime >= 0) } private def testMetricsNonDynamicPartition( From 7fcbb9b57f5eba8b14bf7d86ebaa08a8ee937cd2 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 7 Jul 2017 08:31:30 +0100 Subject: [PATCH 119/779] [SPARK-21313][SS] ConsoleSink's string representation ## What changes were proposed in this pull request? Add `toString` with options for `ConsoleSink` so it shows nicely in query progress. **BEFORE** ``` "sink" : { "description" : "org.apache.spark.sql.execution.streaming.ConsoleSink4b340441" } ``` **AFTER** ``` "sink" : { "description" : "ConsoleSink[numRows=10, truncate=false]" } ``` /cc zsxwing tdas ## How was this patch tested? Local build Author: Jacek Laskowski Closes #18539 from jaceklaskowski/SPARK-21313-ConsoleSink-toString. --- .../org/apache/spark/sql/execution/streaming/ForeachSink.scala | 2 ++ .../org/apache/spark/sql/execution/streaming/console.scala | 2 ++ 2 files changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index de09fb568d2a6..2cc54107f8b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -63,4 +63,6 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria } } } + + override def toString(): String = "ForeachSink" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 3baea6376069f..1c9284e252bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -52,6 +52,8 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) .show(numRowsToShow, isTruncated) } + + override def toString(): String = s"ConsoleSink[numRows=$numRowsToShow, truncate=$isTruncated]" } case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) From 56536e9992ac4ea771758463962e49bba410e896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 7 Jul 2017 18:32:01 +0800 Subject: [PATCH 120/779] [SPARK-21285][ML] VectorAssembler reports the column name of unsupported data type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? add the column name in the exception which is raised by unsupported data type. ## How was this patch tested? + [x] pass all tests. Author: Yan Facai (颜发才) Closes #18523 from facaiy/ENH/vectorassembler_add_col. --- .../apache/spark/ml/feature/VectorAssembler.scala | 15 +++++++++------ .../spark/ml/feature/VectorAssemblerSuite.scala | 5 ++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index ca900536bc7b8..73f27d1a423d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -113,12 +113,15 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) val outputColName = $(outputCol) - val inputDataTypes = inputColNames.map(name => schema(name).dataType) - inputDataTypes.foreach { - case _: NumericType | BooleanType => - case t if t.isInstanceOf[VectorUDT] => - case other => - throw new IllegalArgumentException(s"Data type $other is not supported.") + val incorrectColumns = inputColNames.flatMap { name => + schema(name).dataType match { + case _: NumericType | BooleanType => None + case t if t.isInstanceOf[VectorUDT] => None + case other => Some(s"Data type $other of column $name is not supported.") + } + } + if (incorrectColumns.nonEmpty) { + throw new IllegalArgumentException(incorrectColumns.mkString("\n")) } if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 46cced3a9a6e5..6aef1c6837025 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -79,7 +79,10 @@ class VectorAssemblerSuite val thrown = intercept[IllegalArgumentException] { assembler.transform(df) } - assert(thrown.getMessage contains "Data type StringType is not supported") + assert(thrown.getMessage contains + "Data type StringType of column a is not supported.\n" + + "Data type StringType of column b is not supported.\n" + + "Data type StringType of column c is not supported.") } test("ML attributes") { From fef081309fc28efe8e136f363d85d7ccd9466e61 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Jul 2017 20:04:30 +0800 Subject: [PATCH 121/779] [SPARK-21335][SQL] support un-aliased subquery ## What changes were proposed in this pull request? un-aliased subquery is supported by Spark SQL for a long time. Its semantic was not well defined and had confusing behaviors, and it's not a standard SQL syntax, so we disallowed it in https://issues.apache.org/jira/browse/SPARK-20690 . However, this is a breaking change, and we do have existing queries using un-aliased subquery. We should add the support back and fix its semantic. This PR fixes the un-aliased subquery by assigning a default alias name. After this PR, there is no syntax change from branch 2.2 to master, but we invalid a weird use case: `SELECT v.i from (SELECT i FROM v)`. Now this query will throw analysis exception because users should not be able to use the qualifier inside a subquery. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #18559 from cloud-fan/sub-query. --- .../sql/catalyst/parser/AstBuilder.scala | 16 ++-- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../sql/catalyst/parser/PlanParserSuite.scala | 13 --- .../resources/sql-tests/inputs/group-by.sql | 2 +- .../test/resources/sql-tests/inputs/limit.sql | 2 +- .../sql-tests/inputs/string-functions.sql | 2 +- .../in-subquery/in-set-operations.sql | 2 +- .../negative-cases/invalid-correlation.sql | 2 +- .../scalar-subquery-predicate.sql | 2 +- .../test/resources/sql-tests/inputs/union.sql | 4 +- .../results/columnresolution-negative.sql.out | 16 ++-- .../sql-tests/results/group-by.sql.out | 2 +- .../resources/sql-tests/results/limit.sql.out | 2 +- .../results/string-functions.sql.out | 6 +- .../in-subquery/in-set-operations.sql.out | 2 +- .../invalid-correlation.sql.out | 2 +- .../scalar-subquery-predicate.sql.out | 2 +- .../results/subquery/subquery-in-from.sql.out | 20 +---- .../resources/sql-tests/results/union.sql.out | 4 +- .../apache/spark/sql/CachedTableSuite.scala | 82 ++++++------------- .../org/apache/spark/sql/SQLQuerySuite.scala | 13 +++ .../org/apache/spark/sql/SubquerySuite.scala | 8 +- 22 files changed, 83 insertions(+), 123 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b6a4686bb9ec9..4d725904bc9b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -751,15 +751,17 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * hooks. */ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { - // The unaliased subqueries in the FROM clause are disallowed. Instead of rejecting it in - // parser rules, we handle it here in order to provide better error message. - if (ctx.strictIdentifier == null) { - throw new ParseException("The unaliased subqueries in the FROM clause are not supported.", - ctx) + val alias = if (ctx.strictIdentifier == null) { + // For un-aliased subqueries, use a default alias name that is not likely to conflict with + // normal subquery names, so that parent operators can only access the columns in subquery by + // unqualified names. Users can still use this special qualifier to access columns if they + // know it, but that's not recommended. + "__auto_generated_subquery_name" + } else { + ctx.strictIdentifier.getText } - aliasPlan(ctx.strictIdentifier, - plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)) + SubqueryAlias(alias, plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8649603b1a9f5..9b440cd99f994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -253,7 +253,7 @@ abstract class LogicalPlan // More than one match. case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1).mkString(", ") + val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") throw new AnalysisException( s"Reference '$name' is ambiguous, could be: $referenceNames.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 5b2573fa4d601..6dad097041a15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -450,19 +450,6 @@ class PlanParserSuite extends AnalysisTest { | (select id from t0)) as u_1 """.stripMargin, plan.union(plan).union(plan).as("u_1").select('id)) - - } - - test("aliased subquery") { - val errMsg = "The unaliased subqueries in the FROM clause are not supported" - - assertEqual("select a from (select id as a from t0) tt", - table("t0").select('id.as("a")).as("tt").select('a)) - intercept("select a from (select id as a from t0)", errMsg) - - assertEqual("from (select id as a from t0) tt select a", - table("t0").select('id.as("a")).as("tt").select('a)) - intercept("from (select id as a from t0) select a", errMsg) } test("scalar sub-query") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index bc2120727dac2..1e1384549a410 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -34,7 +34,7 @@ SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), FROM testData; -- Aggregate with foldable input and multiple distinct groups. -SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) t GROUP BY a; +SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; -- Aliases in SELECT could be used in GROUP BY SELECT a AS k, COUNT(b) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index df555bdc1976d..f21912a042716 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -21,7 +21,7 @@ SELECT * FROM testdata LIMIT true; SELECT * FROM testdata LIMIT 'a'; -- limit within a subquery -SELECT * FROM (SELECT * FROM range(10) LIMIT 5) t WHERE id > 3; +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3; -- limit ALL SELECT * FROM testdata WHERE key < 3 LIMIT ALL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 20c0390664037..c95f4817b7ce0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -7,7 +7,7 @@ select 'a' || 'b' || 'c'; -- Check if catalyst combine nested `Concat`s EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col -FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) t; +FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)); -- replace function select replace('abc', 'b', '123'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql index 42f84e9748713..5c371d2305ac8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql @@ -394,7 +394,7 @@ FROM (SELECT * FROM t1)) t4 WHERE t4.t2b IN (SELECT Min(t3b) FROM t3 - WHERE t4.t2a = t3a)) T; + WHERE t4.t2a = t3a)); -- UNION, UNION ALL, UNION DISTINCT, INTERSECT and EXCEPT for NOT IN -- TC 01.12 diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql index f3f0c7622ccdb..e22cade936792 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql @@ -23,7 +23,7 @@ AND t2b = (SELECT max(avg) FROM (SELECT t2b, avg(t2b) avg FROM t2 WHERE t2a = t1.t1b - ) T + ) ) ; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index dbe8d76d2f117..fb0d07fbdace7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -19,7 +19,7 @@ AND c.cv = (SELECT max(avg) FROM (SELECT c1.cv, avg(c1.cv) avg FROM c c1 WHERE c1.ck = p.pk - GROUP BY c1.cv) T); + GROUP BY c1.cv)); create temporary view t1 as select * from values ('val1a', 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 00:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index 63bc044535e4d..e57d69eaad033 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -5,7 +5,7 @@ CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2); SELECT * FROM (SELECT * FROM t1 UNION ALL - SELECT * FROM t1) T; + SELECT * FROM t1); -- Type Coerced Union SELECT * @@ -13,7 +13,7 @@ FROM (SELECT * FROM t1 UNION ALL SELECT * FROM t2 UNION ALL - SELECT * FROM t2) T; + SELECT * FROM t2); -- Regression test for SPARK-18622 SELECT a diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index 9e60e592c2bd1..b5a4f5c2bf654 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -72,7 +72,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 9 @@ -81,7 +81,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 10 @@ -99,7 +99,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 12 @@ -108,7 +108,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 13 @@ -125,7 +125,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 15 @@ -134,7 +134,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 16 @@ -143,7 +143,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 16 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 17 @@ -152,7 +152,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 -- !query 18 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index e23ebd4e822fa..986bb01c13fe4 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -134,7 +134,7 @@ struct -- !query 14 output diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index afdd6df2a5714..146abe6cbd058 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -93,7 +93,7 @@ The limit expression must be integer type, but got string; -- !query 10 -SELECT * FROM (SELECT * FROM range(10) LIMIT 5) t WHERE id > 3 +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 -- !query 10 schema struct -- !query 10 output diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 52eb554edf89e..b0ae9d775d968 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -30,20 +30,20 @@ abc -- !query 3 EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col -FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) t +FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) -- !query 3 schema struct -- !query 3 output == Parsed Logical Plan == 'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x] -+- 'SubqueryAlias t ++- 'SubqueryAlias __auto_generated_subquery_name +- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x] +- 'UnresolvedTableValuedFunction range, [10] == Analyzed Logical Plan == col: string Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] -+- SubqueryAlias t ++- SubqueryAlias __auto_generated_subquery_name +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] +- Range (0, 10, step=1, splits=None) diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out index 5780f49648ec7..e06f9206d3401 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out @@ -496,7 +496,7 @@ FROM (SELECT * FROM t1)) t4 WHERE t4.t2b IN (SELECT Min(t3b) FROM t3 - WHERE t4.t2a = t3a)) T + WHERE t4.t2a = t3a)) -- !query 13 schema struct -- !query 13 output diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index ca3930b33e06d..e4b1a2dbc675c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -40,7 +40,7 @@ AND t2b = (SELECT max(avg) FROM (SELECT t2b, avg(t2b) avg FROM t2 WHERE t2a = t1.t1b - ) T + ) ) -- !query 3 schema struct<> diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index 1d5dddca76a17..8b29300e71f90 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -39,7 +39,7 @@ AND c.cv = (SELECT max(avg) FROM (SELECT c1.cv, avg(c1.cv) avg FROM c c1 WHERE c1.ck = p.pk - GROUP BY c1.cv) T) + GROUP BY c1.cv)) -- !query 3 schema struct -- !query 3 output diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out index 14553557d1ffc..50370df349168 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out @@ -37,26 +37,14 @@ struct -- !query 4 SELECT * FROM (SELECT * FROM testData) WHERE key = 1 -- !query 4 schema -struct<> +struct -- !query 4 output -org.apache.spark.sql.catalyst.parser.ParseException - -The unaliased subqueries in the FROM clause are not supported.(line 1, pos 14) - -== SQL == -SELECT * FROM (SELECT * FROM testData) WHERE key = 1 ---------------^^^ +1 1 -- !query 5 FROM (SELECT * FROM testData WHERE key = 1) SELECT * -- !query 5 schema -struct<> +struct -- !query 5 output -org.apache.spark.sql.catalyst.parser.ParseException - -The unaliased subqueries in the FROM clause are not supported.(line 1, pos 5) - -== SQL == -FROM (SELECT * FROM testData WHERE key = 1) SELECT * ------^^^ +1 1 diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index 865b3aed65d70..d123b7fdbe0cf 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -22,7 +22,7 @@ struct<> SELECT * FROM (SELECT * FROM t1 UNION ALL - SELECT * FROM t1) T + SELECT * FROM t1) -- !query 2 schema struct -- !query 2 output @@ -38,7 +38,7 @@ FROM (SELECT * FROM t1 UNION ALL SELECT * FROM t2 UNION ALL - SELECT * FROM t2) T + SELECT * FROM t2) -- !query 3 schema struct -- !query 3 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 506cc2548e260..3e4f619431599 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -631,13 +631,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val ds2 = sql( """ - |SELECT * FROM (SELECT max(c1) as c1 FROM t1 GROUP BY c1) tt + |SELECT * FROM (SELECT c1, max(c1) FROM t1 GROUP BY c1) |WHERE - |tt.c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) |OR |EXISTS (SELECT c1 FROM t3) |OR - |tt.c1 IN (SELECT c1 FROM t4) + |c1 IN (SELECT c1 FROM t4) """.stripMargin) assert(getNumInMemoryRelations(ds2) == 4) } @@ -683,20 +683,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext Seq(1).toDF("c1").createOrReplaceTempView("t1") Seq(2).toDF("c1").createOrReplaceTempView("t2") - sql( + val sql1 = """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t2) - """.stripMargin).cache() + """.stripMargin + sql(sql1).cache() - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |NOT EXISTS (SELECT * FROM t2) - """.stripMargin) + val cachedDs = sql(sql1) assert(getNumInMemoryRelations(cachedDs) == 1) // Additional predicate in the subquery plan should cause a cache miss @@ -717,20 +712,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext Seq(1).toDF("c1").createOrReplaceTempView("t2") // Simple correlated predicate in subquery - sql( + val sqlText = """ |SELECT * FROM t1 |WHERE |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) - """.stripMargin).cache() + """.stripMargin + sql(sqlText).cache() - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) - """.stripMargin) + val cachedDs = sql(sqlText) assert(getNumInMemoryRelations(cachedDs) == 1) } } @@ -741,22 +731,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("t1") // underlying table t1 is cached as well as the query that refers to it. - val ds = - sql( + val sqlText = """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t1) - """.stripMargin) + """.stripMargin + val ds = sql(sqlText) assert(getNumInMemoryRelations(ds) == 2) - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |NOT EXISTS (SELECT * FROM t1) - """.stripMargin).cache() + val cachedDs = sql(sqlText).cache() assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3) } } @@ -769,45 +753,31 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext Seq(1).toDF("c1").createOrReplaceTempView("t4") // Nested predicate subquery - sql( + val sql1 = """ |SELECT * FROM t1 |WHERE |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin).cache() + """.stripMargin + sql(sql1).cache() - val cachedDs = - sql( - """ - |SELECT * FROM t1 - |WHERE - |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin) + val cachedDs = sql(sql1) assert(getNumInMemoryRelations(cachedDs) == 1) // Scalar subquery and predicate subquery - sql( + val sql2 = """ - |SELECT * FROM (SELECT max(c1) as c1 FROM t1 GROUP BY c1) tt + |SELECT * FROM (SELECT c1, max(c1) FROM t1 GROUP BY c1) |WHERE - |tt.c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) |OR |EXISTS (SELECT c1 FROM t3) |OR - |tt.c1 IN (SELECT c1 FROM t4) - """.stripMargin).cache() + |c1 IN (SELECT c1 FROM t4) + """.stripMargin + sql(sql2).cache() - val cachedDs2 = - sql( - """ - |SELECT * FROM (SELECT max(c1) as c1 FROM t1 GROUP BY c1) tt - |WHERE - |tt.c1 = (SELECT max(c1) FROM t2 GROUP BY c1) - |OR - |EXISTS (SELECT c1 FROM t3) - |OR - |tt.c1 IN (SELECT c1 FROM t4) - """.stripMargin) + val cachedDs2 = sql(sql2) assert(getNumInMemoryRelations(cachedDs2) == 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5171aaebc9907..472ff7385b194 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2638,4 +2638,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21335: support un-aliased subquery") { + withTempView("v") { + Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("v") + checkAnswer(sql("SELECT i from (SELECT i FROM v)"), Row(1)) + + val e = intercept[AnalysisException](sql("SELECT v.i from (SELECT i FROM v)")) + assert(e.message == + "cannot resolve '`v.i`' given input columns: [__auto_generated_subquery_name.i]") + + checkAnswer(sql("SELECT __auto_generated_subquery_name.i from (SELECT i FROM v)"), Row(1)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index c0a3b5add313a..7bcb419e8df6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -112,7 +112,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | with t4 as (select 1 as d, 3 as e) | select * from t4 cross join t2 where t2.b = t4.d | ) - | select a from (select 1 as a union all select 2 as a) t + | select a from (select 1 as a union all select 2 as a) | where a = (select max(d) from t3) """.stripMargin), Array(Row(1)) @@ -606,8 +606,8 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | select cntPlusOne + 1 as cntPlusTwo from ( | select cnt + 1 as cntPlusOne from ( | select sum(r.c) s, count(*) cnt from r where l.a = r.c having cnt = 0 - | ) t1 - | ) t2 + | ) + | ) |) = 2""".stripMargin), Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) } @@ -655,7 +655,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { """ | select c1 from onerow t1 | where exists (select 1 - | from (select 1 as c1 from onerow t2 LIMIT 1) t2 + | from (select c1 from onerow t2 LIMIT 1) t2 | where t1.c1=t2.c1)""".stripMargin), Row(1) :: Nil) } From fbbe37ed416f2ca9d8fc713a135b335b8a0247bf Mon Sep 17 00:00:00 2001 From: CodingCat Date: Fri, 7 Jul 2017 20:10:24 +0800 Subject: [PATCH 122/779] [SPARK-19358][CORE] LiveListenerBus shall log the event name when dropping them due to a fully filled queue ## What changes were proposed in this pull request? Some dropped event will make the whole application behaves unexpectedly, e.g. some UI problem...we shall log the dropped event name to facilitate the debugging ## How was this patch tested? Existing tests Author: CodingCat Closes #16697 from CodingCat/SPARK-19358. --- .../main/scala/org/apache/spark/scheduler/LiveListenerBus.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index f0887e090b956..0dd63d4392800 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -232,6 +232,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { "This likely means one of the SparkListeners is too slow and cannot keep up with " + "the rate at which tasks are being started by the scheduler.") } + logTrace(s"Dropping event $event") } } From a0fe32a219253f0abe9d67cf178c73daf5f6fcc1 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 7 Jul 2017 15:39:29 -0700 Subject: [PATCH 123/779] [SPARK-21336] Revise rand comparison in BatchEvalPythonExecSuite ## What changes were proposed in this pull request? Revise rand comparison in BatchEvalPythonExecSuite In BatchEvalPythonExecSuite, there are two cases using the case "rand() > 3" Rand() generates a random value in [0, 1), it is wired to be compared with 3, use 0.3 instead ## How was this patch tested? unit test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18560 from gengliangwang/revise_BatchEvalPythonExecSuite. --- .../spark/sql/execution/python/BatchEvalPythonExecSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 80ef4eb75ca53..bbd9484271a3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -65,7 +65,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { test("Python UDF: no push down on non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") - .where("b > 4 and dummyPythonUDF(a) and rand() > 3") + .where("b > 4 and dummyPythonUDF(a) and rand() > 0.3") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec( And(_: AttributeReference, _: GreaterThan), @@ -77,7 +77,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { test("Python UDF: no push down on predicates starting from the first non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") - .where("dummyPythonUDF(a) and rand() > 3 and b > 4") + .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4") val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f } From e1a172c201d68406faa53b113518b10c879f1ff6 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Sat, 8 Jul 2017 13:47:41 +0800 Subject: [PATCH 124/779] [SPARK-21100][SQL] Add summary method as alternative to describe that gives quartiles similar to Pandas ## What changes were proposed in this pull request? Adds method `summary` that allows user to specify which statistics and percentiles to calculate. By default it include the existing statistics from `describe` and quartiles (25th, 50th, and 75th percentiles) similar to Pandas. Also changes the implementation of `describe` to delegate to `summary`. ## How was this patch tested? additional unit test Author: Andrew Ray Closes #18307 from aray/SPARK-21100. --- .../scala/org/apache/spark/sql/Dataset.scala | 113 +++++++++++------- .../sql/execution/stat/StatFunctions.scala | 98 ++++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 112 +++++++++++++---- 3 files changed, 258 insertions(+), 65 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b1638a2180b07..5326b45b50a8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -38,18 +38,18 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -224,7 +224,7 @@ class Dataset[T] private[sql]( } } - private def aggregatableColumns: Seq[Expression] = { + private[sql] def aggregatableColumns: Seq[Expression] = { schema.fields .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) .map { n => @@ -2161,9 +2161,9 @@ class Dataset[T] private[sql]( } /** - * Computes statistics for numeric and string columns, including count, mean, stddev, min, and - * max. If no columns are given, this function computes statistics for all numerical or string - * columns. + * Computes basic statistics for numeric and string columns, including count, mean, stddev, min, + * and max. If no columns are given, this function computes statistics for all numerical or + * string columns. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting Dataset. If you want to @@ -2181,46 +2181,79 @@ class Dataset[T] private[sql]( * // max 92.0 192.0 * }}} * + * Use [[summary]] for expanded statistics and control over which statistics to compute. + * + * @param cols Columns to compute statistics on. + * * @group action * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = withPlan { - - // The list of summary statistics to compute, in the form of expressions. - val statistics = List[(String, Expression => Expression)]( - "count" -> ((child: Expression) => Count(child).toAggregateExpression()), - "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), - "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), - "min" -> ((child: Expression) => Min(child).toAggregateExpression()), - "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - - val outputCols = - (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList - - val ret: Seq[Row] = if (outputCols.nonEmpty) { - val aggExprs = statistics.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) - } - - val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq - - // Pivot the data so each summary is one row - row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) - } - } else { - // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => Row(name) } - } - - // All columns are string type - val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - // `toArray` forces materialization to make the seq serializable - LocalRelation.fromExternalRows(schema, ret.toArray.toSeq) + def describe(cols: String*): DataFrame = { + val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*) + selected.summary("count", "mean", "stddev", "min", "max") } + /** + * Computes specified statistics for numeric and string columns. Available statistics are: + * + * - count + * - mean + * - stddev + * - min + * - max + * - arbitrary approximate percentiles specified as a percentage (eg, 75%) + * + * If no statistics are given, this function computes count, mean, stddev, min, + * approximate quartiles (percentiles at 25%, 50%, and 75%), and max. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting Dataset. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * ds.summary().show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // 25% 24.0 176.0 + * // 50% 24.0 176.0 + * // 75% 32.0 180.0 + * // max 92.0 192.0 + * }}} + * + * {{{ + * ds.summary("count", "min", "25%", "75%", "max").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // min 18.0 163.0 + * // 25% 24.0 176.0 + * // 75% 32.0 180.0 + * // max 92.0 192.0 + * }}} + * + * To do a summary for specific columns first select them: + * + * {{{ + * ds.select("age", "height").summary().show() + * }}} + * + * See also [[describe]] for basic statistics. + * + * @param statistics Statistics from above list to be computed. + * + * @group action + * @since 2.3.0 + */ + @scala.annotation.varargs + def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq) + /** * Returns the first `n` rows. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 1debad03c93fa..436e18fdb5ff5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -220,4 +221,97 @@ object StatFunctions extends Logging { Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } + + /** Calculate selected summary statistics for a dataset */ + def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = { + + val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics + + val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) + val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { + val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) + val percentiles = pStrings.map { p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) + } + } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") + (percentiles, pStrings, rest) + } else { + (Seq(), Seq(), selectedStatistics) + } + + + // The list of summary statistics to compute, in the form of expressions. + val availableStatistics = Map[String, Expression => Expression]( + "count" -> ((child: Expression) => Count(child).toAggregateExpression()), + "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), + "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), + "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "max" -> ((child: Expression) => Max(child).toAggregateExpression())) + + val statisticFns = remainingAggregates.map { agg => + require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") + agg -> availableStatistics(agg) + } + + def percentileAgg(child: Expression): Expression = + new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) + .toAggregateExpression() + + val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) + } + if (hasPercentiles) { + aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs + } + + val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq + + val basicStats = if (hasPercentiles) grouped.tail else grouped + + val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) + } + + if (hasPercentiles) { + def nullSafeString(x: Any) = if (x == null) null else x.toString + val percentileRows = grouped.head + .map { + case a: Seq[Any] => a + case _ => Seq.fill(percentiles.length)(null: Any) + } + .transpose + .zip(percentileNames) + .map { case (values: Seq[Any], name) => + Row(name :: values.map(nullSafeString).toList: _*) + } + (rows ++ percentileRows) + .sortWith((left, right) => + selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) + } else { + rows + } + } else { + // If there are no output columns, just output a single column that contains the stats. + selectedStatistics.map(Row(_)) + } + + // All columns are string type + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes + // `toArray` forces materialization to make the seq serializable + Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9ea9951c24ef1..2c7051bf431c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,8 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} @@ -663,13 +662,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("describe") { - val describeTestData = Seq( - ("Bob", 16, 176), - ("Alice", 32, 164), - ("David", 60, 192), - ("Amy", 24, 180)).toDF("name", "age", "height") + private lazy val person2: DataFrame = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + test("describe") { val describeResult = Seq( Row("count", "4", "4", "4"), Row("mean", null, "33.0", "178.0"), @@ -686,32 +685,99 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeTwoCols = describeTestData.describe("name", "age", "height") - assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) - checkAnswer(describeTwoCols, describeResult) - // All aggregate value should have been cast to string - describeTwoCols.collect().foreach { row => - assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) - assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) - } - - val describeAllCols = describeTestData.describe() + val describeAllCols = person2.describe() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeAllCols, describeResult) + // All aggregate value should have been cast to string + describeAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } + } - val describeOneCol = describeTestData.describe("age") + val describeOneCol = person2.describe("age") assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) - val describeNoCol = describeTestData.select("name").describe() - assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) - checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) + val describeNoCol = person2.select().describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} ) - val emptyDescription = describeTestData.limit(0).describe() + val emptyDescription = person2.limit(0).describe() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) } + test("summary") { + val summaryResult = Seq( + Row("count", "4", "4", "4"), + Row("mean", null, "33.0", "178.0"), + Row("stddev", null, "19.148542155126762", "11.547005383792516"), + Row("min", "Alice", "16", "164"), + Row("25%", null, "24.0", "176.0"), + Row("50%", null, "24.0", "176.0"), + Row("75%", null, "32.0", "180.0"), + Row("max", "David", "60", "192")) + + val emptySummaryResult = Seq( + Row("count", "0", "0", "0"), + Row("mean", null, null, null), + Row("stddev", null, null, null), + Row("min", null, null, null), + Row("25%", null, null, null), + Row("50%", null, null, null), + Row("75%", null, null, null), + Row("max", null, null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val summaryAllCols = person2.summary() + + assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height")) + checkAnswer(summaryAllCols, summaryResult) + // All aggregate value should have been cast to string + summaryAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } + } + + val summaryOneCol = person2.select("age").summary() + assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age")) + checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} ) + + val summaryNoCol = person2.select().summary() + assert(getSchemaAsSeq(summaryNoCol) === Seq("summary")) + checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} ) + + val emptyDescription = person2.limit(0).summary() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) + checkAnswer(emptyDescription, emptySummaryResult) + } + + test("summary advanced") { + val stats = Array("count", "50.01%", "max", "mean", "min", "25%") + val orderMatters = person2.summary(stats: _*) + assert(orderMatters.collect().map(_.getString(0)) === stats) + + val onlyPercentiles = person2.summary("0.1%", "99.9%") + assert(onlyPercentiles.count() === 2) + + val fooE = intercept[IllegalArgumentException] { + person2.summary("foo") + } + assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") + + val parseE = intercept[IllegalArgumentException] { + person2.summary("foo%") + } + assert(parseE.getMessage === "Unable to parse foo% as a percentile") + } + test("apply on query results (SPARK-5462)") { val df = testData.sparkSession.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) From 7896e7b99d95d28800f5644bd36b3990cf0ef8c4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 7 Jul 2017 23:05:38 -0700 Subject: [PATCH 125/779] [SPARK-21281][SQL] Use string types by default if array and map have no argument ## What changes were proposed in this pull request? This pr modified code to use string types by default if `array` and `map` in functions have no argument. This behaviour is the same with Hive one; ``` hive> CREATE TEMPORARY TABLE t1 AS SELECT map(); hive> DESCRIBE t1; _c0 map hive> CREATE TEMPORARY TABLE t2 AS SELECT array(); hive> DESCRIBE t2; _c0 array ``` ## How was this patch tested? Added tests in `DataFrameFunctionsSuite`. Author: Takeshi Yamamuro Closes #18516 from maropu/SPARK-21281. --- .../sql/catalyst/expressions/arithmetic.scala | 10 +++--- .../expressions/complexTypeCreator.scala | 35 ++++++++++-------- .../spark/sql/catalyst/expressions/hash.scala | 5 +-- .../expressions/nullExpressions.scala | 7 ++-- .../ExpressionTypeCheckingSuite.scala | 4 +-- .../org/apache/spark/sql/functions.scala | 10 ++---- .../spark/sql/DataFrameFunctionsSuite.scala | 36 +++++++++++++++++++ 7 files changed, 74 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ec6e6ba0f091b..423bf66a24d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -527,13 +527,14 @@ case class Least(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least two arguments") } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } @@ -592,13 +593,14 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least two arguments") } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 98c4cbee38dee..d9eeb5358ef79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -41,12 +41,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") + override def checkInputDataTypes(): TypeCheckResult = { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } override def dataType: ArrayType = { ArrayType( - children.headOption.map(_.dataType).getOrElse(NullType), + children.headOption.map(_.dataType).getOrElse(StringType), containsNull = children.exists(_.nullable)) } @@ -93,7 +94,7 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, - s"$arrayName = new Object[${numElements}];") + s"$arrayName = new Object[$numElements];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -119,7 +120,7 @@ private [sql] object GenArrayData { UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName, ""); + ctx.addMutableState("UnsafeArrayData", arrayDataName, "") val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -169,13 +170,16 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.") + TypeCheckResult.TypeCheckFailure( + s"$prettyName expects a positive even number of arguments.") } else if (keys.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + - "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + TypeCheckResult.TypeCheckFailure( + "The given keys of function map should all be the same type, but they are " + + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) } else if (values.map(_.dataType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " + - "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + TypeCheckResult.TypeCheckFailure( + "The given values of function map should all be the same type, but they are " + + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) } else { TypeCheckResult.TypeCheckSuccess } @@ -183,8 +187,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(NullType), - valueType = values.headOption.map(_.dataType).getOrElse(NullType), + keyType = keys.headOption.map(_.dataType).getOrElse(StringType), + valueType = values.headOption.map(_.dataType).getOrElse(StringType), valueContainsNull = values.exists(_.nullable)) } @@ -292,14 +296,17 @@ trait CreateNamedStructLike extends Expression { } override def checkInputDataTypes(): TypeCheckResult = { - if (children.size % 2 != 0) { + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") + } else if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ffd0e64d86cff..2476fc962a6fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -247,8 +247,9 @@ abstract class HashExpression[E] extends Expression { override def nullable: Boolean = false override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckFailure("function hash requires at least one argument") + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0866b8d791e01..1b625141d56ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -52,10 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { - if (children == Nil) { - TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + if (children.length < 1) { + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName requires at least one argument") } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 30459f173ab52..30725773a37b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { "input to function array should all be the same type") assertError(Coalesce(Seq('intField, 'booleanField)), "input to function coalesce should all be the same type") - assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(Coalesce(Nil), "function coalesce requires at least one argument") assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(Explode('intField), "input to function explode should be array or map type") @@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for Greatest/Least") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('booleanField)), "requires at least two arguments") assertError(operator(Seq('intField, 'stringField)), "should all have the same type") assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3c67960d13e09..1263071a3ffd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1565,10 +1565,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "greatest requires at least 2 arguments.") - Greatest(exprs.map(_.expr)) - } + def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) } /** * Returns the greatest value of the list of column names, skipping null values. @@ -1672,10 +1669,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = withExpr { - require(exprs.length > 1, "least requires at least 2 arguments.") - Least(exprs.map(_.expr)) - } + def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) } /** * Returns the least value of the list of column names, skipping null values. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0e9a2c6cf7dec..0681b9cbeb1d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -448,6 +448,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { rand(Random.nextLong()), randn(Random.nextLong()) ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + + test("SPARK-21281 use string types by default if array and map have no argument") { + val ds = spark.range(1) + var expectedSchema = new StructType() + .add("x", ArrayType(StringType, containsNull = false), nullable = false) + assert(ds.select(array().as("x")).schema == expectedSchema) + expectedSchema = new StructType() + .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) + assert(ds.select(map().as("x")).schema == expectedSchema) + } + + test("SPARK-21281 fails if functions have no argument") { + val df = Seq(1).toDF("a") + + val funcsMustHaveAtLeastOneArg = + ("coalesce", (df: DataFrame) => df.select(coalesce())) :: + ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: + ("named_struct", (df: DataFrame) => df.select(struct())) :: + ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: + ("hash", (df: DataFrame) => df.select(hash())) :: + ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil + funcsMustHaveAtLeastOneArg.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least one argument")) + } + + val funcsMustHaveAtLeastTwoArgs = + ("greatest", (df: DataFrame) => df.select(greatest())) :: + ("greatest", (df: DataFrame) => df.selectExpr("greatest()")) :: + ("least", (df: DataFrame) => df.select(least())) :: + ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil + funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) => + val errMsg = intercept[AnalysisException] { func(df) }.getMessage + assert(errMsg.contains(s"input to function $name requires at least two arguments")) + } + } } object DataFrameFunctionsSuite { From 9760c15acbcf755dd5b13597ceb333576f806ecf Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 8 Jul 2017 14:20:09 +0800 Subject: [PATCH 126/779] [SPARK-20379][CORE] Allow SSL config to reference env variables. This change exposes the internal code path in SparkConf that allows configs to be read with variable substitution applied, and uses that new method in SSLOptions so that SSL configs can reference other variables, and more importantly, environment variables, providing a secure way to provide passwords to Spark when using SSL. The approach is a little bit hacky, but is the smallest change possible. Otherwise, the concept of "namespaced configs" would have to be added to the config system, which would create a lot of noise for not much gain at this point. Tested with added unit tests, and on a real cluster with SSL enabled. Author: Marcelo Vanzin Closes #18394 from vanzin/SPARK-20379.try2. --- .../scala/org/apache/spark/SSLOptions.scala | 20 +++++++++---------- .../scala/org/apache/spark/SparkConf.scala | 5 +++++ .../org/apache/spark/SSLOptionsSuite.scala | 16 +++++++++++++++ 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 29163e7f30546..f86fd20e59190 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -167,39 +167,39 @@ private[spark] object SSLOptions extends Logging { def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) - val port = conf.getOption(s"$ns.port").map(_.toInt) + val port = conf.getWithSubstitution(s"$ns.port").map(_.toInt) port.foreach { p => require(p >= 0, "Port number must be a non-negative value.") } - val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_)) + val keyStore = conf.getWithSubstitution(s"$ns.keyStore").map(new File(_)) .orElse(defaults.flatMap(_.keyStore)) - val keyStorePassword = conf.getOption(s"$ns.keyStorePassword") + val keyStorePassword = conf.getWithSubstitution(s"$ns.keyStorePassword") .orElse(defaults.flatMap(_.keyStorePassword)) - val keyPassword = conf.getOption(s"$ns.keyPassword") + val keyPassword = conf.getWithSubstitution(s"$ns.keyPassword") .orElse(defaults.flatMap(_.keyPassword)) - val keyStoreType = conf.getOption(s"$ns.keyStoreType") + val keyStoreType = conf.getWithSubstitution(s"$ns.keyStoreType") .orElse(defaults.flatMap(_.keyStoreType)) val needClientAuth = conf.getBoolean(s"$ns.needClientAuth", defaultValue = defaults.exists(_.needClientAuth)) - val trustStore = conf.getOption(s"$ns.trustStore").map(new File(_)) + val trustStore = conf.getWithSubstitution(s"$ns.trustStore").map(new File(_)) .orElse(defaults.flatMap(_.trustStore)) - val trustStorePassword = conf.getOption(s"$ns.trustStorePassword") + val trustStorePassword = conf.getWithSubstitution(s"$ns.trustStorePassword") .orElse(defaults.flatMap(_.trustStorePassword)) - val trustStoreType = conf.getOption(s"$ns.trustStoreType") + val trustStoreType = conf.getWithSubstitution(s"$ns.trustStoreType") .orElse(defaults.flatMap(_.trustStoreType)) - val protocol = conf.getOption(s"$ns.protocol") + val protocol = conf.getWithSubstitution(s"$ns.protocol") .orElse(defaults.flatMap(_.protocol)) - val enabledAlgorithms = conf.getOption(s"$ns.enabledAlgorithms") + val enabledAlgorithms = conf.getWithSubstitution(s"$ns.enabledAlgorithms") .map(_.split(",").map(_.trim).filter(_.nonEmpty).toSet) .orElse(defaults.map(_.enabledAlgorithms)) .getOrElse(Set.empty) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index de2f475c6895f..715cfdcc8f4ef 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -373,6 +373,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) } + /** Get an optional value, applying variable substitution. */ + private[spark] def getWithSubstitution(key: String): Option[String] = { + getOption(key).map(reader.substitute(_)) + } + /** Get all parameters as a list of pairs */ def getAll: Array[(String, String)] = { settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 6fc7cea6ee94a..8eabc2b3cb958 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -22,6 +22,8 @@ import javax.net.ssl.SSLContext import org.scalatest.BeforeAndAfterAll +import org.apache.spark.util.SparkConfWithEnv + class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { test("test resolving property file as spark conf ") { @@ -133,4 +135,18 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(opts.enabledAlgorithms === Set("ABC", "DEF")) } + test("variable substitution") { + val conf = new SparkConfWithEnv(Map( + "ENV1" -> "val1", + "ENV2" -> "val2")) + + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", "${env:ENV1}") + conf.set("spark.ssl.trustStore", "${env:ENV2}") + + val opts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + assert(opts.keyStore === Some(new File("val1"))) + assert(opts.trustStore === Some(new File("val2"))) + } + } From d0bfc6733521709e453d643582df2bdd68f28de7 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Fri, 7 Jul 2017 23:33:12 -0700 Subject: [PATCH 127/779] [SPARK-21069][SS][DOCS] Add rate source to programming guide. ## What changes were proposed in this pull request? SPARK-20979 added a new structured streaming source: Rate source. This patch adds the corresponding documentation to programming guide. ## How was this patch tested? Tested by running jekyll locally. Author: Prashant Sharma Author: Prashant Sharma Closes #18562 from ScrapCodes/spark-21069/rate-source-docs. --- docs/structured-streaming-programming-guide.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 3bc377c9a38b5..8f64faadc32dc 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -499,6 +499,8 @@ There are a few built-in sources. - **Socket source (for testing)** - Reads UTF8 text data from a socket connection. The listening server socket is at the driver. Note that this should be used only for testing as this does not provide end-to-end fault-tolerance guarantees. + - **Rate source (for testing)** - Generates data at the specified number of rows per second, each output row contains a `timestamp` and `value`. Where `timestamp` is a `Timestamp` type containing the time of message dispatch, and `value` is of `Long` type containing the message count, starting from 0 as the first row. This source is intended for testing and benchmarking. + Some sources are not fault-tolerant because they do not guarantee that data can be replayed using checkpointed offsets after a failure. See the earlier section on [fault-tolerance semantics](#fault-tolerance-semantics). @@ -546,6 +548,19 @@ Here are the details of all the sources in Spark. No + + Rate Source + + rowsPerSecond (e.g. 100, default: 1): How many rows should be generated per second.

+ rampUpTime (e.g. 5s, default: 0s): How long to ramp up before the generating speed becomes rowsPerSecond. Using finer granularities than seconds will be truncated to integer seconds.

+ numPartitions (e.g. 10, default: Spark's default parallelism): The partition number for the generated rows.

+ + The source will try its best to reach rowsPerSecond, but the query may be resource constrained, and numPartitions can be tweaked to help reach the desired speed. + + Yes + + + Kafka Source From a7b46c627b5d2461257f337139a29f23350e0c77 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Fri, 7 Jul 2017 23:51:32 -0700 Subject: [PATCH 128/779] [SPARK-20307][SPARKR] SparkR: pass on setHandleInvalid to spark.mllib functions that use StringIndexer ## What changes were proposed in this pull request? For randomForest classifier, if test data contains unseen labels, it will throw an error. The StringIndexer already has the handleInvalid logic. The patch add a new method to set the underlying StringIndexer handleInvalid logic. This patch should also apply to other classifiers. This PR focuses on the main logic and randomForest classifier. I will do follow-up PR for other classifiers. ## How was this patch tested? Add a new unit test based on the error case in the JIRA. Author: wangmiao1981 Closes #18496 from wangmiao1981/handle. --- R/pkg/R/mllib_tree.R | 11 ++++++-- R/pkg/tests/fulltests/test_mllib_tree.R | 17 +++++++++++++ .../apache/spark/ml/feature/RFormula.scala | 25 +++++++++++++++++++ .../r/RandomForestClassificationWrapper.scala | 4 ++- .../spark/ml/feature/StringIndexerSuite.scala | 2 +- 5 files changed, 55 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 2f1220a752783..75b1a74ee8c7c 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE) { + maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo new("RandomForestRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(impurity)) impurity <- "gini" impurity <- match.arg(impurity, c("gini", "entropy")) jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", @@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo as.numeric(minInfoGain), as.integer(checkpointInterval), as.character(featureSubsetStrategy), seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("RandomForestClassificationModel", jobj = jobj) } ) diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R index 9b3fc8d270b25..66a0693a59a52 100644 --- a/R/pkg/tests/fulltests/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -212,6 +212,23 @@ test_that("spark.randomForest", { expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50) + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10, + handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") + # spark.randomForest classification can work on libsvm data if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 4b44878784c90..61aa6463bb6da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -132,6 +132,30 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def getFormula: String = $(formula) + /** + * Param for how to handle invalid data (unseen labels or NULL values). + * Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional + * bucket, at index numLabels). + * Default: "error" + * @group param + */ + @Since("2.3.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " + + "invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** @group getParam */ + @Since("2.3.0") + def getHandleInvalid: String = $(handleInvalid) + /** @group setParam */ @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -197,6 +221,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setInputCol(term) .setOutputCol(indexCol) .setStringOrderType($(stringIndexerOrderType)) + .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 8a83d4e980f7b..132345fb9a6d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -78,11 +78,13 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC seed: String, subsamplingRate: Double, maxMemoryInMB: Int, - cacheNodeIds: Boolean): RandomForestClassifierWrapper = { + cacheNodeIds: Boolean, + handleInvalid: String): RandomForestClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 806a92760c8b6..027b1fbc6657c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col From f5f02d213d3151f58070e113d64fcded4f5d401e Mon Sep 17 00:00:00 2001 From: Michael Patterson Date: Fri, 7 Jul 2017 23:59:34 -0700 Subject: [PATCH 129/779] [SPARK-20456][DOCS] Add examples for functions collection for pyspark ## What changes were proposed in this pull request? This adds documentation to many functions in pyspark.sql.functions.py: `upper`, `lower`, `reverse`, `unix_timestamp`, `from_unixtime`, `rand`, `randn`, `collect_list`, `collect_set`, `lit` Add units to the trigonometry functions. Renames columns in datetime examples to be more informative. Adds links between some functions. ## How was this patch tested? `./dev/lint-python` `python python/pyspark/sql/functions.py` `./python/run-tests.py --module pyspark-sql` Author: Michael Patterson Closes #17865 from map222/spark-20456. --- R/pkg/R/functions.R | 11 +- python/pyspark/sql/functions.py | 166 +++++++++++------- .../org/apache/spark/sql/functions.scala | 14 +- 3 files changed, 119 insertions(+), 72 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index c529d83060f50..f28d26a51baa0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -336,7 +336,8 @@ setMethod("asin", }) #' @details -#' \code{atan}: Computes the tangent inverse of the given value. +#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range +#' -pi/2 through pi/2. #' #' @rdname column_math_functions #' @export @@ -599,7 +600,7 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr }) #' @details -#' \code{cos}: Computes the cosine of the given value. +#' \code{cos}: Computes the cosine of the given value. Units in radians. #' #' @rdname column_math_functions #' @aliases cos cos,Column-method @@ -1407,7 +1408,7 @@ setMethod("sign", signature(x = "Column"), }) #' @details -#' \code{sin}: Computes the sine of the given value. +#' \code{sin}: Computes the sine of the given value. Units in radians. #' #' @rdname column_math_functions #' @aliases sin sin,Column-method @@ -1597,7 +1598,7 @@ setMethod("sumDistinct", }) #' @details -#' \code{tan}: Computes the tangent of the given value. +#' \code{tan}: Computes the tangent of the given value. Units in radians. #' #' @rdname column_math_functions #' @aliases tan tan,Column-method @@ -1896,7 +1897,7 @@ setMethod("year", #' @details #' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates -#' (x, y) to polar coordinates (r, theta). +#' (x, y) to polar coordinates (r, theta). Units in radians. #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3416c4b118a07..5d8ded83f667d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -67,9 +67,14 @@ def _(): _.__doc__ = 'Window function: ' + doc return _ +_lit_doc = """ + Creates a :class:`Column` of literal value. + >>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1) + [Row(height=5, spark_user=True)] + """ _functions = { - 'lit': 'Creates a :class:`Column` of literal value.', + 'lit': _lit_doc, 'col': 'Returns a :class:`Column` based on the given column name.', 'column': 'Returns a :class:`Column` based on the given column name.', 'asc': 'Returns a sort expression based on the ascending order of the given column name.', @@ -95,10 +100,13 @@ def _(): '0.0 through pi.', 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + '-pi/2 through pi/2.', - 'atan': 'Computes the tangent inverse of the given value.', + 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' + + '-pi/2 through pi/2', 'cbrt': 'Computes the cube-root of the given value.', 'ceil': 'Computes the ceiling of the given value.', - 'cos': 'Computes the cosine of the given value.', + 'cos': """Computes the cosine of the given value. + + :param col: :class:`DoubleType` column, units in radians.""", 'cosh': 'Computes the hyperbolic cosine of the given value.', 'exp': 'Computes the exponential of the given value.', 'expm1': 'Computes the exponential of the given value minus one.', @@ -109,15 +117,33 @@ def _(): 'rint': 'Returns the double value that is closest in value to the argument and' + ' is equal to a mathematical integer.', 'signum': 'Computes the signum of the given value.', - 'sin': 'Computes the sine of the given value.', + 'sin': """Computes the sine of the given value. + + :param col: :class:`DoubleType` column, units in radians.""", 'sinh': 'Computes the hyperbolic sine of the given value.', - 'tan': 'Computes the tangent of the given value.', + 'tan': """Computes the tangent of the given value. + + :param col: :class:`DoubleType` column, units in radians.""", 'tanh': 'Computes the hyperbolic tangent of the given value.', - 'toDegrees': '.. note:: Deprecated in 2.1, use degrees instead.', - 'toRadians': '.. note:: Deprecated in 2.1, use radians instead.', + 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.', + 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.', 'bitwiseNOT': 'Computes bitwise not.', } +_collect_list_doc = """ + Aggregate function: returns a list of objects with duplicates. + + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) + >>> df2.agg(collect_list('age')).collect() + [Row(collect_list(age)=[2, 5, 5])] + """ +_collect_set_doc = """ + Aggregate function: returns a set of objects with duplicate elements eliminated. + + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) + >>> df2.agg(collect_set('age')).collect() + [Row(collect_set(age)=[5, 2])] + """ _functions_1_6 = { # unary math functions 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + @@ -131,9 +157,8 @@ def _(): 'var_pop': 'Aggregate function: returns the population variance of the values in a group.', 'skewness': 'Aggregate function: returns the skewness of the values in a group.', 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', - 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', - 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + - ' eliminated.', + 'collect_list': _collect_list_doc, + 'collect_set': _collect_set_doc } _functions_2_1 = { @@ -147,7 +172,7 @@ def _(): # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + - 'polar coordinates (r, theta).', + 'polar coordinates (r, theta). Units in radians.', 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } @@ -200,17 +225,20 @@ def _(): @since(1.3) def approxCountDistinct(col, rsd=None): """ - .. note:: Deprecated in 2.1, use approx_count_distinct instead. + .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead. """ return approx_count_distinct(col, rsd) @since(2.1) def approx_count_distinct(col, rsd=None): - """Returns a new :class:`Column` for approximate distinct count of ``col``. + """Aggregate function: returns a new :class:`Column` for approximate distinct count of column `col`. - >>> df.agg(approx_count_distinct(df.age).alias('c')).collect() - [Row(c=2)] + :param rsd: maximum estimation error allowed (default = 0.05). For rsd < 0.01, it is more + efficient to use :func:`countDistinct` + + >>> df.agg(approx_count_distinct(df.age).alias('distinct_ages')).collect() + [Row(distinct_ages=2)] """ sc = SparkContext._active_spark_context if rsd is None: @@ -267,8 +295,7 @@ def coalesce(*cols): @since(1.6) def corr(col1, col2): - """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` - and ``col2``. + """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` and ``col2``. >>> a = range(20) >>> b = [2 * x for x in range(20)] @@ -282,8 +309,7 @@ def corr(col1, col2): @since(2.0) def covar_pop(col1, col2): - """Returns a new :class:`Column` for the population covariance of ``col1`` - and ``col2``. + """Returns a new :class:`Column` for the population covariance of ``col1`` and ``col2``. >>> a = [1] * 10 >>> b = [1] * 10 @@ -297,8 +323,7 @@ def covar_pop(col1, col2): @since(2.0) def covar_samp(col1, col2): - """Returns a new :class:`Column` for the sample covariance of ``col1`` - and ``col2``. + """Returns a new :class:`Column` for the sample covariance of ``col1`` and ``col2``. >>> a = [1] * 10 >>> b = [1] * 10 @@ -450,7 +475,7 @@ def monotonically_increasing_id(): def nanvl(col1, col2): """Returns col1 if it is not NaN, or col2 if col1 is NaN. - Both inputs should be floating point columns (DoubleType or FloatType). + Both inputs should be floating point columns (:class:`DoubleType` or :class:`FloatType`). >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() @@ -460,10 +485,15 @@ def nanvl(col1, col2): return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) +@ignore_unicode_prefix @since(1.4) def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples from U[0.0, 1.0]. + + >>> df.withColumn('rand', rand(seed=42) * 3).collect() + [Row(age=2, name=u'Alice', rand=1.1568609015300986), + Row(age=5, name=u'Bob', rand=1.403379671529166)] """ sc = SparkContext._active_spark_context if seed is not None: @@ -473,10 +503,15 @@ def rand(seed=None): return Column(jc) +@ignore_unicode_prefix @since(1.4) def randn(seed=None): """Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution. + + >>> df.withColumn('randn', randn(seed=42)).collect() + [Row(age=2, name=u'Alice', randn=-0.7556247885860078), + Row(age=5, name=u'Bob', randn=-0.0861619008451133)] """ sc = SparkContext._active_spark_context if seed is not None: @@ -760,7 +795,7 @@ def ntile(n): @since(1.5) def current_date(): """ - Returns the current date as a date column. + Returns the current date as a :class:`DateType` column. """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.current_date()) @@ -768,7 +803,7 @@ def current_date(): def current_timestamp(): """ - Returns the current timestamp as a timestamp column. + Returns the current timestamp as a :class:`TimestampType` column. """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.current_timestamp()) @@ -787,8 +822,8 @@ def date_format(date, format): .. note:: Use when ever possible specialized functions like `year`. These benefit from a specialized implementation. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect() [Row(date=u'04/08/2015')] """ sc = SparkContext._active_spark_context @@ -800,8 +835,8 @@ def year(col): """ Extract the year of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(year('a').alias('year')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(year('dt').alias('year')).collect() [Row(year=2015)] """ sc = SparkContext._active_spark_context @@ -813,8 +848,8 @@ def quarter(col): """ Extract the quarter of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(quarter('a').alias('quarter')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(quarter('dt').alias('quarter')).collect() [Row(quarter=2)] """ sc = SparkContext._active_spark_context @@ -826,8 +861,8 @@ def month(col): """ Extract the month of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(month('a').alias('month')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(month('dt').alias('month')).collect() [Row(month=4)] """ sc = SparkContext._active_spark_context @@ -839,8 +874,8 @@ def dayofmonth(col): """ Extract the day of the month of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(dayofmonth('a').alias('day')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofmonth('dt').alias('day')).collect() [Row(day=8)] """ sc = SparkContext._active_spark_context @@ -852,8 +887,8 @@ def dayofyear(col): """ Extract the day of the year of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(dayofyear('a').alias('day')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofyear('dt').alias('day')).collect() [Row(day=98)] """ sc = SparkContext._active_spark_context @@ -865,8 +900,8 @@ def hour(col): """ Extract the hours of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) - >>> df.select(hour('a').alias('hour')).collect() + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> df.select(hour('ts').alias('hour')).collect() [Row(hour=13)] """ sc = SparkContext._active_spark_context @@ -878,8 +913,8 @@ def minute(col): """ Extract the minutes of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) - >>> df.select(minute('a').alias('minute')).collect() + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> df.select(minute('ts').alias('minute')).collect() [Row(minute=8)] """ sc = SparkContext._active_spark_context @@ -891,8 +926,8 @@ def second(col): """ Extract the seconds of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) - >>> df.select(second('a').alias('second')).collect() + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> df.select(second('ts').alias('second')).collect() [Row(second=15)] """ sc = SparkContext._active_spark_context @@ -904,8 +939,8 @@ def weekofyear(col): """ Extract the week number of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(weekofyear(df.a).alias('week')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(weekofyear(df.dt).alias('week')).collect() [Row(week=15)] """ sc = SparkContext._active_spark_context @@ -917,9 +952,9 @@ def date_add(start, days): """ Returns the date that is `days` days after `start` - >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) - >>> df.select(date_add(df.d, 1).alias('d')).collect() - [Row(d=datetime.date(2015, 4, 9))] + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_add(df.dt, 1).alias('next_date')).collect() + [Row(next_date=datetime.date(2015, 4, 9))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.date_add(_to_java_column(start), days)) @@ -930,9 +965,9 @@ def date_sub(start, days): """ Returns the date that is `days` days before `start` - >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) - >>> df.select(date_sub(df.d, 1).alias('d')).collect() - [Row(d=datetime.date(2015, 4, 7))] + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect() + [Row(prev_date=datetime.date(2015, 4, 7))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) @@ -956,9 +991,9 @@ def add_months(start, months): """ Returns the date that is `months` months after `start` - >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) - >>> df.select(add_months(df.d, 1).alias('d')).collect() - [Row(d=datetime.date(2015, 5, 8))] + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(add_months(df.dt, 1).alias('next_month')).collect() + [Row(next_month=datetime.date(2015, 5, 8))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.add_months(_to_java_column(start), months)) @@ -969,8 +1004,8 @@ def months_between(date1, date2): """ Returns the number of months between date1 and date2. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) - >>> df.select(months_between(df.t, df.d).alias('months')).collect() + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) + >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() [Row(months=3.9495967...)] """ sc = SparkContext._active_spark_context @@ -1073,12 +1108,17 @@ def last_day(date): return Column(sc._jvm.functions.last_day(_to_java_column(date))) +@ignore_unicode_prefix @since(1.5) def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"): """ Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing the timestamp of that moment in the current system time zone in the given format. + + >>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time']) + >>> time_df.select(from_unixtime('unix_time').alias('ts')).collect() + [Row(ts=u'2015-04-08 00:00:00')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format)) @@ -1092,6 +1132,10 @@ def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'): locale, return null if fail. if `timestamp` is None, then it returns current timestamp. + + >>> time_df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> time_df.select(unix_timestamp('dt', 'yyyy-MM-dd').alias('unix_time')).collect() + [Row(unix_time=1428476400)] """ sc = SparkContext._active_spark_context if timestamp is None: @@ -1106,8 +1150,8 @@ def from_utc_timestamp(timestamp, tz): that corresponds to the same time of day in the given timezone. >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect() - [Row(t=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1119,9 +1163,9 @@ def to_utc_timestamp(timestamp, tz): Given a timestamp, which corresponds to a certain time of day in the given timezone, returns another timestamp that corresponds to the same time of day in UTC. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect() - [Row(t=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts']) + >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) @@ -2095,7 +2139,7 @@ def _test(): sc = spark.sparkContext globs['sc'] = sc globs['spark'] = spark - globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() + globs['df'] = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1263071a3ffd5..a5e4a444f33be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1321,7 +1321,8 @@ object functions { def asin(columnName: String): Column = asin(Column(columnName)) /** - * Computes the tangent inverse of the given value. + * Computes the tangent inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2 * * @group math_funcs * @since 1.4.0 @@ -1329,7 +1330,8 @@ object functions { def atan(e: Column): Column = withExpr { Atan(e.expr) } /** - * Computes the tangent inverse of the given column. + * Computes the tangent inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2 * * @group math_funcs * @since 1.4.0 @@ -1338,7 +1340,7 @@ object functions { /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * polar coordinates (r, theta). Units in radians. * * @group math_funcs * @since 1.4.0 @@ -1470,7 +1472,7 @@ object functions { } /** - * Computes the cosine of the given value. + * Computes the cosine of the given value. Units in radians. * * @group math_funcs * @since 1.4.0 @@ -1937,7 +1939,7 @@ object functions { def signum(columnName: String): Column = signum(Column(columnName)) /** - * Computes the sine of the given value. + * Computes the sine of the given value. Units in radians. * * @group math_funcs * @since 1.4.0 @@ -1969,7 +1971,7 @@ object functions { def sinh(columnName: String): Column = sinh(Column(columnName)) /** - * Computes the tangent of the given value. + * Computes the tangent of the given value. Units in radians. * * @group math_funcs * @since 1.4.0 From 01f183e8497d4931f1fe5c69ff16fe84b1e41492 Mon Sep 17 00:00:00 2001 From: Joachim Hereth Date: Sat, 8 Jul 2017 08:32:45 +0100 Subject: [PATCH 130/779] Mesos doc fixes ## What changes were proposed in this pull request? Some link fixes for the documentation [Running Spark on Mesos](https://spark.apache.org/docs/latest/running-on-mesos.html): * Updated Link to Mesos Frameworks (Projects built on top of Mesos) * Update Link to Mesos binaries from Mesosphere (former link was redirected to dcos install page) ## How was this patch tested? Documentation was built and changed page manually/visually inspected. No code was changed, hence no dev tests. Since these changes are rather trivial I did not open a new JIRA ticket. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Joachim Hereth Closes #18564 from daten-kieker/mesos_doc_fixes. --- docs/running-on-mesos.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ec130c1db8f5f..7401b63e022c1 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -10,7 +10,7 @@ Spark can run on hardware clusters managed by [Apache Mesos](http://mesos.apache The advantages of deploying Spark with Mesos include: - dynamic partitioning between Spark and other - [frameworks](https://mesos.apache.org/documentation/latest/mesos-frameworks/) + [frameworks](https://mesos.apache.org/documentation/latest/frameworks/) - scalable partitioning between multiple instances of Spark # How it Works @@ -61,7 +61,7 @@ third party projects publish binary releases that may be helpful in setting Meso One of those is Mesosphere. To install Mesos using the binary releases provided by Mesosphere: -1. Download Mesos installation package from [downloads page](http://mesosphere.io/downloads/) +1. Download Mesos installation package from [downloads page](https://open.mesosphere.com/downloads/mesos/) 2. Follow their instructions for installation and configuration The Mesosphere installation documents suggest setting up ZooKeeper to handle Mesos master failover, From 330bf5c99825afb6129577a34e6bed8b221a98cc Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sat, 8 Jul 2017 08:34:51 +0100 Subject: [PATCH 131/779] [SPARK-20609][MLLIB][TEST] manually cleared 'spark.local.dir' before/after a test in ALSCleanerSuite ## What changes were proposed in this pull request? This PR is similar to #17869. Once` 'spark.local.dir'` is set. Unless this is manually cleared before/after a test. it could return the same directory even if this property is configured. and add before/after for each likewise in ALSCleanerSuite. ## How was this patch tested? existing test. Author: caoxuewen Closes #18537 from heary-cao/ALSCleanerSuite. --- .../spark/ml/recommendation/ALSSuite.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index b57fc8d21ab34..0a0fea255c7f3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -29,6 +29,7 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.commons.io.FileUtils import org.apache.commons.io.filefilter.TrueFileFilter +import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.internal.Logging @@ -777,7 +778,20 @@ class ALSSuite } } -class ALSCleanerSuite extends SparkFunSuite { +class ALSCleanerSuite extends SparkFunSuite with BeforeAndAfterEach { + override def beforeEach(): Unit = { + super.beforeEach() + // Once `Utils.getOrCreateLocalRootDirs` is called, it is cached in `Utils.localRootDirs`. + // Unless this is manually cleared before and after a test, it returns the same directory + // set before even if 'spark.local.dir' is configured afterwards. + Utils.clearLocalRootDirs() + } + + override def afterEach(): Unit = { + Utils.clearLocalRootDirs() + super.afterEach() + } + test("ALS shuffle cleanup standalone") { val conf = new SparkConf() val localDir = Utils.createTempDir() From 0b8dd2d08460f3e6eb578727d2c336b6f11959e7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 8 Jul 2017 20:16:47 +0800 Subject: [PATCH 132/779] [SPARK-21345][SQL][TEST][TEST-MAVEN] SparkSessionBuilderSuite should clean up stopped sessions. ## What changes were proposed in this pull request? `SparkSessionBuilderSuite` should clean up stopped sessions. Otherwise, it leaves behind some stopped `SparkContext`s interfereing with other test suites using `ShardSQLContext`. Recently, master branch fails consequtively. - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/ ## How was this patch tested? Pass the Jenkins with a updated suite. Author: Dongjoon Hyun Closes #18567 from dongjoon-hyun/SPARK-SESSION. --- .../spark/sql/SparkSessionBuilderSuite.scala | 46 ++++++++----------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 770e15629c839..c0301f2ce2d66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -17,50 +17,49 @@ package org.apache.spark.sql +import org.scalatest.BeforeAndAfterEach + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.internal.SQLConf /** * Test cases for the builder pattern of [[SparkSession]]. */ -class SparkSessionBuilderSuite extends SparkFunSuite { +class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { - private var initialSession: SparkSession = _ + override def afterEach(): Unit = { + // This suite should not interfere with the other test suites. + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.getDefaultSession.foreach(_.stop()) + SparkSession.clearDefaultSession() + } - private lazy val sparkContext: SparkContext = { - initialSession = SparkSession.builder() + test("create with config options and propagate them to SparkContext and SparkSession") { + val session = SparkSession.builder() .master("local") .config("spark.ui.enabled", value = false) .config("some-config", "v2") .getOrCreate() - initialSession.sparkContext - } - - test("create with config options and propagate them to SparkContext and SparkSession") { - // Creating a new session with config - this works by just calling the lazy val - sparkContext - assert(initialSession.sparkContext.conf.get("some-config") == "v2") - assert(initialSession.conf.get("some-config") == "v2") - SparkSession.clearDefaultSession() + assert(session.sparkContext.conf.get("some-config") == "v2") + assert(session.conf.get("some-config") == "v2") } test("use global default session") { - val session = SparkSession.builder().getOrCreate() + val session = SparkSession.builder().master("local").getOrCreate() assert(SparkSession.builder().getOrCreate() == session) - SparkSession.clearDefaultSession() } test("config options are propagated to existing SparkSession") { - val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate() + val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate() assert(session1 == session2) assert(session1.conf.get("spark-config1") == "b") - SparkSession.clearDefaultSession() } test("use session from active thread session and propagate config options") { - val defaultSession = SparkSession.builder().getOrCreate() + val defaultSession = SparkSession.builder().master("local").getOrCreate() val activeSession = defaultSession.newSession() SparkSession.setActiveSession(activeSession) val session = SparkSession.builder().config("spark-config2", "a").getOrCreate() @@ -73,16 +72,14 @@ class SparkSessionBuilderSuite extends SparkFunSuite { SparkSession.clearActiveSession() assert(SparkSession.builder().getOrCreate() == defaultSession) - SparkSession.clearDefaultSession() } test("create a new session if the default session has been stopped") { - val defaultSession = SparkSession.builder().getOrCreate() + val defaultSession = SparkSession.builder().master("local").getOrCreate() SparkSession.setDefaultSession(defaultSession) defaultSession.stop() val newSession = SparkSession.builder().master("local").getOrCreate() assert(newSession != defaultSession) - newSession.stop() } test("create a new session if the active thread session has been stopped") { @@ -91,11 +88,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite { activeSession.stop() val newSession = SparkSession.builder().master("local").getOrCreate() assert(newSession != activeSession) - newSession.stop() } test("create SparkContext first then SparkSession") { - sparkContext.stop() val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") val sparkContext2 = new SparkContext(conf) val session = SparkSession.builder().config("key2", "value2").getOrCreate() @@ -105,11 +100,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite { // We won't update conf for existing `SparkContext` assert(!sparkContext2.conf.contains("key2")) assert(sparkContext2.conf.get("key1") == "value1") - session.stop() } test("create SparkContext first then pass context to SparkSession") { - sparkContext.stop() val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") val newSC = new SparkContext(conf) val session = SparkSession.builder().sparkContext(newSC).config("key2", "value2").getOrCreate() @@ -121,14 +114,12 @@ class SparkSessionBuilderSuite extends SparkFunSuite { // the conf of this sparkContext will not contain the conf set through the API config. assert(!session.sparkContext.conf.contains("key2")) assert(session.sparkContext.conf.get("spark.app.name") == "test") - session.stop() } test("SPARK-15887: hive-site.xml should be loaded") { val session = SparkSession.builder().master("local").getOrCreate() assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true") assert(session.sparkContext.hadoopConfiguration.get("hive.in.test") == "true") - session.stop() } test("SPARK-15991: Set global Hadoop conf") { @@ -140,7 +131,6 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(session.sessionState.newHadoopConf().get(mySpecialKey) == mySpecialValue) } finally { session.sparkContext.hadoopConfiguration.unset(mySpecialKey) - session.stop() } } } From 9fccc3627fa41d32fbae6dbbb9bd1521e43eb4f0 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sat, 8 Jul 2017 20:44:12 +0800 Subject: [PATCH 133/779] [SPARK-21083][SQL] Store zero size and row count when analyzing empty table ## What changes were proposed in this pull request? We should be able to store zero size and row count after analyzing empty table. This pr also enhances the test cases for re-analyzing tables. ## How was this patch tested? Added a new test case and enhanced some test cases. Author: Zhenhua Wang Closes #18292 from wzhfy/analyzeNewColumn. --- .../command/AnalyzeTableCommand.scala | 5 +- .../spark/sql/StatisticsCollectionSuite.scala | 13 +++++ .../spark/sql/hive/StatisticsSuite.scala | 52 ++++++++++++++++--- 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 42e2a9ca5c4e2..cba147c35dd99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} -import org.apache.spark.sql.execution.SQLExecution /** @@ -40,10 +39,10 @@ case class AnalyzeTableCommand( } val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta) - val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(0L) + val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(-1L) val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) var newStats: Option[CatalogStatistics] = None - if (newTotalSize > 0 && newTotalSize != oldTotalSize) { + if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) } // We only set rowCount when noscan is false, because otherwise: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 843ced7f0e697..b80bd80e93e8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -82,6 +82,19 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("analyze empty table") { + val table = "emptyTable" + withTable(table) { + sql(s"CREATE TABLE $table (key STRING, value STRING) USING PARQUET") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS noscan") + val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetchedStats1.get.sizeInBytes == 0) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetchedStats2.get.sizeInBytes == 0) + } + } + test("analyze column command - unsupported types and invalid columns") { val tableName = "column_stats_test1" withTable(tableName) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e00fa64e9f2ce..84bcea30d61a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -26,6 +26,7 @@ import scala.util.matching.Regex import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -210,27 +211,62 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - test("test elimination of the influences of the old stats") { + test("keep existing row count in stats with noscan if table is not changed") { val textTable = "textTable" withTable(textTable) { - sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") + sql(s"CREATE TABLE $textTable (key STRING, value STRING)") sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") val fetchedStats1 = checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - // when the total size is not changed, the old row count is kept + // when the table is not changed, total size is the same, and the old row count is kept val fetchedStats2 = checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1 == fetchedStats2) + } + } - sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") - sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - // update total size and remove the old and invalid row count + test("keep existing column stats if table is not changed") { + val table = "update_col_stats_table" + withTable(table) { + sql(s"CREATE TABLE $table (c1 INT, c2 STRING, c3 DOUBLE)") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats0 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetchedStats0.get.colStats == Map("c1" -> ColumnStat(0, None, None, 0, 4, 4))) + + // Insert new data and analyze: have the latest column stats. + sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats1 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get + assert(fetchedStats1.colStats == Map( + "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, + avgLen = 4, maxLen = 4))) + + // Analyze another column: since the table is not changed, the precious column stats are kept. + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") + val fetchedStats2 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get + assert(fetchedStats2.colStats == Map( + "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, + avgLen = 4, maxLen = 4), + "c2" -> ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, + avgLen = 1, maxLen = 1))) + + // Insert new data and analyze: stale column stats are removed and newly collected column + // stats are added. + sql(s"INSERT INTO TABLE $table SELECT 2, 'b', 20.0") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1, c3") val fetchedStats3 = - checkTableStats(textTable, hasSizeInBytes = true, expectedRowCounts = None) - assert(fetchedStats3.get.sizeInBytes > fetchedStats2.get.sizeInBytes) + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)).get + assert(fetchedStats3.colStats == Map( + "c1" -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + "c3" -> ColumnStat(distinctCount = 2, min = Some(10.0), max = Some(20.0), nullCount = 0, + avgLen = 8, maxLen = 8))) } } From 9131bdb7e12bcfb2cb699b3438f554604e28aaa8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sun, 9 Jul 2017 00:24:54 +0800 Subject: [PATCH 134/779] [SPARK-20342][CORE] Update task accumulators before sending task end event. This makes sures that listeners get updated task information; otherwise it's possible to write incomplete task information into event logs, for example, making the information in a replayed UI inconsistent with the original application. Added a new unit test to try to detect the problem, but it's not guaranteed to fail since it's a race; but it fails pretty reliably for me without the scheduler changes. Author: Marcelo Vanzin Closes #18393 from vanzin/SPARK-20342.try2. --- .../apache/spark/scheduler/DAGScheduler.scala | 70 ++++++++++++------- .../spark/scheduler/DAGSchedulerSuite.scala | 32 ++++++++- 2 files changed, 75 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3422a5f204b12..89b4cab88109d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1122,6 +1122,25 @@ class DAGScheduler( } } + private def postTaskEnd(event: CompletionEvent): Unit = { + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulators(event.accumUpdates) + } catch { + case NonFatal(e) => + val taskId = event.taskInfo.taskId + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + listenerBus.post(SparkListenerTaskEnd(event.task.stageId, event.task.stageAttemptId, + Utils.getFormattedClassName(event.task), event.reason, event.taskInfo, taskMetrics)) + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -1138,34 +1157,36 @@ class DAGScheduler( event.taskInfo.attemptNumber, // this is a task attempt number event.reason) - // Reconstruct task metrics. Note: this may be null if the task has failed. - val taskMetrics: TaskMetrics = - if (event.accumUpdates.nonEmpty) { - try { - TaskMetrics.fromAccumulators(event.accumUpdates) - } catch { - case NonFatal(e) => - logError(s"Error when attempting to reconstruct metrics for task $taskId", e) - null - } - } else { - null - } - - // The stage may have already finished when we get this event -- eg. maybe it was a - // speculative task. It is important that we send the TaskEnd event in any case, so listeners - // are properly notified and can chose to handle it. For instance, some listeners are - // doing their own accounting and if they don't get the task end event they think - // tasks are still running when they really aren't. - listenerBus.post(SparkListenerTaskEnd( - stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics)) - if (!stageIdToStage.contains(task.stageId)) { + // The stage may have already finished when we get this event -- eg. maybe it was a + // speculative task. It is important that we send the TaskEnd event in any case, so listeners + // are properly notified and can chose to handle it. For instance, some listeners are + // doing their own accounting and if they don't get the task end event they think + // tasks are still running when they really aren't. + postTaskEnd(event) + // Skip all the actions if the stage has been cancelled. return } val stage = stageIdToStage(task.stageId) + + // Make sure the task's accumulators are updated before any other processing happens, so that + // we can post a task end event before any jobs or stages are updated. The accumulators are + // only updated in certain cases. + event.reason match { + case Success => + stage match { + case rs: ResultStage if rs.activeJob.isEmpty => + // Ignore update if task's job has finished. + case _ => + updateAccumulators(event) + } + case _: ExceptionFailure => updateAccumulators(event) + case _ => + } + postTaskEnd(event) + event.reason match { case Success => task match { @@ -1176,7 +1197,6 @@ class DAGScheduler( resultStage.activeJob match { case Some(job) => if (!job.finished(rt.outputId)) { - updateAccumulators(event) job.finished(rt.outputId) = true job.numFinished += 1 // If the whole job has finished, remove it @@ -1203,7 +1223,6 @@ class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] - updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) @@ -1374,8 +1393,7 @@ class DAGScheduler( // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case exceptionFailure: ExceptionFailure => - // Tasks failed with exceptions might still have accumulator updates. - updateAccumulators(event) + // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 453be26ed8d0c..3b5df657d45cf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.util.Properties -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} @@ -2346,6 +2346,36 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou (Success, 1))) } + test("task end event should have updated accumulators (SPARK-20342)") { + val tasks = 10 + + val accumId = new AtomicLong() + val foundCount = new AtomicLong() + val listener = new SparkListener() { + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + event.taskInfo.accumulables.find(_.id == accumId.get).foreach { _ => + foundCount.incrementAndGet() + } + } + } + sc.addSparkListener(listener) + + // Try a few times in a loop to make sure. This is not guaranteed to fail when the bug exists, + // but it should at least make the test flaky. If the bug is fixed, this should always pass. + (1 to 10).foreach { i => + foundCount.set(0L) + + val accum = sc.longAccumulator(s"accum$i") + accumId.set(accum.id) + + sc.parallelize(1 to tasks, tasks).foreach { _ => + accum.add(1L) + } + sc.listenerBus.waitUntilEmpty(1000) + assert(foundCount.get() === tasks) + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. From 062c336d06a0bd4e740a18d2349e03e311509243 Mon Sep 17 00:00:00 2001 From: jinxing Date: Sun, 9 Jul 2017 00:27:58 +0800 Subject: [PATCH 135/779] [SPARK-21343] Refine the document for spark.reducer.maxReqSizeShuffleToMem. ## What changes were proposed in this pull request? In current code, reducer can break the old shuffle service when `spark.reducer.maxReqSizeShuffleToMem` is enabled. Let's refine document. Author: jinxing Closes #18566 from jinxing64/SPARK-21343. --- .../org/apache/spark/internal/config/package.scala | 6 ++++-- docs/configuration.md | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a629810bf093a..512d539ee9c38 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -323,9 +323,11 @@ package object config { private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") - .internal() .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + - "above this threshold. This is to avoid a giant request takes too much memory.") + "above this threshold. This is to avoid a giant request takes too much memory. We can " + + "enable this config by setting a specific value(e.g. 200m). Note that this config can " + + "be enabled only when the shuffle shuffle service is newer than Spark-2.2 or the shuffle" + + " service is disabled.") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) diff --git a/docs/configuration.md b/docs/configuration.md index 7dc23e441a7ba..6ca84240c1247 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -528,6 +528,16 @@ Apart from these, the following properties are also available, and may be useful By allowing it to limit the number of fetch requests, this scenario can be mitigated. + + spark.reducer.maxReqSizeShuffleToMem + Long.MaxValue + + The blocks of a shuffle request will be fetched to disk when size of the request is above + this threshold. This is to avoid a giant request takes too much memory. We can enable this + config by setting a specific value(e.g. 200m). Note that this config can be enabled only when + the shuffle shuffle service is newer than Spark-2.2 or the shuffle service is disabled. + + spark.shuffle.compress true From c3712b77a915a5cb12b6c0204bc5bd6037aad1f5 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 8 Jul 2017 11:56:19 -0700 Subject: [PATCH 136/779] [SPARK-21307][REVERT][SQL] Remove SQLConf parameters from the parser-related classes ## What changes were proposed in this pull request? Since we do not set active sessions when parsing the plan, we are unable to correctly use SQLConf.get to find the correct active session. Since https://github.com/apache/spark/pull/18531 breaks the build, I plan to revert it at first. ## How was this patch tested? The existing test cases Author: Xiao Li Closes #18568 from gatorsmile/revert18531. --- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 6 +- .../sql/catalyst/parser/ParseDriver.scala | 8 +- .../parser/ExpressionParserSuite.scala | 167 +++++++++--------- .../spark/sql/execution/SparkSqlParser.scala | 11 +- .../org/apache/spark/sql/functions.scala | 3 +- .../internal/BaseSessionStateBuilder.scala | 2 +- .../sql/internal/VariableSubstitution.scala | 4 +- .../sql/execution/SparkSqlParserSuite.scala | 10 +- .../execution/command/DDLCommandSuite.scala | 4 +- .../internal/VariableSubstitutionSuite.scala | 31 ++-- 11 files changed, 127 insertions(+), 121 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 336d3d65d0dd0..c40d5f6031a21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -74,7 +74,7 @@ class SessionCatalog( functionRegistry, conf, new Configuration(), - CatalystSqlParser, + new CatalystSqlParser(conf), DummyFunctionResourceLoader) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4d725904bc9b9..a616b0f773f38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -45,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { +class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging { import ParserUtils._ + def this() = this(new SQLConf()) + protected def typedVisit[T](ctx: ParseTree): T = { ctx.accept(this).asInstanceOf[T] } @@ -1457,7 +1459,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Special characters can be escaped by using Hive/C-style escaping. */ private def createString(ctx: StringLiteralContext): String = { - if (SQLConf.get.escapedStringLiterals) { + if (conf.escapedStringLiterals) { ctx.STRING().asScala.map(stringWithoutUnescape).mkString } else { ctx.STRING().asScala.map(string).mkString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 7e1fcfefc64a5..09598ffe770c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} /** @@ -121,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { /** * Concrete SQL parser for Catalyst-only SQL statements. */ +class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser { + val astBuilder = new AstBuilder(conf) +} + +/** For test-only. */ object CatalystSqlParser extends AbstractSqlParser { - val astBuilder = new AstBuilder + val astBuilder = new AstBuilder(new SQLConf()) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index ac7325257a15a..45f9f72dccc45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -167,12 +167,12 @@ class ExpressionParserSuite extends PlanTest { } test("like expressions with ESCAPED_STRING_LITERALS = true") { - val parser = CatalystSqlParser - withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { - assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) - assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) - assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) - } + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true") + val parser = new CatalystSqlParser(conf) + assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) + assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) + assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) } test("is null expressions") { @@ -435,85 +435,86 @@ class ExpressionParserSuite extends PlanTest { } test("strings") { - val parser = CatalystSqlParser Seq(true, false).foreach { escape => - withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> escape.toString) { - // tests that have same result whatever the conf is - // Single Strings. - assertEqual("\"hello\"", "hello", parser) - assertEqual("'hello'", "hello", parser) - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld", parser) - assertEqual("'hello' \" \" 'world'", "hello world", parser) - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%", parser) - assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) - - // tests that have different result regarding the conf - if (escape) { - // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to - // Spark 1.6 behavior. - - // 'LIKE' string literals. - assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) - assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) - - // Escaped characters. - // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work - // when ESCAPED_STRING_LITERALS is enabled. - // It is parsed literally. - assertEqual("'\\0'", "\\0", parser) - - // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is - // enabled. - val e = intercept[ParseException](parser.parseExpression("'\''")) - assert(e.message.contains("extraneous input '''")) - - // The unescape special characters (e.g., "\\t") for 2.0+ don't work - // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. - assertEqual("'\\\"'", "\\\"", parser) // Double quote - assertEqual("'\\b'", "\\b", parser) // Backspace - assertEqual("'\\n'", "\\n", parser) // Newline - assertEqual("'\\r'", "\\r", parser) // Carriage return - assertEqual("'\\t'", "\\t", parser) // Tab character - - // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. - // They are parsed literally. - assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) - // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. - // They are parsed literally. - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", - "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) - } else { - // Default behavior - - // 'LIKE' string literals. - assertEqual("'pattern\\\\%'", "pattern\\%", parser) - assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') - assertEqual("'\\''", "\'", parser) // Single quote - assertEqual("'\\\"'", "\"", parser) // Double quote - assertEqual("'\\b'", "\b", parser) // Backspace - assertEqual("'\\n'", "\n", parser) // Newline - assertEqual("'\\r'", "\r", parser) // Carriage return - assertEqual("'\\t'", "\t", parser) // Tab character - assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) - - // Unicode - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", - parser) - } + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString) + val parser = new CatalystSqlParser(conf) + + // tests that have same result whatever the conf is + // Single Strings. + assertEqual("\"hello\"", "hello", parser) + assertEqual("'hello'", "hello", parser) + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld", parser) + assertEqual("'hello' \" \" 'world'", "hello world", parser) + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%", parser) + assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) + + // tests that have different result regarding the conf + if (escape) { + // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to + // Spark 1.6 behavior. + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) + + // Escaped characters. + // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work + // when ESCAPED_STRING_LITERALS is enabled. + // It is parsed literally. + assertEqual("'\\0'", "\\0", parser) + + // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled. + val e = intercept[ParseException](parser.parseExpression("'\''")) + assert(e.message.contains("extraneous input '''")) + + // The unescape special characters (e.g., "\\t") for 2.0+ don't work + // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. + assertEqual("'\\\"'", "\\\"", parser) // Double quote + assertEqual("'\\b'", "\\b", parser) // Backspace + assertEqual("'\\n'", "\\n", parser) // Newline + assertEqual("'\\r'", "\\r", parser) // Carriage return + assertEqual("'\\t'", "\\t", parser) // Tab character + + // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) + // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", + "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) + } else { + // Default behavior + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') + assertEqual("'\\''", "\'", parser) // Single quote + assertEqual("'\\\"'", "\"", parser) // Double quote + assertEqual("'\\b'", "\b", parser) // Backspace + assertEqual("'\\n'", "\n", parser) // Newline + assertEqual("'\\r'", "\r", parser) // Carriage return + assertEqual("'\\t'", "\t", parser) // Tab character + assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) + + // Unicode + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", + parser) } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 618d027d8dc07..2f8e416e7df1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -39,11 +39,10 @@ import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. */ -class SparkSqlParser extends AbstractSqlParser { +class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { + val astBuilder = new SparkSqlAstBuilder(conf) - val astBuilder = new SparkSqlAstBuilder - - private val substitutor = new VariableSubstitution + private val substitutor = new VariableSubstitution(conf) protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = { super.parse(substitutor.substitute(command))(toResult) @@ -53,11 +52,9 @@ class SparkSqlParser extends AbstractSqlParser { /** * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ -class SparkSqlAstBuilder extends AstBuilder { +class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { import org.apache.spark.sql.catalyst.parser.ParserUtils._ - private def conf: SQLConf = SQLConf.get - /** * Create a [[SetCommand]] logical plan. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a5e4a444f33be..0c7b483f5c836 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1275,7 +1276,7 @@ object functions { */ def expr(expr: String): Column = { val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser + new SparkSqlParser(new SQLConf) } Column(parser.parseExpression(expr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 72d0ddc62303a..267f76217df84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -114,7 +114,7 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf` field. */ protected lazy val sqlParser: ParserInterface = { - extensions.buildParser(session, new SparkSqlParser) + extensions.buildParser(session, new SparkSqlParser(conf)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala index 2b9c574aaaf0c..4e7c813be9922 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala @@ -25,9 +25,7 @@ import org.apache.spark.internal.config._ * * Variable substitution is controlled by `SQLConf.variableSubstituteEnabled`. */ -class VariableSubstitution { - - private def conf = SQLConf.get +class VariableSubstitution(conf: SQLConf) { private val provider = new ConfigProvider { override def get(key: String): Option[String] = Option(conf.getConfString(key, "")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 2e29fa43f73d9..d238c76fbeeff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -37,7 +37,8 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType */ class SparkSqlParserSuite extends AnalysisTest { - private lazy val parser = new SparkSqlParser + val newConf = new SQLConf + private lazy val parser = new SparkSqlParser(newConf) /** * Normalizes plans: @@ -284,7 +285,6 @@ class SparkSqlParserSuite extends AnalysisTest { } test("query organization") { - val conf = SQLConf.get // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows val baseSql = "select * from t" val basePlan = @@ -293,20 +293,20 @@ class SparkSqlParserSuite extends AnalysisTest { assertEqual(s"$baseSql distribute by a, b", RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = conf.numShufflePartitions)) + numPartitions = newConf.numShufflePartitions)) assertEqual(s"$baseSql distribute by a sort by b", Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: Nil, basePlan, - numPartitions = conf.numShufflePartitions))) + numPartitions = newConf.numShufflePartitions))) assertEqual(s"$baseSql cluster by a, b", Sort(SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, global = false, RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, basePlan, - numPartitions = conf.numShufflePartitions))) + numPartitions = newConf.numShufflePartitions))) } test("pipeline concatenation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 750574830381f..5643c58d9f847 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} // TODO: merge this with DDLSuite (SPARK-14441) class DDLCommandSuite extends PlanTest { - private lazy val parser = new SparkSqlParser + private lazy val parser = new SparkSqlParser(new SQLConf) private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { val e = intercept[ParseException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala index c5e5b70e21335..d5a946aeaac31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/VariableSubstitutionSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.AnalysisException -class VariableSubstitutionSuite extends SparkFunSuite with PlanTest { +class VariableSubstitutionSuite extends SparkFunSuite { - private lazy val sub = new VariableSubstitution + private lazy val conf = new SQLConf + private lazy val sub = new VariableSubstitution(conf) test("system property") { System.setProperty("varSubSuite.var", "abcd") @@ -34,26 +35,26 @@ class VariableSubstitutionSuite extends SparkFunSuite with PlanTest { } test("Spark configuration variable") { - withSQLConf("some-random-string-abcd" -> "1234abcd") { - assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd") - assert(sub.substitute("${some-random-string-abcd}") == "1234abcd") - } + conf.setConfString("some-random-string-abcd", "1234abcd") + assert(sub.substitute("${hiveconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${sparkconf:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${spark:some-random-string-abcd}") == "1234abcd") + assert(sub.substitute("${some-random-string-abcd}") == "1234abcd") } test("multiple substitutes") { val q = "select ${bar} ${foo} ${doo} this is great" - withSQLConf("bar" -> "1", "foo" -> "2", "doo" -> "3") { - assert(sub.substitute(q) == "select 1 2 3 this is great") - } + conf.setConfString("bar", "1") + conf.setConfString("foo", "2") + conf.setConfString("doo", "3") + assert(sub.substitute(q) == "select 1 2 3 this is great") } test("test nested substitutes") { val q = "select ${bar} ${foo} this is great" - withSQLConf("bar" -> "1", "foo" -> "${bar}") { - assert(sub.substitute(q) == "select 1 1 this is great") - } + conf.setConfString("bar", "1") + conf.setConfString("foo", "${bar}") + assert(sub.substitute(q) == "select 1 1 this is great") } } From 08e0d033b40946b4ef5741a7aa1e7ba0bd48c6fb Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 8 Jul 2017 14:24:37 -0700 Subject: [PATCH 137/779] [SPARK-21093][R] Terminate R's worker processes in the parent of R's daemon to prevent a leak ## What changes were proposed in this pull request? This is a retry for #18320. This PR was reverted due to unexpected test failures with -10 error code. I was unable to reproduce in MacOS, CentOS and Ubuntu but only in Jenkins. So, the tests proceeded to verify this and revert the past try here - https://github.com/apache/spark/pull/18456 This new approach was tested in https://github.com/apache/spark/pull/18463. **Test results**: - With the part of suspicious change in the past try (https://github.com/apache/spark/pull/18463/commits/466325d3fd353668583f3bde38ae490d9db0b189) Tests ran 4 times and 2 times passed and 2 time failed. - Without the part of suspicious change in the past try (https://github.com/apache/spark/pull/18463/commits/466325d3fd353668583f3bde38ae490d9db0b189) Tests ran 5 times and they all passed. - With this new approach (https://github.com/apache/spark/pull/18463/commits/0a7589c09f53dfc2094497d8d3e59d6407569417) Tests ran 5 times and they all passed. It looks the cause is as below (see https://github.com/apache/spark/pull/18463/commits/466325d3fd353668583f3bde38ae490d9db0b189): ```diff + exitCode <- 1 ... + data <- parallel:::readChild(child) + if (is.raw(data)) { + if (unserialize(data) == exitCode) { ... + } + } ... - parallel:::mcexit(0L) + parallel:::mcexit(0L, send = exitCode) ``` Two possibilities I think - `parallel:::mcexit(.. , send = exitCode)` https://stat.ethz.ch/R-manual/R-devel/library/parallel/html/mcfork.html > It sends send to the master (unless NULL) and then shuts down the child process. However, it looks possible that the parent attemps to terminate the child right after getting our custom exit code. So, the child gets terminated between "send" and "shuts down", failing to exit properly. - A bug between `parallel:::mcexit(..., send = ...)` and `parallel:::readChild`. **Proposal**: To resolve this, I simply decided to avoid both possibilities with this new approach here (https://github.com/apache/spark/pull/18465/commits/9ff89a7859cb9f427fc774f33c3521c7d962b723). To support this idea, I explained with some quotation of the documentation as below: https://stat.ethz.ch/R-manual/R-devel/library/parallel/html/mcfork.html > `readChild` and `readChildren` return a raw vector with a "pid" attribute if data were available, an integer vector of length one with the process ID if a child terminated or `NULL` if the child no longer exists (no children at all for `readChildren`). `readChild` returns "an integer vector of length one with the process ID if a child terminated" so we can check if it is `integer` and the same selected "process ID". I believe this makes sure that the children are exited. In case that children happen to send any data manually to parent (which is why we introduced the suspicious part of the change (https://github.com/apache/spark/pull/18463/commits/466325d3fd353668583f3bde38ae490d9db0b189)), this should be raw bytes and will be discarded (and then will try to read the next and check if it is `integer` in the next loop). ## How was this patch tested? Manual tests and Jenkins tests. Author: hyukjinkwon Closes #18465 from HyukjinKwon/SPARK-21093-retry-1. --- R/pkg/inst/worker/daemon.R | 51 +++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3a318b71ea06d..2e31dc5f728cd 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,8 +30,50 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) +# Waits indefinitely for a socket connecion by default. +selectTimeout <- NULL + while (TRUE) { - ready <- socketSelect(list(inputCon)) + ready <- socketSelect(list(inputCon), timeout = selectTimeout) + + # Note that the children should be terminated in the parent. If each child terminates + # itself, it appears that the resource is not released properly, that causes an unexpected + # termination of this daemon due to, for example, running out of file descriptors + # (see SPARK-21093). Therefore, the current implementation tries to retrieve children + # that are exited (but not terminated) and then sends a kill signal to terminate them properly + # in the parent. + # + # There are two paths that it attempts to send a signal to terminate the children in the parent. + # + # 1. Every second if any socket connection is not available and if there are child workers + # running. + # 2. Right after a socket connection is available. + # + # In other words, the parent attempts to send the signal to the children every second if + # any worker is running or right before launching other worker children from the following + # new socket connection. + + # The process IDs of exited children are returned below. + children <- parallel:::selectChildren(timeout = 0) + + if (is.integer(children)) { + lapply(children, function(child) { + # This should be the PIDs of exited children. Otherwise, this returns raw bytes if any data + # was sent from this child. In this case, we discard it. + pid <- parallel:::readChild(child) + if (is.integer(pid)) { + # This checks if the data from this child is the same pid of this selected child. + if (child == pid) { + # If so, we terminate this child. + tools::pskill(child, tools::SIGUSR1) + } + } + }) + } else if (is.null(children)) { + # If it is NULL, there are no children. Waits indefinitely for a socket connecion. + selectTimeout <- NULL + } + if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -44,12 +86,15 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { + # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Set SIGUSR1 so that child can exit - tools::pskill(Sys.getpid(), tools::SIGUSR1) + # Note that this mcexit does not fully terminate this child. parallel:::mcexit(0L) + } else { + # Forking succeeded and we need to check if they finished their jobs every second. + selectTimeout <- 1 } } } From 680b33f16694b7c460235b11b8c265bc304f795a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 9 Jul 2017 16:30:35 -0700 Subject: [PATCH 138/779] [SPARK-18016][SQL][FOLLOWUP] merge declareAddedFunctions, initNestedClasses and declareNestedClasses ## What changes were proposed in this pull request? These 3 methods have to be used together, so it makes more sense to merge them into one method and then the caller side only need to call one method. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18579 from cloud-fan/minor. --- .../expressions/codegen/CodeGenerator.scala | 29 +++++++------------ .../codegen/GenerateMutableProjection.scala | 5 +--- .../codegen/GenerateOrdering.scala | 5 +--- .../codegen/GeneratePredicate.scala | 5 +--- .../codegen/GenerateSafeProjection.scala | 5 +--- .../codegen/GenerateUnsafeProjection.scala | 5 +--- .../sql/execution/WholeStageCodegenExec.scala | 5 +--- .../columnar/GenerateColumnAccessor.scala | 5 +--- 8 files changed, 18 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b15bf2ca7c116..7cf9daf628608 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -302,29 +302,20 @@ class CodegenContext { } /** - * Instantiates all nested, private sub-classes as objects to the `OuterClass` + * Declares all function code. If the added functions are too many, split them into nested + * sub-classes to avoid hitting Java compiler constant pool limitation. */ - private[sql] def initNestedClasses(): String = { + def declareAddedFunctions(): String = { + val inlinedFunctions = classFunctions(outerClassName).values + // Nested, private sub-classes have no mutable state (though they do reference the outer class' // mutable state), so we declare and initialize them inline to the OuterClass. - classes.filter(_._1 != outerClassName).map { + val initNestedClasses = classes.filter(_._1 != outerClassName).map { case (className, classInstance) => s"private $className $classInstance = new $className();" - }.mkString("\n") - } - - /** - * Declares all function code that should be inlined to the `OuterClass`. - */ - private[sql] def declareAddedFunctions(): String = { - classFunctions(outerClassName).values.mkString("\n") - } + } - /** - * Declares all nested, private sub-classes and the function code that should be inlined to them. - */ - private[sql] def declareNestedClasses(): String = { - classFunctions.filterKeys(_ != outerClassName).map { + val declareNestedClasses = classFunctions.filterKeys(_ != outerClassName).map { case (className, functions) => s""" |private class $className { @@ -332,7 +323,9 @@ class CodegenContext { |} """.stripMargin } - }.mkString("\n") + + (inlinedFunctions ++ initNestedClasses ++ declareNestedClasses).mkString("\n") + } final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 635766835029b..3768dcde00a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -115,8 +115,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - public ${classOf[BaseMutableProjection].getName} target(InternalRow row) { mutableRow = row; return this; @@ -136,8 +134,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP return mutableRow; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index a31943255b995..4e47895985209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -173,15 +173,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR ${ctx.initMutableStates()} } - ${ctx.declareAddedFunctions()} - public int compare(InternalRow a, InternalRow b) { $comparisons return 0; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index b400783bb5e55..e35b9dda6c017 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -66,15 +66,12 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - public boolean eval(InternalRow ${ctx.INPUT_ROW}) { ${eval.code} return !${eval.isNull} && ${eval.value}; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index dd0419d2286d1..192701a829686 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -175,16 +175,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions return mutableRow; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6be69d119bf8a..f2a66efc98e71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -391,8 +391,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - // Scala.Function1 need this public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); @@ -403,8 +401,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro return ${eval.value}; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} } """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 0bd28e36135c8..1007a7d55691b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -352,14 +352,11 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co ${ctx.initPartition()} } - ${ctx.declareAddedFunctions()} - protected void processNext() throws java.io.IOException { ${code.trim} } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} } """.trim diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index fc977f2fd5530..da34643281911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -192,8 +192,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera this.columnIndexes = columnIndexes; } - ${ctx.declareAddedFunctions()} - public boolean hasNext() { if (currentRow < numRowsInBatch) { return true; @@ -222,8 +220,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera return unsafeRow; } - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} + ${ctx.declareAddedFunctions()} }""" val code = CodeFormatter.stripOverlappingComments( From 457dc9ccbf8404fef6c1ebf8f82e59e4ba480a0e Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 10 Jul 2017 11:22:28 +0800 Subject: [PATCH 139/779] [MINOR][DOC] Improve the docs about how to correctly set configurations ## What changes were proposed in this pull request? Spark provides several ways to set configurations, either from configuration file, or from `spark-submit` command line options, or programmatically through `SparkConf` class. It may confuses beginners why some configurations set through `SparkConf` cannot take affect. So here add some docs to address this problems and let beginners know how to correctly set configurations. ## How was this patch tested? N/A Author: jerryshao Closes #18552 from jerryshao/improve-doc. --- docs/configuration.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 6ca84240c1247..91b5befd1b1eb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -95,6 +95,13 @@ in the `spark-defaults.conf` file. A few configuration keys have been renamed si versions of Spark; in such cases, the older key names are still accepted, but take lower precedence than any instance of the newer key. +Spark properties mainly can be divided into two kinds: one is related to deploy, like +"spark.driver.memory", "spark.executor.instances", this kind of properties may not be affected when +setting programmatically through `SparkConf` in runtime, or the behavior is depending on which +cluster manager and deploy mode you choose, so it would be suggested to set through configuration +file or `spark-submit` command line options; another is mainly related to Spark runtime control, +like "spark.task.maxFailures", this kind of properties can be set in either way. + ## Viewing Spark Properties The application web UI at `http://:4040` lists Spark properties in the "Environment" tab. From 0e80ecae300f3e2033419b2d98da8bf092c105bb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 9 Jul 2017 22:53:27 -0700 Subject: [PATCH 140/779] [SPARK-21100][SQL][FOLLOWUP] cleanup code and add more comments for Dataset.summary ## What changes were proposed in this pull request? Some code cleanup and adding comments to make the code more readable. Changed the way to generate result rows, to be more clear. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #18570 from cloud-fan/summary. --- .../scala/org/apache/spark/sql/Dataset.scala | 9 -- .../sql/execution/stat/StatFunctions.scala | 129 ++++++++---------- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 3 files changed, 56 insertions(+), 84 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5326b45b50a8b..dfb51192c69bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -224,15 +224,6 @@ class Dataset[T] private[sql]( } } - private[sql] def aggregatableColumns: Seq[Expression] = { - schema.fields - .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) - .map { n => - queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver) - .get - } - } - /** * Compose the string representing rows for output * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 436e18fdb5ff5..a75cfb3600225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.stat +import java.util.Locale + import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries} +import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -228,90 +231,68 @@ object StatFunctions extends Logging { val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics - val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) - val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { - val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) - val percentiles = pStrings.map { p => - try { - p.stripSuffix("%").toDouble / 100.0 - } catch { - case e: NumberFormatException => - throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) - } + val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) } - require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - (percentiles, pStrings, rest) - } else { - (Seq(), Seq(), selectedStatistics) } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - - // The list of summary statistics to compute, in the form of expressions. - val availableStatistics = Map[String, Expression => Expression]( - "count" -> ((child: Expression) => Count(child).toAggregateExpression()), - "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), - "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), - "min" -> ((child: Expression) => Min(child).toAggregateExpression()), - "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - - val statisticFns = remainingAggregates.map { agg => - require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") - agg -> availableStatistics(agg) - } - - def percentileAgg(child: Expression): Expression = - new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) - .toAggregateExpression() - - val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList - - val ret: Seq[Row] = if (outputCols.nonEmpty) { - var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) - } - if (hasPercentiles) { - aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs + var percentileIndex = 0 + val statisticFns = selectedStatistics.map { stats => + if (stats.endsWith("%")) { + val index = percentileIndex + percentileIndex += 1 + (child: Expression) => + GetArrayItem( + new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(), + Literal(index)) + } else { + stats.toLowerCase(Locale.ROOT) match { + case "count" => (child: Expression) => Count(child).toAggregateExpression() + case "mean" => (child: Expression) => Average(child).toAggregateExpression() + case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression() + case "min" => (child: Expression) => Min(child).toAggregateExpression() + case "max" => (child: Expression) => Max(child).toAggregateExpression() + case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic") + } } + } - val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + val selectedCols = ds.logicalPlan.output + .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType]) - // Pivot the data so each summary is one row - val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq + val aggExprs = statisticFns.flatMap { func => + selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name)) + } - val basicStats = if (hasPercentiles) grouped.tail else grouped + // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val. + lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head - val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) - } + // We will have one row for each selected statistic in the result. + val result = Array.fill[InternalRow](selectedStatistics.length) { + // each row has the statistic name, and statistic values of each selected column. + new GenericInternalRow(selectedCols.length + 1) + } - if (hasPercentiles) { - def nullSafeString(x: Any) = if (x == null) null else x.toString - val percentileRows = grouped.head - .map { - case a: Seq[Any] => a - case _ => Seq.fill(percentiles.length)(null: Any) - } - .transpose - .zip(percentileNames) - .map { case (values: Seq[Any], name) => - Row(name :: values.map(nullSafeString).toList: _*) - } - (rows ++ percentileRows) - .sortWith((left, right) => - selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) - } else { - rows + var rowIndex = 0 + while (rowIndex < result.length) { + val statsName = selectedStatistics(rowIndex) + result(rowIndex).update(0, UTF8String.fromString(statsName)) + for (colIndex <- selectedCols.indices) { + val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex) + result(rowIndex).update(colIndex + 1, statsValue) } - } else { - // If there are no output columns, just output a single column that contains the stats. - selectedStatistics.map(Row(_)) + rowIndex += 1 } // All columns are string type - val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - // `toArray` forces materialization to make the seq serializable - Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)) - } + val output = AttributeReference("summary", StringType)() +: + selectedCols.map(c => AttributeReference(c.name, StringType)()) + Dataset.ofRows(ds.sparkSession, LocalRelation(output, result)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2c7051bf431c3..b2219b4eb8c17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -770,7 +770,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val fooE = intercept[IllegalArgumentException] { person2.summary("foo") } - assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") + assert(fooE.getMessage === "foo is not a recognised statistic") val parseE = intercept[IllegalArgumentException] { person2.summary("foo%") From 96d58f285bc98d4c2484150eefe7447db4784a86 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Mon, 10 Jul 2017 14:40:20 +0800 Subject: [PATCH 141/779] [SPARK-21219][CORE] Task retry occurs on same executor due to race condition with blacklisting ## What changes were proposed in this pull request? There's a race condition in the current TaskSetManager where a failed task is added for retry (addPendingTask), and can asynchronously be assigned to an executor *prior* to the blacklist state (updateBlacklistForFailedTask), the result is the task might re-execute on the same executor. This is particularly problematic if the executor is shutting down since the retry task immediately becomes a lost task (ExecutorLostFailure). Another side effect is that the actual failure reason gets obscured by the retry task which never actually executed. There are sample logs showing the issue in the https://issues.apache.org/jira/browse/SPARK-21219 The fix is to change the ordering of the addPendingTask and updatingBlackListForFailedTask calls in TaskSetManager.handleFailedTask ## How was this patch tested? Implemented a unit test that verifies the task is black listed before it is added to the pending task. Ran the unit test without the fix and it fails. Ran the unit test with the fix and it passes. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Vandenberg Closes #18427 from ericvandenbergfb/blacklistFix. --- .../spark/scheduler/TaskSetManager.scala | 21 ++++----- .../spark/scheduler/TaskSetManagerSuite.scala | 44 ++++++++++++++++++- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 02d374dc37cd5..3968fb7e6356d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -198,7 +198,7 @@ private[spark] class TaskSetManager( private[scheduler] var emittedTaskSizeWarning = false /** Add a task to all the pending-task lists that it should be on. */ - private def addPendingTask(index: Int) { + private[spark] def addPendingTask(index: Int) { for (loc <- tasks(index).preferredLocations) { loc match { case e: ExecutorCacheTaskLocation => @@ -832,15 +832,6 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) - if (successful(index)) { - logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" + - s" be re-executed (either because the task failed with a shuffle data fetch failure," + - s" so the previous stage needs to be re-run, or because a different copy of the task" + - s" has already succeeded).") - } else { - addPendingTask(index) - } - if (!isZombie && reason.countTowardsTaskFailures) { taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( info.host, info.executorId, index)) @@ -854,6 +845,16 @@ private[spark] class TaskSetManager( return } } + + if (successful(index)) { + logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" + + s" be re-executed (either because the task failed with a shuffle data fetch failure," + + s" so the previous stage needs to be re-run, or because a different copy of the task" + + s" has already succeeded).") + } else { + addPendingTask(index) + } + maybeFinishTaskSet() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 80fb674725814..e46900e4e5049 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.mockito.Matchers.{any, anyInt, anyString} -import org.mockito.Mockito.{mock, never, spy, verify, when} +import org.mockito.Mockito.{mock, never, spy, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -1172,6 +1172,48 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(blacklistTracker.isNodeBlacklisted("host1")) } + test("update blacklist before adding pending task to avoid race condition") { + // When a task fails, it should apply the blacklist policy prior to + // retrying the task otherwise there's a race condition where run on + // the same executor that it was intended to be black listed from. + val conf = new SparkConf(). + set(config.BLACKLIST_ENABLED, true) + + // Create a task with two executors. + sc = new SparkContext("local", "test", conf) + val exec = "executor1" + val host = "host1" + val exec2 = "executor2" + val host2 = "host2" + sched = new FakeTaskScheduler(sc, (exec, host), (exec2, host2)) + val taskSet = FakeTask.createTaskSet(1) + + val clock = new ManualClock + val mockListenerBus = mock(classOf[LiveListenerBus]) + val blacklistTracker = new BlacklistTracker(mockListenerBus, conf, None, clock) + val taskSetManager = new TaskSetManager(sched, taskSet, 1, Some(blacklistTracker)) + val taskSetManagerSpy = spy(taskSetManager) + + val taskDesc = taskSetManagerSpy.resourceOffer(exec, host, TaskLocality.ANY) + + // Assert the task has been black listed on the executor it was last executed on. + when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + val task = invocationOnMock.getArgumentAt(0, classOf[Int]) + assert(taskSetManager.taskSetBlacklistHelperOpt.get. + isExecutorBlacklistedForTask(exec, task)) + } + } + ) + + // Simulate a fake exception + val e = new ExceptionFailure("a", "b", Array(), "c", None) + taskSetManagerSpy.handleFailedTask(taskDesc.get.taskId, TaskState.FAILED, e) + + verify(taskSetManagerSpy, times(1)).addPendingTask(anyInt()) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { From c444d10868c808f4ae43becd5506bf944d9c2e9b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 10 Jul 2017 07:46:47 +0100 Subject: [PATCH 142/779] [MINOR][DOC] Remove obsolete `ec2-scripts.md` ## What changes were proposed in this pull request? Since this document became obsolete, we had better remove this for Apache Spark 2.3.0. The original document is removed via SPARK-12735 on January 2016, and currently it's just redirection page. The only reference in Apache Spark website will go directly to the destination in https://github.com/apache/spark-website/pull/54. ## How was this patch tested? N/A. This is a removal of documentation. Author: Dongjoon Hyun Closes #18578 from dongjoon-hyun/SPARK-REMOVE-EC2. --- docs/ec2-scripts.md | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 docs/ec2-scripts.md diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md deleted file mode 100644 index 6cd39dbed055d..0000000000000 --- a/docs/ec2-scripts.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -layout: global -title: Running Spark on EC2 -redirect: https://github.com/amplab/spark-ec2#readme ---- - -This document has been superseded and replaced by documentation at https://github.com/amplab/spark-ec2#readme From 647963a26a2d4468ebd9b68111ebe68bee501fde Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 10 Jul 2017 15:58:34 +0800 Subject: [PATCH 143/779] [SPARK-20460][SQL] Make it more consistent to handle column name duplication ## What changes were proposed in this pull request? This pr made it more consistent to handle column name duplication. In the current master, error handling is different when hitting column name duplication: ``` // json scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq("""{"a":1, "a":1}"""""").toDF().coalesce(1).write.mode("overwrite").text("/tmp/data") scala> spark.read.format("json").schema(schema).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Reference 'a' is ambiguous, could be: a#12, a#13.; at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:287) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:181) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:153) scala> spark.read.format("json").load("/tmp/data").show org.apache.spark.sql.AnalysisException: Duplicate column(s) : "a" found, cannot save to JSON format; at org.apache.spark.sql.execution.datasources.json.JsonDataSource.checkConstraints(JsonDataSource.scala:81) at org.apache.spark.sql.execution.datasources.json.JsonDataSource.inferSchema(JsonDataSource.scala:63) at org.apache.spark.sql.execution.datasources.json.JsonFileFormat.inferSchema(JsonFileFormat.scala:57) at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$7.apply(DataSource.scala:176) at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$7.apply(DataSource.scala:176) // csv scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq("a,a", "1,1").toDF().coalesce(1).write.mode("overwrite").text("/tmp/data") scala> spark.read.format("csv").schema(schema).option("header", false).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Reference 'a' is ambiguous, could be: a#41, a#42.; at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:287) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:181) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:153) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:152) // If `inferSchema` is true, a CSV format is duplicate-safe (See SPARK-16896) scala> spark.read.format("csv").option("header", true).load("/tmp/data").show +---+---+ | a0| a1| +---+---+ | 1| 1| +---+---+ // parquet scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq((1, 1)).toDF("a", "b").coalesce(1).write.mode("overwrite").parquet("/tmp/data") scala> spark.read.format("parquet").schema(schema).option("header", false).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Reference 'a' is ambiguous, could be: a#110, a#111.; at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:287) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:181) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:153) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolve$1.apply(LogicalPlan.scala:152) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) ``` When this patch applied, the results change to; ``` // json scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq("""{"a":1, "a":1}"""""").toDF().coalesce(1).write.mode("overwrite").text("/tmp/data") scala> spark.read.format("json").schema(schema).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Found duplicate column(s) in datasource: "a"; at org.apache.spark.sql.util.SchemaUtils$.checkColumnNameDuplication(SchemaUtil.scala:47) at org.apache.spark.sql.util.SchemaUtils$.checkSchemaColumnNameDuplication(SchemaUtil.scala:33) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:186) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:368) scala> spark.read.format("json").load("/tmp/data").show org.apache.spark.sql.AnalysisException: Found duplicate column(s) in datasource: "a"; at org.apache.spark.sql.util.SchemaUtils$.checkColumnNameDuplication(SchemaUtil.scala:47) at org.apache.spark.sql.util.SchemaUtils$.checkSchemaColumnNameDuplication(SchemaUtil.scala:33) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:186) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:368) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:178) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:156) // csv scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq("a,a", "1,1").toDF().coalesce(1).write.mode("overwrite").text("/tmp/data") scala> spark.read.format("csv").schema(schema).option("header", false).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Found duplicate column(s) in datasource: "a"; at org.apache.spark.sql.util.SchemaUtils$.checkColumnNameDuplication(SchemaUtil.scala:47) at org.apache.spark.sql.util.SchemaUtils$.checkSchemaColumnNameDuplication(SchemaUtil.scala:33) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:186) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:368) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:178) scala> spark.read.format("csv").option("header", true).load("/tmp/data").show +---+---+ | a0| a1| +---+---+ | 1| 1| +---+---+ // parquet scala> val schema = StructType(StructField("a", IntegerType) :: StructField("a", IntegerType) :: Nil) scala> Seq((1, 1)).toDF("a", "b").coalesce(1).write.mode("overwrite").parquet("/tmp/data") scala> spark.read.format("parquet").schema(schema).option("header", false).load("/tmp/data").show org.apache.spark.sql.AnalysisException: Found duplicate column(s) in datasource: "a"; at org.apache.spark.sql.util.SchemaUtils$.checkColumnNameDuplication(SchemaUtil.scala:47) at org.apache.spark.sql.util.SchemaUtils$.checkSchemaColumnNameDuplication(SchemaUtil.scala:33) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:186) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:368) ``` ## How was this patch tested? Added tests in `DataFrameReaderWriterSuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro Closes #17758 from maropu/SPARK-20460. --- .../sql/catalyst/catalog/SessionCatalog.scala | 16 +--- .../apache/spark/sql/util/SchemaUtils.scala | 58 +++++++++--- .../spark/sql/util/SchemaUtilsSuite.scala | 83 +++++++++++++++++ .../command/createDataSourceTables.scala | 2 - .../spark/sql/execution/command/tables.scala | 8 +- .../spark/sql/execution/command/views.scala | 9 +- .../execution/datasources/DataSource.scala | 43 +++++++-- .../InsertIntoHadoopFsRelationCommand.scala | 14 ++- .../datasources/PartitioningUtils.scala | 10 +-- .../datasources/jdbc/JdbcUtils.scala | 11 +-- .../datasources/json/JsonDataSource.scala | 15 +--- .../sql/execution/datasources/rules.scala | 36 ++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../sql/execution/command/DDLSuite.scala | 56 ++++++++---- .../spark/sql/jdbc/JDBCWriteSuite.scala | 4 +- .../sql/sources/ResolvedDataSourceSuite.scala | 5 +- .../sql/streaming/FileStreamSinkSuite.scala | 37 ++++++++ .../sql/test/DataFrameReaderWriterSuite.scala | 88 +++++++++++++++++++ .../sql/hive/execution/HiveDDLSuite.scala | 2 +- 19 files changed, 382 insertions(+), 119 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c40d5f6031a21..b44d2ee69e1d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType object SessionCatalog { val DEFAULT_DATABASE = "default" @@ -188,19 +188,6 @@ class SessionCatalog( } } - private def checkDuplication(fields: Seq[StructField]): Unit = { - val columnNames = if (conf.caseSensitiveAnalysis) { - fields.map(_.name) - } else { - fields.map(_.name.toLowerCase) - } - if (columnNames.distinct.length != columnNames.length) { - val duplicateColumns = columnNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => x - } - throw new AnalysisException(s"Found duplicate column(s): ${duplicateColumns.mkString(", ")}") - } - } // ---------------------------------------------------------------------------- // Databases // ---------------------------------------------------------------------------- @@ -353,7 +340,6 @@ class SessionCatalog( val tableIdentifier = TableIdentifier(table, Some(db)) requireDbExists(db) requireTableExists(tableIdentifier) - checkDuplication(newSchema) val catalogTable = externalCatalog.getTable(db, table) val oldSchema = catalogTable.schema diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index e881685ce6262..41ca270095ffb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.util -import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.types.StructType /** @@ -25,29 +27,63 @@ import org.apache.spark.internal.Logging * * TODO: Merge this file with [[org.apache.spark.ml.util.SchemaUtils]]. */ -private[spark] object SchemaUtils extends Logging { +private[spark] object SchemaUtils { /** - * Checks if input column names have duplicate identifiers. Prints a warning message if + * Checks if an input schema has duplicate column names. This throws an exception if the + * duplication exists. + * + * @param schema schema to check + * @param colType column type name, used in an exception message + * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not + */ + def checkSchemaColumnNameDuplication( + schema: StructType, colType: String, caseSensitiveAnalysis: Boolean = false): Unit = { + checkColumnNameDuplication(schema.map(_.name), colType, caseSensitiveAnalysis) + } + + // Returns true if a given resolver is case-sensitive + private def isCaseSensitiveAnalysis(resolver: Resolver): Boolean = { + if (resolver == caseSensitiveResolution) { + true + } else if (resolver == caseInsensitiveResolution) { + false + } else { + sys.error("A resolver to check if two identifiers are equal must be " + + "`caseSensitiveResolution` or `caseInsensitiveResolution` in o.a.s.sql.catalyst.") + } + } + + /** + * Checks if input column names have duplicate identifiers. This throws an exception if * the duplication exists. * * @param columnNames column names to check - * @param colType column type name, used in a warning message + * @param colType column type name, used in an exception message + * @param resolver resolver used to determine if two identifiers are equal + */ + def checkColumnNameDuplication( + columnNames: Seq[String], colType: String, resolver: Resolver): Unit = { + checkColumnNameDuplication(columnNames, colType, isCaseSensitiveAnalysis(resolver)) + } + + /** + * Checks if input column names have duplicate identifiers. This throws an exception if + * the duplication exists. + * + * @param columnNames column names to check + * @param colType column type name, used in an exception message * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not */ def checkColumnNameDuplication( columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = { - val names = if (caseSensitiveAnalysis) { - columnNames - } else { - columnNames.map(_.toLowerCase) - } + val names = if (caseSensitiveAnalysis) columnNames else columnNames.map(_.toLowerCase) if (names.distinct.length != names.length) { val duplicateColumns = names.groupBy(identity).collect { case (x, ys) if ys.length > 1 => s"`$x`" } - logWarning(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}. " + - "You might need to assign different column names.") + throw new AnalysisException( + s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}") } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala new file mode 100644 index 0000000000000..a25be2fe61dbd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.types.StructType + +class SchemaUtilsSuite extends SparkFunSuite { + + private def resolver(caseSensitiveAnalysis: Boolean): Resolver = { + if (caseSensitiveAnalysis) { + caseSensitiveResolution + } else { + caseInsensitiveResolution + } + } + + Seq((true, ("a", "a"), ("b", "b")), (false, ("a", "A"), ("b", "B"))).foreach { + case (caseSensitive, (a0, a1), (b0, b1)) => + + val testType = if (caseSensitive) "case-sensitive" else "case-insensitive" + test(s"Check column name duplication in $testType cases") { + def checkExceptionCases(schemaStr: String, duplicatedColumns: Seq[String]): Unit = { + val expectedErrorMsg = "Found duplicate column(s) in SchemaUtilsSuite: " + + duplicatedColumns.map(c => s"`${c.toLowerCase}`").mkString(", ") + val schema = StructType.fromDDL(schemaStr) + var msg = intercept[AnalysisException] { + SchemaUtils.checkSchemaColumnNameDuplication( + schema, "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + }.getMessage + assert(msg.contains(expectedErrorMsg)) + msg = intercept[AnalysisException] { + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", resolver(caseSensitive)) + }.getMessage + assert(msg.contains(expectedErrorMsg)) + msg = intercept[AnalysisException] { + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + }.getMessage + assert(msg.contains(expectedErrorMsg)) + } + + checkExceptionCases(s"$a0 INT, b INT, $a1 INT", a0 :: Nil) + checkExceptionCases(s"$a0 INT, b INT, $a1 INT, $a0 INT", a0 :: Nil) + checkExceptionCases(s"$a0 INT, $b0 INT, $a1 INT, $a0 INT, $b1 INT", b0 :: a0 :: Nil) + } + } + + test("Check no exception thrown for valid schemas") { + def checkNoExceptionCases(schemaStr: String, caseSensitive: Boolean): Unit = { + val schema = StructType.fromDDL(schemaStr) + SchemaUtils.checkSchemaColumnNameDuplication( + schema, "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", resolver(caseSensitive)) + SchemaUtils.checkColumnNameDuplication( + schema.map(_.name), "in SchemaUtilsSuite", caseSensitiveAnalysis = caseSensitive) + } + + checkNoExceptionCases("a INT, b INT, c INT", caseSensitive = true) + checkNoExceptionCases("Aa INT, b INT, aA INT", caseSensitive = true) + + checkNoExceptionCases("a INT, b INT, c INT", caseSensitive = false) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 729bd39d821c9..04b2534ca5eb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.command import java.net.URI -import org.apache.hadoop.fs.Path - import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8ded1060f7bf0..fa50d12722411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.command import java.io.File import java.net.URI import java.nio.file.FileSystems -import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import scala.util.Try -import org.apache.commons.lang3.StringEscapeUtils import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -42,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils /** @@ -202,6 +201,11 @@ case class AlterTableAddColumnsCommand( // make sure any partition columns are at the end of the fields val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema + + SchemaUtils.checkColumnNameDuplication( + reorderedSchema.map(_.name), "in the table definition of " + table.identifier, + conf.caseSensitiveAnalysis) + catalog.alterTableSchema( table, catalogTable.schema.copy(fields = reorderedSchema.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index a6d56ca91a3ee..ffdfd527fa701 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.sql.util.SchemaUtils /** @@ -355,15 +356,15 @@ object ViewHelper { properties: Map[String, String], session: SparkSession, analyzedPlan: LogicalPlan): Map[String, String] = { + val queryOutput = analyzedPlan.schema.fieldNames + // Generate the query column names, throw an AnalysisException if there exists duplicate column // names. - val queryOutput = analyzedPlan.schema.fieldNames - assert(queryOutput.distinct.size == queryOutput.size, - s"The view output ${queryOutput.mkString("(", ",", ")")} contains duplicate column name.") + SchemaUtils.checkColumnNameDuplication( + queryOutput, "in the view definition", session.sessionState.conf.resolver) // Generate the view default database name. val viewDefaultDatabase = session.sessionState.catalog.getCurrentDatabase - removeQueryColumnNames(properties) ++ generateViewDefaultDatabase(viewDefaultDatabase) ++ generateQueryColumnNames(queryOutput) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 75e530607570f..d36a04f1fff8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -87,6 +87,14 @@ case class DataSource( lazy val providingClass: Class[_] = DataSource.lookupDataSource(className) lazy val sourceInfo: SourceInfo = sourceSchema() private val caseInsensitiveOptions = CaseInsensitiveMap(options) + private val equality = sparkSession.sessionState.conf.resolver + + bucketSpec.map { bucket => + SchemaUtils.checkColumnNameDuplication( + bucket.bucketColumnNames, "in the bucket definition", equality) + SchemaUtils.checkColumnNameDuplication( + bucket.sortColumnNames, "in the sort definition", equality) + } /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer @@ -132,7 +140,6 @@ case class DataSource( // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource val resolved = tempFileIndex.partitionSchema.map { partitionField => - val equality = sparkSession.sessionState.conf.resolver // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( partitionField) @@ -146,7 +153,6 @@ case class DataSource( inferredPartitions } else { val partitionFields = partitionColumns.map { partitionColumn => - val equality = sparkSession.sessionState.conf.resolver userSpecifiedSchema.flatMap(_.find(c => equality(c.name, partitionColumn))).orElse { val inferredPartitions = tempFileIndex.partitionSchema val inferredOpt = inferredPartitions.find(p => equality(p.name, partitionColumn)) @@ -172,7 +178,6 @@ case class DataSource( } val dataSchema = userSpecifiedSchema.map { schema => - val equality = sparkSession.sessionState.conf.resolver StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name)))) }.orElse { format.inferSchema( @@ -184,9 +189,18 @@ case class DataSource( s"Unable to infer schema for $format. It must be specified manually.") } - SchemaUtils.checkColumnNameDuplication( - (dataSchema ++ partitionSchema).map(_.name), "in the data schema and the partition schema", - sparkSession.sessionState.conf.caseSensitiveAnalysis) + // We just print a waring message if the data schema and partition schema have the duplicate + // columns. This is because we allow users to do so in the previous Spark releases and + // we have the existing tests for the cases (e.g., `ParquetHadoopFsRelationSuite`). + // See SPARK-18108 and SPARK-21144 for related discussions. + try { + SchemaUtils.checkColumnNameDuplication( + (dataSchema ++ partitionSchema).map(_.name), + "in the data schema and the partition schema", + equality) + } catch { + case e: AnalysisException => logWarning(e.getMessage) + } (dataSchema, partitionSchema) } @@ -391,6 +405,23 @@ case class DataSource( s"$className is not a valid Spark SQL Data Source.") } + relation match { + case hs: HadoopFsRelation => + SchemaUtils.checkColumnNameDuplication( + hs.dataSchema.map(_.name), + "in the data schema", + equality) + SchemaUtils.checkColumnNameDuplication( + hs.partitionSchema.map(_.name), + "in the partition schema", + equality) + case _ => + SchemaUtils.checkColumnNameDuplication( + relation.schema.map(_.name), + "in the data schema", + equality) + } + relation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 0031567d3d288..c1bcfb8610783 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -21,7 +21,6 @@ import java.io.IOException import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} @@ -30,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.util.SchemaUtils /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. @@ -64,13 +63,10 @@ case class InsertIntoHadoopFsRelationCommand( assert(children.length == 1) // Most formats don't do well with duplicate columns, so lets not allow that - if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) { - val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s): $duplicateColumns found, " + - "cannot save to file.") - } + SchemaUtils.checkSchemaColumnNameDuplication( + query.schema, + s"when inserting into $outputPath", + sparkSession.sessionState.conf.caseSensitiveAnalysis) val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) val fs = outputPath.getFileSystem(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f61c673baaa58..92358da6d6c67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -301,13 +302,8 @@ object PartitioningUtils { normalizedKey -> value } - if (normalizedPartSpec.map(_._1).distinct.length != normalizedPartSpec.length) { - val duplicateColumns = normalizedPartSpec.map(_._1).groupBy(identity).collect { - case (x, ys) if ys.length > 1 => x - } - throw new AnalysisException(s"Found duplicated columns in partition specification: " + - duplicateColumns.mkString(", ")) - } + SchemaUtils.checkColumnNameDuplication( + normalizedPartSpec.map(_._1), "in the partition schema", resolver) normalizedPartSpec.toMap } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 55b2539c13381..bbe9024f13a44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.NextIterator @@ -749,14 +750,8 @@ object JdbcUtils extends Logging { val nameEquality = df.sparkSession.sessionState.conf.resolver // checks duplicate columns in the user specified column types. - userSchema.fieldNames.foreach { col => - val duplicatesCols = userSchema.fieldNames.filter(nameEquality(_, col)) - if (duplicatesCols.size >= 2) { - throw new AnalysisException( - "Found duplicate column(s) in createTableColumnTypes option value: " + - duplicatesCols.mkString(", ")) - } - } + SchemaUtils.checkColumnNameDuplication( + userSchema.map(_.name), "in the createTableColumnTypes option value", nameEquality) // checks if user specified column names exist in the DataFrame schema userSchema.fieldNames.foreach { col => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 5a92a71d19e78..8b7c2709afde1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -59,9 +59,7 @@ abstract class JsonDataSource extends Serializable { inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Option[StructType] = { if (inputPaths.nonEmpty) { - val jsonSchema = infer(sparkSession, inputPaths, parsedOptions) - checkConstraints(jsonSchema) - Some(jsonSchema) + Some(infer(sparkSession, inputPaths, parsedOptions)) } else { None } @@ -71,17 +69,6 @@ abstract class JsonDataSource extends Serializable { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType - - /** Constraints to be imposed on schema to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") - } - } } object JsonDataSource { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3f4a78580f1eb..41d40aa926fbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.{AtomicType, StructType} +import org.apache.spark.sql.util.SchemaUtils /** * Try to replaces [[UnresolvedRelation]]s if the plan is for direct query on files. @@ -222,12 +223,10 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi } private def normalizeCatalogTable(schema: StructType, table: CatalogTable): CatalogTable = { - val columnNames = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { - schema.map(_.name) - } else { - schema.map(_.name.toLowerCase) - } - checkDuplication(columnNames, "table definition of " + table.identifier) + SchemaUtils.checkSchemaColumnNameDuplication( + schema, + "in the table definition of " + table.identifier, + sparkSession.sessionState.conf.caseSensitiveAnalysis) val normalizedPartCols = normalizePartitionColumns(schema, table) val normalizedBucketSpec = normalizeBucketSpec(schema, table) @@ -253,7 +252,10 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi partCols = table.partitionColumnNames, resolver = sparkSession.sessionState.conf.resolver) - checkDuplication(normalizedPartitionCols, "partition") + SchemaUtils.checkColumnNameDuplication( + normalizedPartitionCols, + "in the partition schema", + sparkSession.sessionState.conf.resolver) if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) { if (DDLUtils.isHiveTable(table)) { @@ -283,8 +285,15 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi tableCols = schema.map(_.name), bucketSpec = bucketSpec, resolver = sparkSession.sessionState.conf.resolver) - checkDuplication(normalizedBucketSpec.bucketColumnNames, "bucket") - checkDuplication(normalizedBucketSpec.sortColumnNames, "sort") + + SchemaUtils.checkColumnNameDuplication( + normalizedBucketSpec.bucketColumnNames, + "in the bucket definition", + sparkSession.sessionState.conf.resolver) + SchemaUtils.checkColumnNameDuplication( + normalizedBucketSpec.sortColumnNames, + "in the sort definition", + sparkSession.sessionState.conf.resolver) normalizedBucketSpec.sortColumnNames.map(schema(_)).map(_.dataType).foreach { case dt if RowOrdering.isOrderable(dt) => // OK @@ -297,15 +306,6 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi } } - private def checkDuplication(colNames: Seq[String], colType: String): Unit = { - if (colNames.distinct.length != colNames.length) { - val duplicateColumns = colNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => x - } - failAnalysis(s"Found duplicate column(s) in $colType: ${duplicateColumns.mkString(", ")}") - } - } - private def failAnalysis(msg: String) = throw new AnalysisException(msg) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b2219b4eb8c17..a5a2e1c38d300 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1189,7 +1189,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") .write.format("parquet").save("temp") } - assert(e.getMessage.contains("Duplicate column(s)")) + assert(e.getMessage.contains("Found duplicate column(s) when inserting into")) assert(e.getMessage.contains("column1")) assert(!e.getMessage.contains("column2")) @@ -1199,7 +1199,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .toDF("column1", "column2", "column3", "column1", "column3") .write.format("json").save("temp") } - assert(f.getMessage.contains("Duplicate column(s)")) + assert(f.getMessage.contains("Found duplicate column(s) when inserting into")) assert(f.getMessage.contains("column1")) assert(f.getMessage.contains("column3")) assert(!f.getMessage.contains("column2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 5c40d8bb4b1ef..5c0a6aa724bf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -436,16 +436,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create table - duplicate column names in the table definition") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int, a string) USING json") - } - assert(e.message == "Found duplicate column(s) in table definition of `tbl`: a") - - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - val e2 = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int, A string) USING json") + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT, $c1 INT) USING parquet") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the table definition of `t`")) } - assert(e2.message == "Found duplicate column(s) in table definition of `tbl`: a") } } @@ -466,17 +463,33 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create table - column repeated in partition columns") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int) USING json PARTITIONED BY (a, a)") + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT) USING parquet PARTITIONED BY ($c0, $c1)") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the partition schema")) + } } - assert(e.message == "Found duplicate column(s) in partition: a") } - test("create table - column repeated in bucket columns") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE tbl(a int) USING json CLUSTERED BY (a, a) INTO 4 BUCKETS") + test("create table - column repeated in bucket/sort columns") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var errMsg = intercept[AnalysisException] { + sql(s"CREATE TABLE t($c0 INT) USING parquet CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the bucket definition")) + + errMsg = intercept[AnalysisException] { + sql(s""" + |CREATE TABLE t($c0 INT, col INT) USING parquet CLUSTERED BY (col) + | SORTED BY ($c0, $c1) INTO 2 BUCKETS + """.stripMargin) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the sort definition")) + } } - assert(e.message == "Found duplicate column(s) in bucket: a") } test("Refresh table after changing the data source table partitioning") { @@ -528,6 +541,17 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create view - duplicate column names in the view definition") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val errMsg = intercept[AnalysisException] { + sql(s"CREATE VIEW t AS SELECT * FROM VALUES (1, 1) AS t($c0, $c1)") + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the view definition")) + } + } + } + test("Alter/Describe Database") { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 92f50a095f19b..2334d5ae32dc3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.{Date, DriverManager, Timestamp} +import java.sql.DriverManager import java.util.Properties import scala.collection.JavaConverters.propertiesAsScalaMapConverter @@ -479,7 +479,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .jdbc(url1, "TEST.USERDBTYPETEST", properties) }.getMessage() assert(msg.contains( - "Found duplicate column(s) in createTableColumnTypes option value: name, NaMe")) + "Found duplicate column(s) in the createTableColumnTypes option value: `name`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 0f97fd78d2ffb..308c5079c44bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -21,11 +21,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.test.SharedSQLContext -class ResolvedDataSourceSuite extends SparkFunSuite { +class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { private def getProvidingClass(name: String): Class[_] = DataSource( - sparkSession = null, + sparkSession = spark, className = name, options = Map(DateTimeUtils.TIMEZONE_OPTION -> DateTimeUtils.defaultTimeZone().getID) ).providingClass diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index bb6a27803bb20..6676099d426ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -352,4 +353,40 @@ class FileStreamSinkSuite extends StreamTest { assertAncestorIsNotMetadataDirectory(s"/a/b/c") assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra") } + + test("SPARK-20460 Check name duplication in schema") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + var query: StreamingQuery = null + try { + query = + df.writeStream + .option("checkpointLocation", checkpointDir) + .format("json") + .start(outputDir) + + inputData.addData((1, 1)) + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + if (query != null) { + query.stop() + } + } + + val errorMsg = intercept[AnalysisException] { + spark.read.schema(s"$c0 INT, $c1 INT").json(outputDir).as[(Int, Int)] + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema: ")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 306aecb5bbc86..569bac156b531 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -687,4 +688,91 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be testRead(spark.read.schema(userSchemaString).text(dir, dir), data ++ data, userSchema) testRead(spark.read.schema(userSchemaString).text(Seq(dir, dir): _*), data ++ data, userSchema) } + + test("SPARK-20460 Check name duplication in buckets") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF("col", c0).write.bucketBy(2, c0, c1).saveAsTable("t") + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the bucket definition")) + + errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF("col", c0).write.bucketBy(2, "col").sortBy(c0, c1).saveAsTable("t") + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the sort definition")) + } + } + } + + test("SPARK-20460 Check name duplication in schema") { + def checkWriteDataColumnDuplication( + format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val errorMsg = intercept[AnalysisException] { + Seq((1, 1)).toDF(colName0, colName1).write.format(format).mode("overwrite") + .save(tempDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) when inserting into")) + } + + def checkReadUserSpecifiedDataColumnDuplication( + df: DataFrame, format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val testDir = Utils.createTempDir(tempDir.getAbsolutePath) + df.write.format(format).mode("overwrite").save(testDir.getAbsolutePath) + val errorMsg = intercept[AnalysisException] { + spark.read.format(format).schema(s"$colName0 INT, $colName1 INT") + .load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema:")) + } + + def checkReadPartitionColumnDuplication( + format: String, colName0: String, colName1: String, tempDir: File): Unit = { + val testDir = Utils.createTempDir(tempDir.getAbsolutePath) + Seq(1).toDF("col").write.format(format).mode("overwrite") + .save(s"${testDir.getAbsolutePath}/$colName0=1/$colName1=1") + val errorMsg = intercept[AnalysisException] { + spark.read.format(format).load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the partition schema:")) + } + + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + withTempDir { src => + // Check CSV format + checkWriteDataColumnDuplication("csv", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "csv", c0, c1, src) + // If `inferSchema` is true, a CSV format is duplicate-safe (See SPARK-16896) + var testDir = Utils.createTempDir(src.getAbsolutePath) + Seq("a,a", "1,1").toDF().coalesce(1).write.mode("overwrite").text(testDir.getAbsolutePath) + val df = spark.read.format("csv").option("inferSchema", true).option("header", true) + .load(testDir.getAbsolutePath) + checkAnswer(df, Row(1, 1)) + checkReadPartitionColumnDuplication("csv", c0, c1, src) + + // Check JSON format + checkWriteDataColumnDuplication("json", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "json", c0, c1, src) + // Inferred schema cases + testDir = Utils.createTempDir(src.getAbsolutePath) + Seq(s"""{"$c0":3, "$c1":5}""").toDF().write.mode("overwrite") + .text(testDir.getAbsolutePath) + val errorMsg = intercept[AnalysisException] { + spark.read.format("json").option("inferSchema", true).load(testDir.getAbsolutePath) + }.getMessage + assert(errorMsg.contains("Found duplicate column(s) in the data schema:")) + checkReadPartitionColumnDuplication("json", c0, c1, src) + + // Check Parquet format + checkWriteDataColumnDuplication("parquet", c0, c1, src) + checkReadUserSpecifiedDataColumnDuplication( + Seq((1, 1)).toDF("c0", "c1"), "parquet", c0, c1, src) + checkReadPartitionColumnDuplication("parquet", c0, c1, src) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 31fa3d2447467..12daf3af11abe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -345,7 +345,7 @@ class HiveDDLSuite val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int) PARTITIONED BY (a string)") } - assert(e.message == "Found duplicate column(s) in table definition of `default`.`tbl`: a") + assert(e.message == "Found duplicate column(s) in the table definition of `default`.`tbl`: `a`") } test("add/drop partition with location - managed table") { From 6a06c4b03c4dd86241fb9d11b4360371488f0e53 Mon Sep 17 00:00:00 2001 From: jinxing Date: Mon, 10 Jul 2017 21:06:58 +0800 Subject: [PATCH 144/779] [SPARK-21342] Fix DownloadCallback to work well with RetryingBlockFetcher. ## What changes were proposed in this pull request? When `RetryingBlockFetcher` retries fetching blocks. There could be two `DownloadCallback`s download the same content to the same target file. It could cause `ShuffleBlockFetcherIterator` reading a partial result. This pr proposes to create and delete the tmp files in `OneForOneBlockFetcher` Author: jinxing Author: Shixiong Zhu Closes #18565 from jinxing64/SPARK-21342. --- .../shuffle/ExternalShuffleClient.java | 7 ++-- .../shuffle/OneForOneBlockFetcher.java | 34 +++++++++++------- .../spark/network/shuffle/ShuffleClient.java | 13 +++++-- .../shuffle/TempShuffleFileManager.java | 36 +++++++++++++++++++ .../network/sasl/SaslIntegrationSuite.java | 2 +- .../shuffle/OneForOneBlockFetcherSuite.java | 2 +- .../spark/network/BlockTransferService.scala | 8 ++--- .../netty/NettyBlockTransferService.scala | 9 +++-- .../storage/ShuffleBlockFetcherIterator.scala | 28 ++++++++++----- .../spark/storage/BlockManagerSuite.scala | 4 +-- .../ShuffleBlockFetcherIteratorSuite.scala | 10 +++--- 11 files changed, 108 insertions(+), 45 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6ac9302517ee0..31bd24e5038b2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,7 +17,6 @@ package org.apache.spark.network.shuffle; -import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; @@ -91,15 +90,15 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - File[] shuffleFiles) { + TempShuffleFileManager tempShuffleFileManager) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf, - shuffleFiles).start(); + new OneForOneBlockFetcher(client, appId, execId, + blockIds1, listener1, conf, tempShuffleFileManager).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index d46ce2e0e6b78..2f160d12af22b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -57,11 +57,21 @@ public class OneForOneBlockFetcher { private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; - private TransportConf transportConf = null; - private File[] shuffleFiles = null; + private final TransportConf transportConf; + private final TempShuffleFileManager tempShuffleFileManager; private StreamHandle streamHandle = null; + public OneForOneBlockFetcher( + TransportClient client, + String appId, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TransportConf transportConf) { + this(client, appId, execId, blockIds, listener, transportConf, null); + } + public OneForOneBlockFetcher( TransportClient client, String appId, @@ -69,18 +79,14 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - File[] shuffleFiles) { + TempShuffleFileManager tempShuffleFileManager) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - if (shuffleFiles != null) { - this.shuffleFiles = shuffleFiles; - assert this.shuffleFiles.length == blockIds.length: - "Number of shuffle files should equal to blocks"; - } + this.tempShuffleFileManager = tempShuffleFileManager; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -119,9 +125,9 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (shuffleFiles != null) { + if (tempShuffleFileManager != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), - new DownloadCallback(shuffleFiles[i], i)); + new DownloadCallback(i)); } else { client.fetchChunk(streamHandle.streamId, i, chunkCallback); } @@ -157,8 +163,8 @@ private class DownloadCallback implements StreamCallback { private File targetFile = null; private int chunkIndex; - DownloadCallback(File targetFile, int chunkIndex) throws IOException { - this.targetFile = targetFile; + DownloadCallback(int chunkIndex) throws IOException { + this.targetFile = tempShuffleFileManager.createTempShuffleFile(); this.channel = Channels.newChannel(new FileOutputStream(targetFile)); this.chunkIndex = chunkIndex; } @@ -174,6 +180,9 @@ public void onComplete(String streamId) throws IOException { ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + if (!tempShuffleFileManager.registerTempShuffleFileToClean(targetFile)) { + targetFile.delete(); + } } @Override @@ -182,6 +191,7 @@ public void onFailure(String streamId, Throwable cause) throws IOException { // On receipt of a failure, fail every block from chunkIndex onwards. String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); failRemainingBlocks(remainingBlockIds, cause); + targetFile.delete(); } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 978ff5a2a8699..9e77bee7f9ee6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -18,7 +18,6 @@ package org.apache.spark.network.shuffle; import java.io.Closeable; -import java.io.File; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ public abstract class ShuffleClient implements Closeable { @@ -35,6 +34,16 @@ public void init(String appId) { } * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. + * + * @param host the host of the remote node. + * @param port the port of the remote node. + * @param execId the executor id. + * @param blockIds block ids to fetch. + * @param listener the listener to receive block fetching status. + * @param tempShuffleFileManager TempShuffleFileManager to create and clean temp shuffle files. + * If it's not null, the remote blocks will be streamed + * into temp shuffle files to reduce the memory usage, otherwise, + * they will be kept in memory. */ public abstract void fetchBlocks( String host, @@ -42,5 +51,5 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - File[] shuffleFiles); + TempShuffleFileManager tempShuffleFileManager); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java new file mode 100644 index 0000000000000..84a5ed6a276bd --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; + +/** + * A manager to create temp shuffle block files to reduce the memory usage and also clean temp + * files when they won't be used any more. + */ +public interface TempShuffleFileManager { + + /** Create a temp shuffle block file. */ + File createTempShuffleFile(); + + /** + * Register a temp shuffle file to clean up when it won't be used any more. Return whether the + * file is registered successfully. If `false`, the caller should clean up the file by itself. + */ + boolean registerTempShuffleFileToClean(File file); +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 8110f1e004c73..02e6eb3a4467e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -204,7 +204,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" }; OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); + new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf); fetcher.start(); blockFetchLatch.await(); checkSecurityException(exception.get()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 61d82214e7d30..dc947a619bf02 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -131,7 +131,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap { diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 6860214c7fe39..fe5fd2da039bb 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,7 +17,7 @@ package org.apache.spark.network -import java.io.{Closeable, File} +import java.io.Closeable import java.nio.ByteBuffer import scala.concurrent.{Future, Promise} @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.ThreadUtils @@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - shuffleFiles: Array[File]): Unit + tempShuffleFileManager: TempShuffleFileManager): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -101,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }, shuffleFiles = null) + }, tempShuffleFileManager = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b13a9c681e543..30ff93897f98a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,7 +17,6 @@ package org.apache.spark.network.netty -import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -30,7 +29,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -90,14 +89,14 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - shuffleFiles: Array[File]): Unit = { + tempShuffleFileManager: TempShuffleFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener, - transportConf, shuffleFiles).start() + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, + transportConf, tempShuffleFileManager).start() } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a10f1feadd0af..81d822dc8a98f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -66,7 +66,7 @@ final class ShuffleBlockFetcherIterator( maxReqsInFlight: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with Logging { + extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -135,7 +135,8 @@ final class ShuffleBlockFetcherIterator( * A set to store the files used for shuffling remote huge blocks. Files in this set will be * deleted when cleanup. This is a layer of defensiveness against disk file leaks. */ - val shuffleFilesSet = mutable.HashSet[File]() + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[File]() initialize() @@ -149,6 +150,19 @@ final class ShuffleBlockFetcherIterator( currentResult = null } + override def createTempShuffleFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + + override def registerTempShuffleFileToClean(file: File): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + /** * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ @@ -176,7 +190,7 @@ final class ShuffleBlockFetcherIterator( } shuffleFilesSet.foreach { file => if (!file.delete()) { - logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()); + logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()) } } } @@ -221,12 +235,8 @@ final class ShuffleBlockFetcherIterator( // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - val shuffleFiles = blockIds.map { _ => - blockManager.diskBlockManager.createTempLocalBlock()._2 - }.toArray - shuffleFilesSet ++= shuffleFiles shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, shuffleFiles) + blockFetchingListener, this) } else { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, blockFetchingListener, null) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 086adccea954c..755a61a438a6a 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -45,7 +45,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus @@ -1382,7 +1382,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - shuffleFiles: Array[File]): Unit = { + tempShuffleFileManager: TempShuffleFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 559b3faab8fd2..6a70cedf769b8 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -432,12 +432,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var shuffleFiles: Array[File] = null + var tempShuffleFileManager: TempShuffleFileManager = null when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] + tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -466,13 +466,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(shuffleFiles === null) + assert(tempShuffleFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(shuffleFiles != null) + assert(tempShuffleFileManager != null) } } From 18b3b00ecfde6c694fb6fee4f4d07d04e3d08ccf Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Mon, 10 Jul 2017 09:26:42 -0700 Subject: [PATCH 145/779] [SPARK-21272] SortMergeJoin LeftAnti does not update numOutputRows ## What changes were proposed in this pull request? Updating numOutputRows metric was missing from one return path of LeftAnti SortMergeJoin. ## How was this patch tested? Non-zero output rows manually seen in metrics. Author: Juliusz Sompolski Closes #18494 from juliuszsompolski/SPARK-21272. --- .../sql/execution/joins/SortMergeJoinExec.scala | 1 + .../spark/sql/execution/metric/SQLMetricsSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 8445c26eeee58..639b8e00c121b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -290,6 +290,7 @@ case class SortMergeJoinExec( currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches if (currentRightMatches == null || currentRightMatches.length == 0) { + numOutputRows += 1 return true } var found = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cb3405b2fe19b..2911cbbeee479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -483,6 +483,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + test("SortMergeJoin(left-anti) metrics") { + val anti = testData2.filter("a > 2") + withTempView("antiData") { + anti.createOrReplaceTempView("antiData") + val df = spark.sql( + "SELECT * FROM testData2 ANTI JOIN antiData ON testData2.a = antiData.a") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("SortMergeJoin", Map("number of output rows" -> 4L))) + ) + } + } + test("save metrics") { withTempPath { file => // person creates a temporary view. get the DF before listing previous execution IDs From 2bfd5accdce2ae31feeeddf213a019cf8ec97663 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 10 Jul 2017 10:40:03 -0700 Subject: [PATCH 146/779] [SPARK-21266][R][PYTHON] Support schema a DDL-formatted string in dapply/gapply/from_json ## What changes were proposed in this pull request? This PR supports schema in a DDL formatted string for `from_json` in R/Python and `dapply` and `gapply` in R, which are commonly used and/or consistent with Scala APIs. Additionally, this PR exposes `structType` in R to allow working around in other possible corner cases. **Python** `from_json` ```python from pyspark.sql.functions import from_json data = [(1, '''{"a": 1}''')] df = spark.createDataFrame(data, ("key", "value")) df.select(from_json(df.value, "a INT").alias("json")).show() ``` **R** `from_json` ```R df <- sql("SELECT named_struct('name', 'Bob') as people") df <- mutate(df, people_json = to_json(df$people)) head(select(df, from_json(df$people_json, "name STRING"))) ``` `structType.character` ```R structType("a STRING, b INT") ``` `dapply` ```R dapply(createDataFrame(list(list(1.0)), "a"), function(x) {x}, "a DOUBLE") ``` `gapply` ```R gapply(createDataFrame(list(list(1.0)), "a"), "a", function(key, x) { x }, "a DOUBLE") ``` ## How was this patch tested? Doc tests for `from_json` in Python and unit tests `test_sparkSQL.R` in R. Author: hyukjinkwon Closes #18498 from HyukjinKwon/SPARK-21266. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 36 ++++- R/pkg/R/functions.R | 12 +- R/pkg/R/group.R | 3 + R/pkg/R/schema.R | 29 +++- R/pkg/tests/fulltests/test_sparkSQL.R | 136 ++++++++++-------- python/pyspark/sql/functions.py | 11 +- .../org/apache/spark/sql/functions.scala | 7 +- 8 files changed, 160 insertions(+), 76 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index b7fdae58de459..232f5cf31f319 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -429,6 +429,7 @@ export("structField", "structField.character", "print.structField", "structType", + "structType.character", "structType.jobj", "structType.structField", "print.structType") @@ -465,5 +466,6 @@ S3method(print, summary.GBTRegressionModel) S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) +S3method(structType, character) S3method(structType, jobj) S3method(structType, structField) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 3b9d42d6e7158..e7a166c3014c1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1391,6 +1391,10 @@ setMethod("summarize", }) dapplyInternal <- function(x, func, schema) { + if (is.character(schema)) { + schema <- structType(schema) + } + packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) @@ -1408,6 +1412,8 @@ dapplyInternal <- function(x, func, schema) { dataFrame(sdf) } +setClassUnion("characterOrstructType", c("character", "structType")) + #' dapply #' #' Apply a function to each partition of a SparkDataFrame. @@ -1418,10 +1424,11 @@ dapplyInternal <- function(x, func, schema) { #' to each partition will be passed. #' The output of func should be a R data.frame. #' @param schema The schema of the resulting SparkDataFrame after the function is applied. -#' It must match the output of func. +#' It must match the output of func. Since Spark 2.3, the DDL-formatted string +#' is also supported for the schema. #' @family SparkDataFrame functions #' @rdname dapply -#' @aliases dapply,SparkDataFrame,function,structType-method +#' @aliases dapply,SparkDataFrame,function,characterOrstructType-method #' @name dapply #' @seealso \link{dapplyCollect} #' @export @@ -1444,6 +1451,17 @@ dapplyInternal <- function(x, func, schema) { #' y <- cbind(y, y[1] + 1L) #' }, #' schema) +#' +#' # The schema also can be specified in a DDL-formatted string. +#' schema <- "a INT, d DOUBLE, c STRING, d INT" +#' df1 <- dapply( +#' df, +#' function(x) { +#' y <- x[x[1] > 1, ] +#' y <- cbind(y, y[1] + 1L) +#' }, +#' schema) +#' #' collect(df1) #' # the result #' # a b c d @@ -1452,7 +1470,7 @@ dapplyInternal <- function(x, func, schema) { #' } #' @note dapply since 2.0.0 setMethod("dapply", - signature(x = "SparkDataFrame", func = "function", schema = "structType"), + signature(x = "SparkDataFrame", func = "function", schema = "characterOrstructType"), function(x, func, schema) { dapplyInternal(x, func, schema) }) @@ -1522,6 +1540,7 @@ setMethod("dapplyCollect", #' @param schema the schema of the resulting SparkDataFrame after the function is applied. #' The schema must match to output of \code{func}. It has to be defined for each #' output column with preferred output column name and corresponding data type. +#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @return A SparkDataFrame. #' @family SparkDataFrame functions #' @aliases gapply,SparkDataFrame-method @@ -1541,7 +1560,7 @@ setMethod("dapplyCollect", #' #' Here our output contains three columns, the key which is a combination of two #' columns with data types integer and string and the mean which is a double. -#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' schema <- structType(structField("a", "integer"), structField("c", "string"), #' structField("avg", "double")) #' result <- gapply( #' df, @@ -1550,6 +1569,15 @@ setMethod("dapplyCollect", #' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) #' }, schema) #' +#' The schema also can be specified in a DDL-formatted string. +#' schema <- "a INT, c STRING, avg DOUBLE" +#' result <- gapply( +#' df, +#' c("a", "c"), +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' }, schema) +#' #' We can also group the data and afterwards call gapply on GroupedData. #' For Example: #' gdf <- group_by(df, "a", "c") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f28d26a51baa0..86507f13f038d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2174,8 +2174,9 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' #' @rdname column_collection_functions #' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. -#' @aliases from_json from_json,Column,structType-method +#' @aliases from_json from_json,Column,characterOrstructType-method #' @export #' @examples #' @@ -2188,10 +2189,15 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' df2 <- sql("SELECT named_struct('name', 'Bob') as people") #' df2 <- mutate(df2, people_json = to_json(df2$people)) #' schema <- structType(structField("name", "string")) -#' head(select(df2, from_json(df2$people_json, schema)))} +#' head(select(df2, from_json(df2$people_json, schema))) +#' head(select(df2, from_json(df2$people_json, "name STRING")))} #' @note from_json since 2.2.0 -setMethod("from_json", signature(x = "Column", schema = "structType"), +setMethod("from_json", signature(x = "Column", schema = "characterOrstructType"), function(x, schema, as.json.array = FALSE, ...) { + if (is.character(schema)) { + schema <- structType(schema) + } + if (as.json.array) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 17f5283abead1..0a7be0e993975 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -233,6 +233,9 @@ setMethod("gapplyCollect", }) gapplyInternal <- function(x, func, schema) { + if (is.character(schema)) { + schema <- structType(schema) + } packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) broadcastArr <- lapply(ls(.broadcastNames), diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index cb5bdb90175bf..d1ed6833d5d02 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -23,18 +23,24 @@ #' Create a structType object that contains the metadata for a SparkDataFrame. Intended for #' use with createDataFrame and toDF. #' -#' @param x a structField object (created with the field() function) +#' @param x a structField object (created with the \code{structField} method). Since Spark 2.3, +#' this can be a DDL-formatted string, which is a comma separated list of field +#' definitions, e.g., "a INT, b STRING". #' @param ... additional structField objects #' @return a structType object #' @rdname structType #' @export #' @examples #'\dontrun{ -#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' schema <- structType(structField("a", "integer"), structField("c", "string"), #' structField("avg", "double")) #' df1 <- gapply(df, list("a", "c"), #' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, #' schema) +#' schema <- structType("a INT, c STRING, avg DOUBLE") +#' df1 <- gapply(df, list("a", "c"), +#' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, +#' schema) #' } #' @note structType since 1.4.0 structType <- function(x, ...) { @@ -68,6 +74,23 @@ structType.structField <- function(x, ...) { structType(stObj) } +#' @rdname structType +#' @method structType character +#' @export +structType.character <- function(x, ...) { + if (!is.character(x)) { + stop("schema must be a DDL-formatted string.") + } + if (length(list(...)) > 0) { + stop("multiple DDL-formatted strings are not supported") + } + + stObj <- handledCallJStatic("org.apache.spark.sql.types.StructType", + "fromDDL", + x) + structType(stObj) +} + #' Print a Spark StructType. #' #' This function prints the contents of a StructType returned from the @@ -102,7 +125,7 @@ print.structType <- function(x, ...) { #' field1 <- structField("a", "integer") #' field2 <- structField("c", "string") #' field3 <- structField("avg", "double") -#' schema <- structType(field1, field2, field3) +#' schema <- structType(field1, field2, field3) #' df1 <- gapply(df, list("a", "c"), #' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, #' schema) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a2bcb5aefe16d..77052d4a28345 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -146,6 +146,13 @@ test_that("structType and structField", { expect_is(testSchema, "structType") expect_is(testSchema$fields()[[2]], "structField") expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") + + testSchema <- structType("a STRING, b INT") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") + + expect_error(structType("A stri"), "DataType stri is not supported.") }) test_that("structField type strings", { @@ -1480,13 +1487,15 @@ test_that("column functions", { j <- collect(select(df, alias(to_json(df$info), "json"))) expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") df <- as.DataFrame(j) - schema <- structType(structField("age", "integer"), - structField("height", "double")) - s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) - expect_equal(ncol(s), 1) - expect_equal(nrow(s), 3) - expect_is(s[[1]][[1]], "struct") - expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + schemas <- list(structType(structField("age", "integer"), structField("height", "double")), + "age INT, height DOUBLE") + for (schema in schemas) { + s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) + expect_equal(ncol(s), 1) + expect_equal(nrow(s), 3) + expect_is(s[[1]][[1]], "struct") + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + } # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) @@ -1504,14 +1513,15 @@ test_that("column functions", { # check if array type in string is correctly supported. jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" df <- as.DataFrame(list(list("people" = jsonArr))) - schema <- structType(structField("name", "string")) - arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) - expect_equal(ncol(arr), 1) - expect_equal(nrow(arr), 1) - expect_is(arr[[1]][[1]], "list") - expect_equal(length(arr$arrcol[[1]]), 2) - expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") - expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + for (schema in list(structType(structField("name", "string")), "name STRING")) { + arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) + expect_equal(ncol(arr), 1) + expect_equal(nrow(arr), 1) + expect_is(arr[[1]][[1]], "list") + expect_equal(length(arr$arrcol[[1]]), 2) + expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") + expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + } # Test create_array() and create_map() df <- as.DataFrame(data.frame( @@ -2885,30 +2895,33 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { expect_identical(ldf, result) # Filter and add a column - schema <- structType(structField("a", "integer"), structField("b", "double"), - structField("c", "string"), structField("d", "integer")) - df1 <- dapply( - df, - function(x) { - y <- x[x$a > 1, ] - y <- cbind(y, y$a + 1L) - }, - schema) - result <- collect(df1) - expected <- ldf[ldf$a > 1, ] - expected$d <- expected$a + 1L - rownames(expected) <- NULL - expect_identical(expected, result) - - result <- dapplyCollect( - df, - function(x) { - y <- x[x$a > 1, ] - y <- cbind(y, y$a + 1L) - }) - expected1 <- expected - names(expected1) <- names(result) - expect_identical(expected1, result) + schemas <- list(structType(structField("a", "integer"), structField("b", "double"), + structField("c", "string"), structField("d", "integer")), + "a INT, b DOUBLE, c STRING, d INT") + for (schema in schemas) { + df1 <- dapply( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }, + schema) + result <- collect(df1) + expected <- ldf[ldf$a > 1, ] + expected$d <- expected$a + 1L + rownames(expected) <- NULL + expect_identical(expected, result) + + result <- dapplyCollect( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }) + expected1 <- expected + names(expected1) <- names(result) + expect_identical(expected1, result) + } # Remove the added column df2 <- dapply( @@ -3020,29 +3033,32 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { # Computes the sum of second column by grouping on the first and third columns # and checks if the sum is larger than 2 - schema <- structType(structField("a", "integer"), structField("e", "boolean")) - df2 <- gapply( - df, - c(df$"a", df$"c"), - function(key, x) { - y <- data.frame(key[1], sum(x$b) > 2) - }, - schema) - actual <- collect(df2)$e - expected <- c(TRUE, TRUE) - expect_identical(actual, expected) - - df2Collect <- gapplyCollect( - df, - c(df$"a", df$"c"), - function(key, x) { - y <- data.frame(key[1], sum(x$b) > 2) - colnames(y) <- c("a", "e") - y - }) - actual <- df2Collect$e + schemas <- list(structType(structField("a", "integer"), structField("e", "boolean")), + "a INT, e BOOLEAN") + for (schema in schemas) { + df2 <- gapply( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + }, + schema) + actual <- collect(df2)$e + expected <- c(TRUE, TRUE) expect_identical(actual, expected) + df2Collect <- gapplyCollect( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + colnames(y) <- c("a", "e") + y + }) + actual <- df2Collect$e + expect_identical(actual, expected) + } + # Computes the arithmetic mean of the second column by grouping # on the first and third columns. Output the groupping value and the average. schema <- structType(structField("a", "integer"), structField("c", "string"), diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5d8ded83f667d..f3e7d033e97cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1883,15 +1883,20 @@ def from_json(col, schema, options={}): string. :param col: string column in json format - :param schema: a StructType or ArrayType of StructType to use when parsing the json column + :param schema: a StructType or ArrayType of StructType to use when parsing the json column. :param options: options to control parsing. accepts the same options as the json datasource + .. note:: Since Spark 2.3, the DDL-formatted string or a JSON format string is also + supported for ``schema``. + >>> from pyspark.sql.types import * >>> data = [(1, '''{"a": 1}''')] >>> schema = StructType([StructField("a", IntegerType())]) >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=Row(a=1))] + >>> df.select(from_json(df.value, "a INT").alias("json")).collect() + [Row(json=Row(a=1))] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) @@ -1900,7 +1905,9 @@ def from_json(col, schema, options={}): """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.from_json(_to_java_column(col), schema.json(), options) + if isinstance(schema, DataType): + schema = schema.json() + jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0c7b483f5c836..ebdeb42b0bfb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2114,7 +2114,7 @@ object functions { * Calculates the hash code of given columns, and returns the result as an int column. * * @group misc_funcs - * @since 2.0 + * @since 2.0.0 */ @scala.annotation.varargs def hash(cols: Column*): Column = withExpr { @@ -3074,9 +3074,8 @@ object functions { * string. * * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, - * the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL - * format is also supported for the schema. + * @param schema the schema to use when parsing the json string as a json string, it could be a + * JSON format string or a DDL-formatted string. * * @group collection_funcs * @since 2.3.0 From d03aebbe6508ba441dc87f9546f27aeb27553d77 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 10 Jul 2017 15:21:03 -0700 Subject: [PATCH 147/779] [SPARK-13534][PYSPARK] Using Apache Arrow to increase performance of DataFrame.toPandas ## What changes were proposed in this pull request? Integrate Apache Arrow with Spark to increase performance of `DataFrame.toPandas`. This has been done by using Arrow to convert data partitions on the executor JVM to Arrow payload byte arrays where they are then served to the Python process. The Python DataFrame can then collect the Arrow payloads where they are combined and converted to a Pandas DataFrame. Data types except complex, date, timestamp, and decimal are currently supported, otherwise an `UnsupportedOperation` exception is thrown. Additions to Spark include a Scala package private method `Dataset.toArrowPayload` that will convert data partitions in the executor JVM to `ArrowPayload`s as byte arrays so they can be easily served. A package private class/object `ArrowConverters` that provide data type mappings and conversion routines. In Python, a private method `DataFrame._collectAsArrow` is added to collect Arrow payloads and a SQLConf "spark.sql.execution.arrow.enable" can be used in `toPandas()` to enable using Arrow (uses the old conversion by default). ## How was this patch tested? Added a new test suite `ArrowConvertersSuite` that will run tests on conversion of Datasets to Arrow payloads for supported types. The suite will generate a Dataset and matching Arrow JSON data, then the dataset is converted to an Arrow payload and finally validated against the JSON data. This will ensure that the schema and data has been converted correctly. Added PySpark tests to verify the `toPandas` method is producing equal DataFrames with and without pyarrow. A roundtrip test to ensure the pandas DataFrame produced by pyspark is equal to a one made directly with pandas. Author: Bryan Cutler Author: Li Jin Author: Li Jin Author: Wes McKinney Closes #18459 from BryanCutler/toPandas_with_arrow-SPARK-13534. --- bin/pyspark | 2 +- dev/deps/spark-deps-hadoop-2.6 | 5 + dev/deps/spark-deps-hadoop-2.7 | 5 + pom.xml | 20 + python/pyspark/serializers.py | 17 + python/pyspark/sql/dataframe.py | 48 +- python/pyspark/sql/tests.py | 78 +- .../apache/spark/sql/internal/SQLConf.scala | 22 + sql/core/pom.xml | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 20 + .../sql/execution/arrow/ArrowConverters.scala | 429 ++++++ .../arrow/ArrowConvertersSuite.scala | 1222 +++++++++++++++++ 12 files changed, 1859 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala diff --git a/bin/pyspark b/bin/pyspark index d3b512eeb1209..dd286277c1fc1 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$1" + exec "$PYSPARK_DRIVER_PYTHON" -m "$@" exit fi diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index c1325318d52fa..1a6515be51cff 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ac5abd21807b6..09e5a4288ca50 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/pom.xml b/pom.xml index 5f524079495c0..f124ba45007b7 100644 --- a/pom.xml +++ b/pom.xml @@ -181,6 +181,7 @@ 2.6 1.8 1.0.0 + 0.4.0 ${java.home} @@ -1878,6 +1879,25 @@ paranamer ${paranamer.version} + + org.apache.arrow + arrow-vector + ${arrow.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + + io.netty + netty-handler + + + diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ea5e00e9eeef5..d5c2a7518b18f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -182,6 +182,23 @@ def loads(self, obj): raise NotImplementedError +class ArrowSerializer(FramedSerializer): + """ + Serializes an Arrow stream. + """ + + def dumps(self, obj): + raise NotImplementedError + + def loads(self, obj): + import pyarrow as pa + reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) + return reader.read_all() + + def __repr__(self): + return "ArrowSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 27a6dad8917d3..944739bcd2078 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,8 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -1710,7 +1711,8 @@ def toDF(self, *cols): @since(1.3) def toPandas(self): - """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + """ + Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. @@ -1723,18 +1725,42 @@ def toPandas(self): 1 5 Bob """ import pandas as pd + if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": + try: + import pyarrow + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + return table.to_pandas() + else: + return pd.DataFrame.from_records([], columns=self.columns) + except ImportError as e: + msg = "note: pyarrow must be installed and available on calling Python process " \ + "if using spark.sql.execution.arrow.enable=true" + raise ImportError("%s\n%s" % (e.message, msg)) + else: + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type - dtype = {} - for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - if pandas_type is not None: - dtype[field.name] = pandas_type + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - return pdf + def _collectAsArrow(self): + """ + Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed + and available. + + .. note:: Experimental. + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(port, ArrowSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9db2f40474f70..bd8477e35f37a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -58,12 +58,21 @@ from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier -from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException +_have_arrow = False +try: + import pyarrow + _have_arrow = True +except: + # No Arrow, but that's okay, we'll skip those tests + pass + + class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2843,6 +2852,73 @@ def __init__(self, **kwargs): _make_type_verifier(data_type, nullable=False)(obj) +@unittest.skipIf(not _have_arrow, "Arrow not installed") +class ArrowTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") + cls.schema = StructType([ + StructField("1_str_t", StringType(), True), + StructField("2_int_t", IntegerType(), True), + StructField("3_long_t", LongType(), True), + StructField("4_float_t", FloatType(), True), + StructField("5_double_t", DoubleType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0), + ("b", 2, 20, 0.4, 4.0), + ("c", 3, 30, 0.8, 6.0)] + + def assertFramesEqual(self, df_with_arrow, df_without): + msg = ("DataFrame from Arrow is not equal" + + ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + + ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) + self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + + def test_unsupported_datatype(self): + schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) + df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: df.toPandas()) + + def test_null_conversion(self): + df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + + self.data) + pdf = df_null.toPandas() + null_counts = pdf.isnull().sum().tolist() + self.assertTrue(all([c == 1 for c in null_counts])) + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + pdf = df.toPandas() + self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_pandas_round_trip(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + pdf = pd.DataFrame(data=data_dict) + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_filtered_frame(self): + df = self.spark.range(3).toDF("i") + pdf = df.filter("i < 0").toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "i") + self.assertTrue(pdf.empty) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 25152f3e32d6b..643587a6eb09d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -855,6 +855,24 @@ object SQLConf { .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + val ARROW_EXECUTION_ENABLE = + buildConf("spark.sql.execution.arrow.enable") + .internal() + .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + + "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + + "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + + "LongType, ShortType") + .booleanConf + .createWithDefault(false) + + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = + buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") + .internal() + .doc("When using Apache Arrow, limit the maximum number of records that can be written " + + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") + .intConf + .createWithDefault(10000) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1115,6 +1133,10 @@ class SQLConf extends Serializable with Logging { def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + + def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 1bc34a6b069d9..661c31ded7148 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,6 +103,10 @@ jackson-databind ${fasterxml.jackson.version} + + org.apache.arrow + arrow-vector + org.apache.xbean xbean-asm5-shaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dfb51192c69bc..a7773831df075 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -2907,6 +2908,16 @@ class Dataset[T] private[sql]( } } + /** + * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + */ + private[sql] def collectAsArrowToPython(): Int = { + withNewExecutionId { + val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + PythonRDD.serveIterator(iter, "serve-Arrow") + } + } + private[sql] def toPythonIterator(): Int = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) @@ -2988,4 +2999,13 @@ class Dataset[T] private[sql]( Dataset(sparkSession, logicalPlan) } } + + /** Convert to an RDD of ArrowPayload byte arrays */ + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + val schemaCaptured = this.schema + val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + queryExecution.toRdd.mapPartitionsInternal { iter => + ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala new file mode 100644 index 0000000000000..6af5c73422377 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -0,0 +1,429 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.arrow + +import java.io.ByteArrayOutputStream +import java.nio.channels.Channels + +import scala.collection.JavaConverters._ + +import io.netty.buffer.ArrowBuf +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.vector._ +import org.apache.arrow.vector.BaseValueVector.BaseMutator +import org.apache.arrow.vector.file._ +import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + + +/** + * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + */ +private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { + + /** + * Convert the ArrowPayload to an ArrowRecordBatch. + */ + def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { + ArrowConverters.byteArrayToBatch(payload, allocator) + } + + /** + * Get the ArrowPayload as a type that can be served to Python. + */ + def asPythonSerializable: Array[Byte] = payload +} + +private[sql] object ArrowPayload { + + /** + * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. + */ + def apply( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): ArrowPayload = { + new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) + } +} + +private[sql] object ArrowConverters { + + /** + * Map a Spark DataType to ArrowType. + */ + private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { + dataType match { + case BooleanType => ArrowType.Bool.INSTANCE + case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) + case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) + case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case ByteType => new ArrowType.Int(8, true) + case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") + } + } + + /** + * Convert a Spark Dataset schema to Arrow schema. + */ + private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { + val arrowFields = schema.fields.map { f => + new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) + } + new Schema(arrowFields.toList.asJava) + } + + /** + * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload + * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + */ + private[sql] def toPayloadIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { + new Iterator[ArrowPayload] { + private val _allocator = new RootAllocator(Long.MaxValue) + private var _nextPayload = if (rowIter.nonEmpty) convert() else null + + override def hasNext: Boolean = _nextPayload != null + + override def next(): ArrowPayload = { + val obj = _nextPayload + if (hasNext) { + if (rowIter.hasNext) { + _nextPayload = convert() + } else { + _allocator.close() + _nextPayload = null + } + } + obj + } + + private def convert(): ArrowPayload = { + val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) + ArrowPayload(batch, schema, _allocator) + } + } + } + + /** + * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed + * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, + * then rowIter will be fully consumed. + */ + private def internalRowIterToArrowBatch( + rowIter: Iterator[InternalRow], + schema: StructType, + allocator: BufferAllocator, + maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { + + val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => + ColumnWriter(field.dataType, ordinal, allocator).init() + } + + val writerLength = columnWriters.length + var recordsInBatch = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { + val row = rowIter.next() + var i = 0 + while (i < writerLength) { + columnWriters(i).write(row) + i += 1 + } + recordsInBatch += 1 + } + + val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip + val buffers = bufferArrays.flatten + + val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 + val recordBatch = new ArrowRecordBatch(rowLength, + fieldNodes.toList.asJava, buffers.toList.asJava) + + buffers.foreach(_.release()) + recordBatch + } + + /** + * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, + * the batch can no longer be used. + */ + private[arrow] def batchToByteArray( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): Array[Byte] = { + val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val out = new ByteArrayOutputStream() + val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + + // Write a batch to byte stream, ensure the batch, allocator and writer are closed + Utils.tryWithSafeFinally { + val loader = new VectorLoader(root) + loader.load(batch) + writer.writeBatch() // writeBatch can throw IOException + } { + batch.close() + root.close() + writer.close() + } + out.toByteArray + } + + /** + * Convert a byte array to an ArrowRecordBatch. + */ + private[arrow] def byteArrayToBatch( + batchBytes: Array[Byte], + allocator: BufferAllocator): ArrowRecordBatch = { + val in = new ByteArrayReadableSeekableByteChannel(batchBytes) + val reader = new ArrowFileReader(in, allocator) + + // Read a batch from a byte stream, ensure the reader is closed + Utils.tryWithSafeFinally { + val root = reader.getVectorSchemaRoot // throws IOException + val unloader = new VectorUnloader(root) + reader.loadNextBatch() // throws IOException + unloader.getRecordBatch + } { + reader.close() + } + } +} + +/** + * Interface for writing InternalRows to Arrow Buffers. + */ +private[arrow] trait ColumnWriter { + def init(): this.type + def write(row: InternalRow): Unit + + /** + * Clear the column writer and return the ArrowFieldNode and ArrowBuf. + * This should be called only once after all the data is written. + */ + def finish(): (ArrowFieldNode, Array[ArrowBuf]) +} + +/** + * Base class for flat arrow column writer, i.e., column without children. + */ +private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) + extends ColumnWriter { + + def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) + + def valueVector: BaseDataValueVector + def valueMutator: BaseMutator + + def setNull(): Unit + def setValue(row: InternalRow): Unit + + protected var count = 0 + protected var nullCount = 0 + + override def init(): this.type = { + valueVector.allocateNew() + this + } + + override def write(row: InternalRow): Unit = { + if (row.isNullAt(ordinal)) { + setNull() + nullCount += 1 + } else { + setValue(row) + } + count += 1 + } + + override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { + valueMutator.setValueCount(count) + val fieldNode = new ArrowFieldNode(count, nullCount) + val valueBuffers = valueVector.getBuffers(true) + (fieldNode, valueBuffers) + } +} + +private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableBitVector + = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) + override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) +} + +private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableSmallIntVector + = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) + override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getShort(ordinal)) +} + +private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableIntVector + = new NullableIntVector("IntValue", getFieldType(dtype), allocator) + override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getInt(ordinal)) +} + +private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableBigIntVector + = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) + override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getLong(ordinal)) +} + +private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableFloat4Vector + = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) + override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getFloat(ordinal)) +} + +private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableFloat8Vector + = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) + override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getDouble(ordinal)) +} + +private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableUInt1Vector + = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) + override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getByte(ordinal)) +} + +private[arrow] class UTF8StringColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableVarCharVector + = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) + override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + val str = row.getUTF8String(ordinal) + valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) + } +} + +private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableVarBinaryVector + = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) + override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + val bytes = row.getBinary(ordinal) + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +} + +private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableDateDayVector + = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) + override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + valueMutator.setSafe(count, row.getInt(ordinal)) + } +} + +private[arrow] class TimeStampColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableTimeStampMicroVector + = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) + override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + valueMutator.setSafe(count, row.getLong(ordinal)) + } +} + +private[arrow] object ColumnWriter { + + /** + * Create an Arrow ColumnWriter given the type and ordinal of row. + */ + def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { + val dtype = ArrowConverters.sparkTypeToArrowType(dataType) + dataType match { + case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) + case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) + case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) + case LongType => new LongColumnWriter(dtype, ordinal, allocator) + case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) + case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) + case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) + case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) + case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) + case DateType => new DateColumnWriter(dtype, ordinal, allocator) + case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala new file mode 100644 index 0000000000000..159328cc0d958 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -0,0 +1,1222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.arrow + +import java.io.File +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat +import java.util.Locale + +import com.google.common.io.Files +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.file.json.JsonFileReader +import org.apache.arrow.vector.util.Validator +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.apache.spark.util.Utils + + +class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { + import testImplicits._ + + private var tempDataPath: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath + } + + test("collect to arrow record batch") { + val indexData = (1 to 6).toDF("i") + val arrowPayloads = indexData.toArrowPayload.collect() + assert(arrowPayloads.nonEmpty) + assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val rowCount = arrowRecordBatches.map(_.getLength).sum + assert(rowCount === indexData.count()) + arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("short conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_s", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ] + | }, { + | "name" : "b_s", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -32768 ] + | } ] + | } ] + |} + """.stripMargin + + val a_s = List[Short](1, -1, 2, -2, 32767, -32768) + val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) + val df = a_s.zip(b_s).toDF("a_s", "b_s") + + collectAndValidate(df, json, "integer-16bit.json") + } + + test("int conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + collectAndValidate(df, json, "integer-32bit.json") + } + + test("long conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_l", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ] + | }, { + | "name" : "b_l", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ] + | } ] + | } ] + |} + """.stripMargin + + val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) + val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) + val df = a_l.zip(b_l).toDF("a_l", "b_l") + + collectAndValidate(df, json, "integer-64bit.json") + } + + test("float conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_f", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ] + | }, { + | "name" : "b_f", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) + val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) + val df = a_f.zip(b_f).toDF("a_f", "b_f") + + collectAndValidate(df, json, "floating_point-single_precision.json") + } + + test("double conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ] + | }, { + | "name" : "b_d", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) + val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "floating_point-double_precision.json") + } + + test("index conversion") { + val data = List[Int](1, 2, 3, 4, 5, 6) + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + val df = data.toDF("i") + + collectAndValidate(df, json, "indexData-ints.json") + } + + test("mixed numeric type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "e", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "b", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "c", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "e", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + + val data = List(1, 2, 3, 4, 5, 6) + val data_tuples = for (d <- data) yield { + (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) + } + val df = data_tuples.toDF("a", "b", "c", "d", "e") + + collectAndValidate(df, json, "mixed_numeric_types.json") + } + + test("string type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "upper_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "lower_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "null_str", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "upper_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "A", "B", "C" ] + | }, { + | "name" : "lower_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "a", "b", "c" ] + | }, { + | "name" : "null_str", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 0 ], + | "OFFSET" : [ 0, 2, 5, 5 ], + | "DATA" : [ "ab", "CDE", "" ] + | } ] + | } ] + |} + """.stripMargin + + val upperCase = Seq("A", "B", "C") + val lowerCase = Seq("a", "b", "c") + val nullStr = Seq("ab", "CDE", null) + val df = (upperCase, lowerCase, nullStr).zipped.toList + .toDF("upper_case", "lower_case", "null_str") + + collectAndValidate(df, json, "stringData.json") + } + + test("boolean type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_bool", + | "type" : { + | "name" : "bool" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_bool", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ true, true, false, true ] + | } ] + | } ] + |} + """.stripMargin + val df = Seq(true, true, false, true).toDF("a_bool") + collectAndValidate(df, json, "boolData.json") + } + + test("byte type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_byte", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 8 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_byte", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 64, 127 ] + | } ] + | } ] + |} + | + """.stripMargin + val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") + collectAndValidate(df, json, "byteData.json") + } + + test("binary type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_binary", + | "type" : { + | "name" : "binary" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_binary", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 3, 4, 6 ], + | "DATA" : [ "616263", "64", "6566" ] + | } ] + | } ] + |} + """.stripMargin + + val data = Seq("abc", "d", "ef") + val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) + val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) + + collectAndValidate(df, json, "binaryData.json") + } + + test("floating-point NaN") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "NaN_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "NaN_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 2, + | "columns" : [ { + | "name" : "NaN_f", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1.2000000476837158, "NaN" ] + | }, { + | "name" : "NaN_d", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ "NaN", 1.2 ] + | } ] + | } ] + |} + """.stripMargin + + val fnan = Seq(1.2F, Float.NaN) + val dnan = Seq(Double.NaN, 1.2) + val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") + + collectAndValidate(df, json, "nanData-floating_point.json") + } + + test("partitioned DataFrame") { + val json1 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 1, 2 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 1 ] + | } ] + | } ] + |} + """.stripMargin + val json2 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 3, 3 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 1, 2 ] + | } ] + | } ] + |} + """.stripMargin + + val arrowPayloads = testData2.toArrowPayload.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + assert(arrowPayloads.length === 2) + val schema = testData2.schema + + val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") + val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") + Files.write(json1, tempFile1, StandardCharsets.UTF_8) + Files.write(json2, tempFile2, StandardCharsets.UTF_8) + + validateConversion(schema, arrowPayloads(0), tempFile1) + validateConversion(schema, arrowPayloads(1), tempFile2) + } + + test("empty frame collect") { + val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() + assert(arrowPayload.isEmpty) + + val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") + val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() + assert(filteredArrowPayload.isEmpty) + } + + test("empty partition collect") { + val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") + val arrowPayloads = emptyPart.toArrowPayload.collect() + assert(arrowPayloads.length === 1) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + assert(arrowRecordBatches.head.getLength == 1) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("max records in batch conf") { + val totalRecords = 10 + val maxRecordsPerBatch = 3 + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) + val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") + val arrowPayloads = df.toArrowPayload.collect() + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + var recordCount = 0 + arrowRecordBatches.foreach { batch => + assert(batch.getLength > 0) + assert(batch.getLength <= maxRecordsPerBatch) + recordCount += batch.getLength + batch.close() + } + assert(recordCount == totalRecords) + allocator.close() + spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + } + + testQuietly("unsupported types") { + def runUnsupported(block: => Unit): Unit = { + val msg = intercept[SparkException] { + block + } + assert(msg.getMessage.contains("Unsupported data type")) + assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) + } + + runUnsupported { decimalData.toArrowPayload.collect() } + runUnsupported { arrayData.toDF().toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowPayload.collect() } + runUnsupported { complexData.toArrowPayload.collect() } + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) + runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } + + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } + } + + test("test Arrow Validator") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + val json_diff_col_order = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + // Different schema + intercept[IllegalArgumentException] { + collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json") + } + + // Different values + intercept[IllegalArgumentException] { + collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json") + } + } + + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ + private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { + // NOTE: coalesce to single partition because can only load 1 batch in validator + val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val tempFile = new File(tempDataPath, file) + Files.write(json, tempFile, StandardCharsets.UTF_8) + validateConversion(df.schema, arrowPayload, tempFile) + } + + private def validateConversion( + sparkSchema: StructType, + arrowPayload: ArrowPayload, + jsonFile: File): Unit = { + val allocator = new RootAllocator(Long.MaxValue) + val jsonReader = new JsonFileReader(jsonFile, allocator) + + val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) + val jsonSchema = jsonReader.start() + Validator.compareSchemas(arrowSchema, jsonSchema) + + val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) + val vectorLoader = new VectorLoader(arrowRoot) + val arrowRecordBatch = arrowPayload.loadBatch(allocator) + vectorLoader.load(arrowRecordBatch) + val jsonRoot = jsonReader.read() + Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) + + jsonRoot.close() + jsonReader.close() + arrowRecordBatch.close() + arrowRoot.close() + allocator.close() + } +} From c3713fde86204bf3f027483914ff9e60e7aad261 Mon Sep 17 00:00:00 2001 From: chie8842 Date: Mon, 10 Jul 2017 18:56:54 -0700 Subject: [PATCH 148/779] [SPARK-21358][EXAMPLES] Argument of repartitionandsortwithinpartitions at pyspark ## What changes were proposed in this pull request? At example of repartitionAndSortWithinPartitions at rdd.py, third argument should be True or False. I proposed fix of example code. ## How was this patch tested? * I rename test_repartitionAndSortWithinPartitions to test_repartitionAndSortWIthinPartitions_asc to specify boolean argument. * I added test_repartitionAndSortWithinPartitions_desc to test False pattern at third argument. (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: chie8842 Closes #18586 from chie8842/SPARK-21358. --- python/pyspark/rdd.py | 2 +- python/pyspark/tests.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7dfa17f68a943..3325b65f8b600 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -608,7 +608,7 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p sort records by their keys. >>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)]) - >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2) + >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, True) >>> rdd2.glom().collect() [[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]] """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bb13de563cdd4..73ab442dfd791 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1019,14 +1019,22 @@ def test_histogram(self): self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) self.assertRaises(TypeError, lambda: rdd.histogram(2)) - def test_repartitionAndSortWithinPartitions(self): + def test_repartitionAndSortWithinPartitions_asc(self): rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) partitions = repartitioned.glom().collect() self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + def test_repartitionAndSortWithinPartitions_desc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) + self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) + def test_repartition_no_skewed(self): num_partitions = 20 a = self.sc.parallelize(range(int(1000)), 2) From a2bec6c92a063f4a8e9ed75a9f3f06808485b6d7 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 10 Jul 2017 20:16:29 -0700 Subject: [PATCH 149/779] [SPARK-21043][SQL] Add unionByName in Dataset ## What changes were proposed in this pull request? This pr added `unionByName` in `DataSet`. Here is how to use: ``` val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") df1.unionByName(df2).show // output: // +----+----+----+ // |col0|col1|col2| // +----+----+----+ // | 1| 2| 3| // | 6| 4| 5| // +----+----+----+ ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #18300 from maropu/SPARK-21043-2. --- .../scala/org/apache/spark/sql/Dataset.scala | 60 +++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 87 +++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a7773831df075..7f3ae05411516 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils @@ -1734,6 +1735,65 @@ class Dataset[T] private[sql]( CombineUnions(Union(logicalPlan, other.logicalPlan)) } + /** + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. + * + * This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set + * union (that does deduplication of elements), use this function followed by a [[distinct]]. + * + * The difference between this function and [[union]] is that this function + * resolves columns by name (not by position): + * + * {{{ + * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") + * val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") + * df1.unionByName(df2).show + * + * // output: + * // +----+----+----+ + * // |col0|col1|col2| + * // +----+----+----+ + * // | 1| 2| 3| + * // | 6| 4| 5| + * // +----+----+----+ + * }}} + * + * @group typedrel + * @since 2.3.0 + */ + def unionByName(other: Dataset[T]): Dataset[T] = withSetOperator { + // Check column name duplication + val resolver = sparkSession.sessionState.analyzer.resolver + val leftOutputAttrs = logicalPlan.output + val rightOutputAttrs = other.logicalPlan.output + + SchemaUtils.checkColumnNameDuplication( + leftOutputAttrs.map(_.name), + "in the left attributes", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + SchemaUtils.checkColumnNameDuplication( + rightOutputAttrs.map(_.name), + "in the right attributes", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + // Builds a project list for `other` based on `logicalPlan` output names + val rightProjectList = leftOutputAttrs.map { lattr => + rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { + throw new AnalysisException( + s"""Cannot resolve column name "${lattr.name}" among """ + + s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""") + } + } + + // Delegates failure checks to `CheckAnalysis` + val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) + val rightChild = Project(rightProjectList ++ notFoundAttrs, other.logicalPlan) + + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(logicalPlan, rightChild)) + } + /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset. * This is equivalent to `INTERSECT` in SQL. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a5a2e1c38d300..5ae27032e0e94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -111,6 +111,93 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } + test("union by name") { + var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") + val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") + val unionDf = df1.unionByName(df2.unionByName(df3)) + checkAnswer(unionDf, + Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil + ) + + // Check if adjacent unions are combined into a single one + assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) + + // Check failure cases + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains( + "Union can only be performed on tables with the same number of columns, " + + "but the first table has 2 columns and the second table has 3 columns")) + + df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + df2 = Seq((4, 5, 6)).toDF("a", "c", "d") + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) + } + + test("union by name - type coercion") { + var df1 = Seq((1, "a")).toDF("c0", "c1") + var df2 = Seq((3, 1L)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) + + df1 = Seq((1, 1.0)).toDF("c0", "c1") + df2 = Seq((8L, 3.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) + + df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") + df2 = Seq(("a", 4.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) + + df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") + df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") + val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") + checkAnswer(df1.unionByName(df2.unionByName(df3)), + Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil + ) + } + + test("union by name - check case sensitivity") { + def checkCaseSensitiveTest(): Unit = { + val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") + val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") + checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg2 = intercept[AnalysisException] { + checkCaseSensitiveTest() + }.getMessage + assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkCaseSensitiveTest() + } + } + + test("union by name - check name duplication") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var df1 = Seq((1, 1)).toDF(c0, c1) + var df2 = Seq((1, 1)).toDF("c0", "c1") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) + df1 = Seq((1, 1)).toDF("c0", "c1") + df2 = Seq((1, 1)).toDF(c0, c1) + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) + } + } + } + test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) From 1471ee7af5a9952b60cf8c56d60cb6a7ec46cc69 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 11 Jul 2017 11:19:59 +0800 Subject: [PATCH 150/779] [SPARK-21350][SQL] Fix the error message when the number of arguments is wrong when invoking a UDF ### What changes were proposed in this pull request? Users get a very confusing error when users specify a wrong number of parameters. ```Scala val df = spark.emptyDataFrame spark.udf.register("foo", (_: String).length) df.selectExpr("foo(2, 3, 4)") ``` ``` org.apache.spark.sql.UDFSuite$$anonfun$9$$anonfun$apply$mcV$sp$12 cannot be cast to scala.Function3 java.lang.ClassCastException: org.apache.spark.sql.UDFSuite$$anonfun$9$$anonfun$apply$mcV$sp$12 cannot be cast to scala.Function3 at org.apache.spark.sql.catalyst.expressions.ScalaUDF.(ScalaUDF.scala:109) ``` This PR is to capture the exception and issue an error message that is consistent with what we did for built-in functions. After the fix, the error message is improved to ``` Invalid number of arguments for function foo; line 1 pos 0 org.apache.spark.sql.AnalysisException: Invalid number of arguments for function foo; line 1 pos 0 at org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry.lookupFunction(FunctionRegistry.scala:119) ``` ### How was this patch tested? Added a test case Author: gatorsmile Closes #18574 from gatorsmile/statsCheck. --- .../apache/spark/sql/UDFRegistration.scala | 412 +++++++++++++----- .../org/apache/spark/sql/JavaUDFSuite.java | 8 + .../scala/org/apache/spark/sql/UDFSuite.scala | 13 +- 3 files changed, 331 insertions(+), 102 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 8bdc0221888d0..c4d0adb5236f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -111,7 +111,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == $x) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: $x; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) }""") @@ -123,16 +128,20 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]" val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") println(s""" - |/** - | * Register a user-defined function with ${i} arguments. - | * @since 1.3.0 - | */ - |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { - | val func = f$anyCast.call($anyParams) - | functionRegistry.createOrReplaceTempFunction( - | name, - | (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) - |}""".stripMargin) + |/** + | * Register a user-defined function with ${i} arguments. + | * @since 1.3.0 + | */ + |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { + | val func = f$anyCast.call($anyParams) + |def builder(e: Seq[Expression]) = if (e.length == $i) { + | ScalaUDF(func, returnType, e) + |} else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $i; Found: " + e.length) + |} + |functionRegistry.createOrReplaceTempFunction(name, builder) + |}""".stripMargin) } */ @@ -144,7 +153,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 0) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 0; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -157,7 +171,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 1) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 1; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -170,7 +189,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 2) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 2; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -183,7 +207,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 3) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 3; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -196,7 +225,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 4) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 4; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -209,7 +243,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 5) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 5; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -222,7 +261,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 6) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 6; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -235,7 +279,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 7) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 7; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -248,7 +297,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 8) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 8; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -261,7 +315,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 9) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 9; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -274,7 +333,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 10) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 10; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -287,7 +351,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 11) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 11; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -300,7 +369,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 12) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 12; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -313,7 +387,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 13) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 13; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -326,7 +405,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 14) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 14; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -339,7 +423,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 15) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 15; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -352,7 +441,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 16) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 16; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -365,7 +459,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 17) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 17; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -378,7 +477,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 18) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 18; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -391,7 +495,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 19) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 19; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -404,7 +513,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 20) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 20; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -417,7 +531,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 21) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 21; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -430,7 +549,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + def builder(e: Seq[Expression]) = if (e.length == 22) { + ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 22; Found: " + e.length) + } functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -531,9 +655,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 1) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 1; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -542,9 +670,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 2) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 2; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -553,9 +685,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 3) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 3; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -564,9 +700,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 4) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 4; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -575,9 +715,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 5) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 5; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -586,9 +730,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 6) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 6; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -597,9 +745,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 7) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 7; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -608,9 +760,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 8) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 8; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -619,9 +775,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 9) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 9; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -630,9 +790,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 10) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 10; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -641,9 +805,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 11) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 11; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -652,9 +820,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 12) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 12; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -663,9 +835,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 13) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 13; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -674,9 +850,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 14) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 14; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -685,9 +865,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 15) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 15; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -696,9 +880,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 16) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 16; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -707,9 +895,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 17) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 17; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -718,9 +910,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 18) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 18; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -729,9 +925,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 19) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 19; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -740,9 +940,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 20) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 20; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -751,9 +955,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 21) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 21; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } /** @@ -762,9 +970,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.createOrReplaceTempFunction( - name, - (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) + def builder(e: Seq[Expression]) = if (e.length == 22) { + ScalaUDF(func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 22; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) } // scalastyle:on line.size.limit diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 250fa674d8ecc..4fb2988f24d26 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -25,6 +25,7 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF2; @@ -105,4 +106,11 @@ public void udf4Test() { } Assert.assertEquals(55, sum); } + + @SuppressWarnings("unchecked") + @Test(expected = AnalysisException.class) + public void udf5Test() { + spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); + List results = spark.sql("SELECT inc(1, 5)").collectAsList(); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index b4f744b193ada..335b882ace92a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -71,12 +71,21 @@ class UDFSuite extends QueryTest with SharedSQLContext { } } - test("error reporting for incorrect number of arguments") { + test("error reporting for incorrect number of arguments - builtin function") { val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } - assert(e.getMessage.contains("arguments")) + assert(e.getMessage.contains("Invalid number of arguments for function substr")) + } + + test("error reporting for incorrect number of arguments - udf") { + val df = spark.emptyDataFrame + val e = intercept[AnalysisException] { + spark.udf.register("foo", (_: String).length) + df.selectExpr("foo(2, 3, 4)") + } + assert(e.getMessage.contains("Invalid number of arguments for function foo")) } test("error reporting for undefined functions") { From 833eab2c9bd273ee9577fbf9e480d3e3a4b7d203 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 11 Jul 2017 11:26:17 +0800 Subject: [PATCH 151/779] [SPARK-21369][CORE] Don't use Scala Tuple2 in common/network-* ## What changes were proposed in this pull request? Remove all usages of Scala Tuple2 from common/network-* projects. Otherwise, Yarn users cannot use `spark.reducer.maxReqSizeShuffleToMem`. ## How was this patch tested? Jenkins. Author: Shixiong Zhu Closes #18593 from zsxwing/SPARK-21369. --- common/network-common/pom.xml | 3 ++- .../client/TransportResponseHandler.java | 20 +++++++++---------- .../server/OneForOneStreamManager.java | 17 +++++----------- common/network-shuffle/pom.xml | 1 + common/network-yarn/pom.xml | 1 + 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 066970f24205f..0254d0cefc368 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -90,7 +90,8 @@ org.apache.spark spark-tags_${scala.binary.version} - + test + + runtime log4j @@ -1859,9 +1859,9 @@ ${antlr4.version} - ${jline.groupid} + jline jline - ${jline.version} + 2.12.1 org.apache.commons @@ -1933,6 +1933,7 @@ --> org.jboss.netty org.codehaus.groovy + *:*_2.10 true @@ -1987,6 +1988,8 @@ -unchecked -deprecation -feature + -explaintypes + -Yno-adapted-args -Xms1024m @@ -2585,44 +2588,6 @@ - - scala-2.10 - - scala-2.10 - - - 2.10.6 - 2.10 - ${scala.version} - org.scala-lang - - - - - org.apache.maven.plugins - maven-enforcer-plugin - - - enforce-versions - - enforce - - - - - - *:*_2.11 - - - - - - - - - - - test-java-home @@ -2633,16 +2598,18 @@ + scala-2.11 - - !scala-2.10 - + + + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 89b0c7a3ab7b0..41f3a0451aa8a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -87,19 +87,11 @@ object SparkBuild extends PomBuild { val projectsMap: Map[String, Seq[Setting[_]]] = Map.empty override val profiles = { - val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { + Properties.envOrNone("SBT_MAVEN_PROFILES") match { case None => Seq("sbt") case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } - - if (System.getProperty("scala-2.10") == "") { - // To activate scala-2.10 profile, replace empty property value to non-empty value - // in the same way as Maven which handles -Dname as -Dname=true before executes build process. - // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.10", "true") - } - profiles } Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { @@ -234,9 +226,7 @@ object SparkBuild extends PomBuild { }, javacJVMVersion := "1.8", - // SBT Scala 2.10 build still doesn't support Java 8, because scalac 2.10 doesn't, but, - // it also doesn't touch Java 8 code and it's OK to emit Java 7 bytecode in this case - scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8"), + scalacJVMVersion := "1.8", javacOptions in Compile ++= Seq( "-encoding", "UTF-8", @@ -477,7 +467,6 @@ object OldDeps { def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", - scalaVersion := "2.10.5", libraryDependencies := allPreviousArtifactKeys.value.flatten ) } @@ -756,13 +745,7 @@ object CopyDependencies { object TestSettings { import BuildCommons._ - private val scalaBinaryVersion = - if (System.getProperty("scala-2.10") == "true") { - "2.10" - } else { - "2.11" - } - + private val scalaBinaryVersion = "2.11" lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, diff --git a/python/run-tests.py b/python/run-tests.py index b2e50435bb192..afd3d29a0ff90 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -54,7 +54,8 @@ def print_red(text): LOGGER = logging.getLogger() # Find out where the assembly jars are located. -for scala in ["2.11", "2.10"]: +# Later, add back 2.12 to this list: +for scala in ["2.11"]: build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) if os.path.isdir(build_dir): SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") diff --git a/repl/pom.xml b/repl/pom.xml index 6d133a3cfff7d..51eb9b60dd54a 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -32,8 +32,8 @@ repl - scala-2.10/src/main/scala - scala-2.10/src/test/scala + scala-2.11/src/main/scala + scala-2.11/src/test/scala @@ -71,7 +71,7 @@ ${scala.version} - ${jline.groupid} + jline jline @@ -170,23 +170,17 @@ + + + diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala deleted file mode 100644 index fba321be91886..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import org.apache.spark.internal.Logging - -object Main extends Logging { - - initializeLogIfNecessary(true) - Signaling.cancelOnInterrupt() - - private var _interp: SparkILoop = _ - - def interp = _interp - - def interp_=(i: SparkILoop) { _interp = i } - - def main(args: Array[String]) { - _interp = new SparkILoop - _interp.process(args) - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala deleted file mode 100644 index be9b79021d2a8..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import scala.tools.nsc.{CompilerCommand, Settings} - -import org.apache.spark.annotation.DeveloperApi - -/** - * Command class enabling Spark-specific command line options (provided by - * org.apache.spark.repl.SparkRunnerSettings). - * - * @example new SparkCommandLine(Nil).settings - * - * @param args The list of command line arguments - * @param settings The underlying settings to associate with this set of - * command-line options - */ -@DeveloperApi -class SparkCommandLine(args: List[String], override val settings: Settings) - extends CompilerCommand(args, settings) { - def this(args: List[String], error: String => Unit) { - this(args, new SparkRunnerSettings(error)) - } - - def this(args: List[String]) { - // scalastyle:off println - this(args, str => Console.println("Error: " + str)) - // scalastyle:on println - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala deleted file mode 100644 index 2b5d56a895902..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ /dev/null @@ -1,114 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Paul Phillips - */ - -package org.apache.spark.repl - -import scala.tools.nsc._ -import scala.tools.nsc.interpreter._ - -import scala.reflect.internal.util.BatchSourceFile -import scala.tools.nsc.ast.parser.Tokens.EOF - -import org.apache.spark.internal.Logging - -private[repl] trait SparkExprTyper extends Logging { - val repl: SparkIMain - - import repl._ - import global.{ reporter => _, Import => _, _ } - import definitions._ - import syntaxAnalyzer.{ UnitParser, UnitScanner, token2name } - import naming.freshInternalVarName - - object codeParser extends { val global: repl.global.type = repl.global } with CodeHandlers[Tree] { - def applyRule[T](code: String, rule: UnitParser => T): T = { - reporter.reset() - val scanner = newUnitParser(code) - val result = rule(scanner) - - if (!reporter.hasErrors) - scanner.accept(EOF) - - result - } - - def defns(code: String) = stmts(code) collect { case x: DefTree => x } - def expr(code: String) = applyRule(code, _.expr()) - def stmts(code: String) = applyRule(code, _.templateStats()) - def stmt(code: String) = stmts(code).last // guaranteed nonempty - } - - /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */ - def parse(line: String): Option[List[Tree]] = debugging(s"""parse("$line")""") { - var isIncomplete = false - reporter.withIncompleteHandler((_, _) => isIncomplete = true) { - val trees = codeParser.stmts(line) - if (reporter.hasErrors) { - Some(Nil) - } else if (isIncomplete) { - None - } else { - Some(trees) - } - } - } - // def parsesAsExpr(line: String) = { - // import codeParser._ - // (opt expr line).isDefined - // } - - def symbolOfLine(code: String): Symbol = { - def asExpr(): Symbol = { - val name = freshInternalVarName() - // Typing it with a lazy val would give us the right type, but runs - // into compiler bugs with things like existentials, so we compile it - // behind a def and strip the NullaryMethodType which wraps the expr. - val line = "def " + name + " = {\n" + code + "\n}" - - interpretSynthetic(line) match { - case IR.Success => - val sym0 = symbolOfTerm(name) - // drop NullaryMethodType - val sym = sym0.cloneSymbol setInfo afterTyper(sym0.info.finalResultType) - if (sym.info.typeSymbol eq UnitClass) NoSymbol else sym - case _ => NoSymbol - } - } - def asDefn(): Symbol = { - val old = repl.definedSymbolList.toSet - - interpretSynthetic(code) match { - case IR.Success => - repl.definedSymbolList filterNot old match { - case Nil => NoSymbol - case sym :: Nil => sym - case syms => NoSymbol.newOverloaded(NoPrefix, syms) - } - case _ => NoSymbol - } - } - beQuietDuring(asExpr()) orElse beQuietDuring(asDefn()) - } - - private var typeOfExpressionDepth = 0 - def typeOfExpression(expr: String, silent: Boolean = true): Type = { - if (typeOfExpressionDepth > 2) { - logDebug("Terminating typeOfExpression recursion for expression: " + expr) - return NoType - } - typeOfExpressionDepth += 1 - // Don't presently have a good way to suppress undesirable success output - // while letting errors through, so it is first trying it silently: if there - // is an error, and errors are desired, then it re-evaluates non-silently - // to induce the error message. - try beSilentDuring(symbolOfLine(expr).tpe) match { - case NoType if !silent => symbolOfLine(expr).tpe // generate error - case tpe => tpe - } - finally typeOfExpressionDepth -= 1 - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala deleted file mode 100644 index 955be17a73b85..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package scala.tools.nsc - -import org.apache.spark.annotation.DeveloperApi - -// NOTE: Forced to be public (and in scala.tools.nsc package) to access the -// settings "explicitParentLoader" method - -/** - * Provides exposure for the explicitParentLoader method on settings instances. - */ -@DeveloperApi -object SparkHelper { - /** - * Retrieves the explicit parent loader for the provided settings. - * - * @param settings The settings whose explicit parent loader to retrieve - * - * @return The Optional classloader representing the explicit parent loader - */ - @DeveloperApi - def explicitParentLoader(settings: Settings) = settings.explicitParentLoader -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala deleted file mode 100644 index b7237a6ce822f..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ /dev/null @@ -1,1145 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Alexander Spoon - */ - -package org.apache.spark.repl - - -import java.net.URL - -import scala.reflect.io.AbstractFile -import scala.tools.nsc._ -import scala.tools.nsc.backend.JavaPlatform -import scala.tools.nsc.interpreter._ -import scala.tools.nsc.interpreter.{Results => IR} -import Predef.{println => _, _} -import java.io.{BufferedReader, FileReader} -import java.net.URI -import java.util.concurrent.locks.ReentrantLock -import scala.sys.process.Process -import scala.tools.nsc.interpreter.session._ -import scala.util.Properties.{jdkHome, javaVersion} -import scala.tools.util.{Javap} -import scala.annotation.tailrec -import scala.collection.mutable.ListBuffer -import scala.concurrent.ops -import scala.tools.nsc.util._ -import scala.tools.nsc.interpreter._ -import scala.tools.nsc.io.{File, Directory} -import scala.reflect.NameTransformer._ -import scala.tools.nsc.util.ScalaClassLoader._ -import scala.tools.util._ -import scala.language.{implicitConversions, existentials, postfixOps} -import scala.reflect.{ClassTag, classTag} -import scala.tools.reflect.StdRuntimeTags._ - -import java.lang.{Class => jClass} -import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} - -import org.apache.spark.SparkConf -import org.apache.spark.SparkContext -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.util.Utils - -/** The Scala interactive shell. It provides a read-eval-print loop - * around the Interpreter class. - * After instantiation, clients should call the main() method. - * - * If no in0 is specified, then input will come from the console, and - * the class will attempt to provide input editing feature such as - * input history. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - * @version 1.2 - */ -@DeveloperApi -class SparkILoop( - private val in0: Option[BufferedReader], - protected val out: JPrintWriter, - val master: Option[String] -) extends AnyRef with LoopCommands with SparkILoopInit with Logging { - def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master)) - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None) - def this() = this(None, new JPrintWriter(Console.out, true), None) - - private var in: InteractiveReader = _ // the input stream from which commands come - - // NOTE: Exposed in package for testing - private[repl] var settings: Settings = _ - - private[repl] var intp: SparkIMain = _ - - @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp - @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i - - /** Having inherited the difficult "var-ness" of the repl instance, - * I'm trying to work around it by moving operations into a class from - * which it will appear a stable prefix. - */ - private def onIntp[T](f: SparkIMain => T): T = f(intp) - - class IMainOps[T <: SparkIMain](val intp: T) { - import intp._ - import global._ - - def printAfterTyper(msg: => String) = - intp.reporter printMessage afterTyper(msg) - - /** Strip NullaryMethodType artifacts. */ - private def replInfo(sym: Symbol) = { - sym.info match { - case NullaryMethodType(restpe) if sym.isAccessor => restpe - case info => info - } - } - def echoTypeStructure(sym: Symbol) = - printAfterTyper("" + deconstruct.show(replInfo(sym))) - - def echoTypeSignature(sym: Symbol, verbose: Boolean) = { - if (verbose) SparkILoop.this.echo("// Type signature") - printAfterTyper("" + replInfo(sym)) - - if (verbose) { - SparkILoop.this.echo("\n// Internal Type structure") - echoTypeStructure(sym) - } - } - } - implicit def stabilizeIMain(intp: SparkIMain) = new IMainOps[intp.type](intp) - - /** TODO - - * -n normalize - * -l label with case class parameter names - * -c complete - leave nothing out - */ - private def typeCommandInternal(expr: String, verbose: Boolean): Result = { - onIntp { intp => - val sym = intp.symbolOfLine(expr) - if (sym.exists) intp.echoTypeSignature(sym, verbose) - else "" - } - } - - // NOTE: Must be public for visibility - @DeveloperApi - var sparkContext: SparkContext = _ - - override def echoCommandMessage(msg: String) { - intp.reporter printMessage msg - } - - // def isAsync = !settings.Yreplsync.value - private[repl] def isAsync = false - // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) - private def history = in.history - - /** The context class loader at the time this object was created */ - protected val originalClassLoader = Utils.getContextOrSparkClassLoader - - // classpath entries added via :cp - private var addedClasspath: String = "" - - /** A reverse list of commands to replay if the user requests a :replay */ - private var replayCommandStack: List[String] = Nil - - /** A list of commands to replay if the user requests a :replay */ - private def replayCommands = replayCommandStack.reverse - - /** Record a command for replay should the user request a :replay */ - private def addReplay(cmd: String) = replayCommandStack ::= cmd - - private def savingReplayStack[T](body: => T): T = { - val saved = replayCommandStack - try body - finally replayCommandStack = saved - } - private def savingReader[T](body: => T): T = { - val saved = in - try body - finally in = saved - } - - - private def sparkCleanUp() { - echo("Stopping spark context.") - intp.beQuietDuring { - command("sc.stop()") - } - } - /** Close the interpreter and set the var to null. */ - private def closeInterpreter() { - if (intp ne null) { - sparkCleanUp() - intp.close() - intp = null - } - } - - class SparkILoopInterpreter extends SparkIMain(settings, out) { - outer => - - override private[repl] lazy val formatting = new Formatting { - def prompt = SparkILoop.this.prompt - } - override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader) - } - - /** - * Constructs a new interpreter. - */ - protected def createInterpreter() { - require(settings != null) - - if (addedClasspath != "") settings.classpath.append(addedClasspath) - val addedJars = - if (Utils.isWindows) { - // Strip any URI scheme prefix so we can add the correct path to the classpath - // e.g. file:/C:/my/path.jar -> C:/my/path.jar - getAddedJars().map { jar => new URI(jar).getPath.stripPrefix("/") } - } else { - // We need new URI(jar).getPath here for the case that `jar` includes encoded white space (%20). - getAddedJars().map { jar => new URI(jar).getPath } - } - // work around for Scala bug - val totalClassPath = addedJars.foldLeft( - settings.classpath.value)((l, r) => ClassPath.join(l, r)) - this.settings.classpath.value = totalClassPath - - intp = new SparkILoopInterpreter - } - - /** print a friendly help message */ - private def helpCommand(line: String): Result = { - if (line == "") helpSummary() - else uniqueCommand(line) match { - case Some(lc) => echo("\n" + lc.longHelp) - case _ => ambiguousError(line) - } - } - private def helpSummary() = { - val usageWidth = commands map (_.usageMsg.length) max - val formatStr = "%-" + usageWidth + "s %s %s" - - echo("All commands can be abbreviated, e.g. :he instead of :help.") - echo("Those marked with a * have more detailed help, e.g. :help imports.\n") - - commands foreach { cmd => - val star = if (cmd.hasLongHelp) "*" else " " - echo(formatStr.format(cmd.usageMsg, star, cmd.help)) - } - } - private def ambiguousError(cmd: String): Result = { - matchingCommands(cmd) match { - case Nil => echo(cmd + ": no such command. Type :help for help.") - case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") - } - Result(true, None) - } - private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) - private def uniqueCommand(cmd: String): Option[LoopCommand] = { - // this lets us add commands willy-nilly and only requires enough command to disambiguate - matchingCommands(cmd) match { - case List(x) => Some(x) - // exact match OK even if otherwise appears ambiguous - case xs => xs find (_.name == cmd) - } - } - private var fallbackMode = false - - private def toggleFallbackMode() { - val old = fallbackMode - fallbackMode = !old - System.setProperty("spark.repl.fallback", fallbackMode.toString) - echo(s""" - |Switched ${if (old) "off" else "on"} fallback mode without restarting. - | If you have defined classes in the repl, it would - |be good to redefine them incase you plan to use them. If you still run - |into issues it would be good to restart the repl and turn on `:fallback` - |mode as first command. - """.stripMargin) - } - - /** Show the history */ - private lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { - override def usage = "[num]" - def defaultLines = 20 - - def apply(line: String): Result = { - if (history eq NoHistory) - return "No history available." - - val xs = words(line) - val current = history.index - val count = try xs.head.toInt catch { case _: Exception => defaultLines } - val lines = history.asStrings takeRight count - val offset = current - lines.size + 1 - - for ((line, index) <- lines.zipWithIndex) - echo("%3d %s".format(index + offset, line)) - } - } - - // When you know you are most likely breaking into the middle - // of a line being typed. This softens the blow. - private[repl] def echoAndRefresh(msg: String) = { - echo("\n" + msg) - in.redrawLine() - } - private[repl] def echo(msg: String) = { - out println msg - out.flush() - } - private def echoNoNL(msg: String) = { - out print msg - out.flush() - } - - /** Search the history */ - private def searchHistory(_cmdline: String) { - val cmdline = _cmdline.toLowerCase - val offset = history.index - history.size + 1 - - for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) - echo("%d %s".format(index + offset, line)) - } - - private var currentPrompt = Properties.shellPromptString - - /** - * Sets the prompt string used by the REPL. - * - * @param prompt The new prompt string - */ - @DeveloperApi - def setPrompt(prompt: String) = currentPrompt = prompt - - /** - * Represents the current prompt string used by the REPL. - * - * @return The current prompt string - */ - @DeveloperApi - def prompt = currentPrompt - - import LoopCommand.{ cmd, nullary } - - /** Standard commands */ - private lazy val standardCommands = List( - cmd("cp", "", "add a jar or directory to the classpath", addClasspath), - cmd("help", "[command]", "print this summary or command-specific help", helpCommand), - historyCommand, - cmd("h?", "", "search the history", searchHistory), - cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), - cmd("implicits", "[-v]", "show the implicits in scope", implicitsCommand), - cmd("javap", "", "disassemble a file or class name", javapCommand), - cmd("load", "", "load and interpret a Scala file", loadCommand), - nullary("paste", "enter paste mode: all input up to ctrl-D compiled together", pasteCommand), -// nullary("power", "enable power user mode", powerCmd), - nullary("quit", "exit the repl", () => Result(false, None)), - nullary("replay", "reset execution and replay all previous commands", replay), - nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), - shCommand, - nullary("silent", "disable/enable automatic printing of results", verbosity), - nullary("fallback", """ - |disable/enable advanced repl changes, these fix some issues but may introduce others. - |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), - cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), - nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) - ) - - /** Power user commands */ - private lazy val powerCommands: List[LoopCommand] = List( - // cmd("phase", "", "set the implicit phase for power commands", phaseCommand) - ) - - // private def dumpCommand(): Result = { - // echo("" + power) - // history.asStrings takeRight 30 foreach echo - // in.redrawLine() - // } - // private def valsCommand(): Result = power.valsDescription - - private val typeTransforms = List( - "scala.collection.immutable." -> "immutable.", - "scala.collection.mutable." -> "mutable.", - "scala.collection.generic." -> "generic.", - "java.lang." -> "jl.", - "scala.runtime." -> "runtime." - ) - - private def importsCommand(line: String): Result = { - val tokens = words(line) - val handlers = intp.languageWildcardHandlers ++ intp.importHandlers - val isVerbose = tokens contains "-v" - - handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { - case (handler, idx) => - val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) - val imps = handler.implicitSymbols - val found = tokens filter (handler importsSymbolNamed _) - val typeMsg = if (types.isEmpty) "" else types.size + " types" - val termMsg = if (terms.isEmpty) "" else terms.size + " terms" - val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" - val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") - val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") - - intp.reporter.printMessage("%2d) %-30s %s%s".format( - idx + 1, - handler.importString, - statsMsg, - foundMsg - )) - } - } - - private def implicitsCommand(line: String): Result = onIntp { intp => - import intp._ - import global._ - - def p(x: Any) = intp.reporter.printMessage("" + x) - - // If an argument is given, only show a source with that - // in its name somewhere. - val args = line split "\\s+" - val filtered = intp.implicitSymbolsBySource filter { - case (source, syms) => - (args contains "-v") || { - if (line == "") (source.fullName.toString != "scala.Predef") - else (args exists (source.name.toString contains _)) - } - } - - if (filtered.isEmpty) - return "No implicits have been imported other than those in Predef." - - filtered foreach { - case (source, syms) => - p("/* " + syms.size + " implicit members imported from " + source.fullName + " */") - - // This groups the members by where the symbol is defined - val byOwner = syms groupBy (_.owner) - val sortedOwners = byOwner.toList sortBy { case (owner, _) => afterTyper(source.info.baseClasses indexOf owner) } - - sortedOwners foreach { - case (owner, members) => - // Within each owner, we cluster results based on the final result type - // if there are more than a couple, and sort each cluster based on name. - // This is really just trying to make the 100 or so implicits imported - // by default into something readable. - val memberGroups: List[List[Symbol]] = { - val groups = members groupBy (_.tpe.finalResultType) toList - val (big, small) = groups partition (_._2.size > 3) - val xss = ( - (big sortBy (_._1.toString) map (_._2)) :+ - (small flatMap (_._2)) - ) - - xss map (xs => xs sortBy (_.name.toString)) - } - - val ownerMessage = if (owner == source) " defined in " else " inherited from " - p(" /* " + members.size + ownerMessage + owner.fullName + " */") - - memberGroups foreach { group => - group foreach (s => p(" " + intp.symbolDefString(s))) - p("") - } - } - p("") - } - } - - private def findToolsJar() = { - val jdkPath = Directory(jdkHome) - val jar = jdkPath / "lib" / "tools.jar" toFile; - - if (jar isFile) - Some(jar) - else if (jdkPath.isDirectory) - jdkPath.deepFiles find (_.name == "tools.jar") - else None - } - private def addToolsJarToLoader() = { - val cl = findToolsJar match { - case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) - case _ => intp.classLoader - } - if (Javap.isAvailable(cl)) { - logDebug(":javap available.") - cl - } - else { - logDebug(":javap unavailable: no tools.jar at " + jdkHome) - intp.classLoader - } - } - - private def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { - override def tryClass(path: String): Array[Byte] = { - val hd :: rest = path split '.' toList; - // If there are dots in the name, the first segment is the - // key to finding it. - if (rest.nonEmpty) { - intp optFlatName hd match { - case Some(flat) => - val clazz = flat :: rest mkString NAME_JOIN_STRING - val bytes = super.tryClass(clazz) - if (bytes.nonEmpty) bytes - else super.tryClass(clazz + MODULE_SUFFIX_STRING) - case _ => super.tryClass(path) - } - } - else { - // Look for Foo first, then Foo$, but if Foo$ is given explicitly, - // we have to drop the $ to find object Foo, then tack it back onto - // the end of the flattened name. - def className = intp flatName path - def moduleName = (intp flatName path.stripSuffix(MODULE_SUFFIX_STRING)) + MODULE_SUFFIX_STRING - - val bytes = super.tryClass(className) - if (bytes.nonEmpty) bytes - else super.tryClass(moduleName) - } - } - } - // private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) - private lazy val javap = - try newJavap() - catch { case _: Exception => null } - - // Still todo: modules. - private def typeCommand(line0: String): Result = { - line0.trim match { - case "" => ":type [-v] " - case s if s startsWith "-v " => typeCommandInternal(s stripPrefix "-v " trim, true) - case s => typeCommandInternal(s, false) - } - } - - private def warningsCommand(): Result = { - if (intp.lastWarnings.isEmpty) - "Can't find any cached warnings." - else - intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } - } - - private def javapCommand(line: String): Result = { - if (javap == null) - ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) - else if (javaVersion startsWith "1.7") - ":javap not yet working with java 1.7" - else if (line == "") - ":javap [-lcsvp] [path1 path2 ...]" - else - javap(words(line)) foreach { res => - if (res.isError) return "Failed: " + res.value - else res.show() - } - } - - private def wrapCommand(line: String): Result = { - def failMsg = "Argument to :wrap must be the name of a method with signature [T](=> T): T" - onIntp { intp => - import intp._ - import global._ - - words(line) match { - case Nil => - intp.executionWrapper match { - case "" => "No execution wrapper is set." - case s => "Current execution wrapper: " + s - } - case "clear" :: Nil => - intp.executionWrapper match { - case "" => "No execution wrapper is set." - case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper." - } - case wrapper :: Nil => - intp.typeOfExpression(wrapper) match { - case PolyType(List(targ), MethodType(List(arg), restpe)) => - intp setExecutionWrapper intp.pathToTerm(wrapper) - "Set wrapper to '" + wrapper + "'" - case tp => - failMsg + "\nFound: " - } - case _ => failMsg - } - } - } - - private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent" - // private def phaseCommand(name: String): Result = { - // val phased: Phased = power.phased - // import phased.NoPhaseName - - // if (name == "clear") { - // phased.set(NoPhaseName) - // intp.clearExecutionWrapper() - // "Cleared active phase." - // } - // else if (name == "") phased.get match { - // case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" - // case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) - // } - // else { - // val what = phased.parse(name) - // if (what.isEmpty || !phased.set(what)) - // "'" + name + "' does not appear to represent a valid phase." - // else { - // intp.setExecutionWrapper(pathToPhaseWrapper) - // val activeMessage = - // if (what.toString.length == name.length) "" + what - // else "%s (%s)".format(what, name) - - // "Active phase is now: " + activeMessage - // } - // } - // } - - /** - * Provides a list of available commands. - * - * @return The list of commands - */ - @DeveloperApi - def commands: List[LoopCommand] = standardCommands /*++ ( - if (isReplPower) powerCommands else Nil - )*/ - - private val replayQuestionMessage = - """|That entry seems to have slain the compiler. Shall I replay - |your session? I can re-run each line except the last one. - |[y/n] - """.trim.stripMargin - - private def crashRecovery(ex: Throwable): Boolean = { - echo(ex.toString) - ex match { - case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("\nUnrecoverable error.") - throw ex - case _ => - def fn(): Boolean = - try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) - catch { case _: RuntimeException => false } - - if (fn()) replay() - else echo("\nAbandoning crashed session.") - } - true - } - - /** The main read-eval-print loop for the repl. It calls - * command() for each line of input, and stops when - * command() returns false. - */ - private def loop() { - def readOneLine() = { - out.flush() - in readLine prompt - } - // return false if repl should exit - def processLine(line: String): Boolean = { - if (isAsync) { - if (!awaitInitialized()) return false - runThunks() - } - if (line eq null) false // assume null means EOF - else command(line) match { - case Result(false, _) => false - case Result(_, Some(finalLine)) => addReplay(finalLine) ; true - case _ => true - } - } - def innerLoop() { - val shouldContinue = try { - processLine(readOneLine()) - } catch {case t: Throwable => crashRecovery(t)} - if (shouldContinue) - innerLoop() - } - innerLoop() - } - - /** interpret all lines from a specified file */ - private def interpretAllFrom(file: File) { - savingReader { - savingReplayStack { - file applyReader { reader => - in = SimpleReader(reader, out, false) - echo("Loading " + file + "...") - loop() - } - } - } - } - - /** create a new interpreter and replay the given commands */ - private def replay() { - reset() - if (replayCommandStack.isEmpty) - echo("Nothing to replay.") - else for (cmd <- replayCommands) { - echo("Replaying: " + cmd) // flush because maybe cmd will have its own output - command(cmd) - echo("") - } - } - private def resetCommand() { - echo("Resetting repl state.") - if (replayCommandStack.nonEmpty) { - echo("Forgetting this session history:\n") - replayCommands foreach echo - echo("") - replayCommandStack = Nil - } - if (intp.namedDefinedTerms.nonEmpty) - echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) - if (intp.definedTypes.nonEmpty) - echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) - - reset() - } - - private def reset() { - intp.reset() - // unleashAndSetPhase() - } - - /** fork a shell and run a command */ - private lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { - override def usage = "" - def apply(line: String): Result = line match { - case "" => showUsage() - case _ => - val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")" - intp interpret toRun - () - } - } - - private def withFile(filename: String)(action: File => Unit) { - val f = File(filename) - - if (f.exists) action(f) - else echo("That file does not exist") - } - - private def loadCommand(arg: String) = { - var shouldReplay: Option[String] = None - withFile(arg)(f => { - interpretAllFrom(f) - shouldReplay = Some(":load " + arg) - }) - Result(true, shouldReplay) - } - - private def addAllClasspath(args: Seq[String]): Unit = { - var added = false - var totalClasspath = "" - for (arg <- args) { - val f = File(arg).normalize - if (f.exists) { - added = true - addedClasspath = ClassPath.join(addedClasspath, f.path) - totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) - intp.addUrlsToClassPath(f.toURI.toURL) - sparkContext.addJar(f.toURI.toURL.getPath) - } - } - } - - private def addClasspath(arg: String): Unit = { - val f = File(arg).normalize - if (f.exists) { - addedClasspath = ClassPath.join(addedClasspath, f.path) - intp.addUrlsToClassPath(f.toURI.toURL) - sparkContext.addJar(f.toURI.toURL.getPath) - echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, intp.global.classPath.asClasspathString)) - } - else echo("The path '" + f + "' doesn't seem to exist.") - } - - - private def powerCmd(): Result = { - if (isReplPower) "Already in power mode." - else enablePowerMode(false) - } - - private[repl] def enablePowerMode(isDuringInit: Boolean) = { - // replProps.power setValue true - // unleashAndSetPhase() - // asyncEcho(isDuringInit, power.banner) - } - // private def unleashAndSetPhase() { -// if (isReplPower) { -// // power.unleash() -// // Set the phase to "typer" -// intp beSilentDuring phaseCommand("typer") -// } -// } - - private def asyncEcho(async: Boolean, msg: => String) { - if (async) asyncMessage(msg) - else echo(msg) - } - - private def verbosity() = { - // val old = intp.printResults - // intp.printResults = !old - // echo("Switched " + (if (old) "off" else "on") + " result printing.") - } - - /** - * Run one command submitted by the user. Two values are returned: - * (1) whether to keep running, (2) the line to record for replay, - * if any. - */ - private[repl] def command(line: String): Result = { - if (line startsWith ":") { - val cmd = line.tail takeWhile (x => !x.isWhitespace) - uniqueCommand(cmd) match { - case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) - case _ => ambiguousError(cmd) - } - } - else if (intp.global == null) Result(false, None) // Notice failure to create compiler - else Result(true, interpretStartingWith(line)) - } - - private def readWhile(cond: String => Boolean) = { - Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) - } - - private def pasteCommand(): Result = { - echo("// Entering paste mode (ctrl-D to finish)\n") - val code = readWhile(_ => true) mkString "\n" - echo("\n// Exiting paste mode, now interpreting.\n") - intp interpret code - () - } - - private object paste extends Pasted { - val ContinueString = " | " - val PromptString = "scala> " - - def interpret(line: String): Unit = { - echo(line.trim) - intp interpret line - echo("") - } - - def transcript(start: String) = { - echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") - apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) - } - } - import paste.{ ContinueString, PromptString } - - /** - * Interpret expressions starting with the first line. - * Read lines until a complete compilation unit is available - * or until a syntax error has been seen. If a full unit is - * read, go ahead and interpret it. Return the full string - * to be recorded for replay, if any. - */ - private def interpretStartingWith(code: String): Option[String] = { - // signal completion non-completion input has been received - in.completion.resetVerbosity() - - def reallyInterpret = { - val reallyResult = intp.interpret(code) - (reallyResult, reallyResult match { - case IR.Error => None - case IR.Success => Some(code) - case IR.Incomplete => - if (in.interactive && code.endsWith("\n\n")) { - echo("You typed two blank lines. Starting a new command.") - None - } - else in.readLine(ContinueString) match { - case null => - // we know compilation is going to fail since we're at EOF and the - // parser thinks the input is still incomplete, but since this is - // a file being read non-interactively we want to fail. So we send - // it straight to the compiler for the nice error message. - intp.compileString(code) - None - - case line => interpretStartingWith(code + "\n" + line) - } - }) - } - - /** Here we place ourselves between the user and the interpreter and examine - * the input they are ostensibly submitting. We intervene in several cases: - * - * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. - * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation - * on the previous result. - * 3) If the Completion object's execute returns Some(_), we inject that value - * and avoid the interpreter, as it's likely not valid scala code. - */ - if (code == "") None - else if (!paste.running && code.trim.startsWith(PromptString)) { - paste.transcript(code) - None - } - else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { - interpretStartingWith(intp.mostRecentVar + code) - } - else if (code.trim startsWith "//") { - // line comment, do nothing - None - } - else - reallyInterpret._2 - } - - // runs :load `file` on any files passed via -i - private def loadFiles(settings: Settings) = settings match { - case settings: SparkRunnerSettings => - for (filename <- settings.loadfiles.value) { - val cmd = ":load " + filename - command(cmd) - addReplay(cmd) - echo("") - } - case _ => - } - - /** Tries to create a JLineReader, falling back to SimpleReader: - * unless settings or properties are such that it should start - * with SimpleReader. - */ - private def chooseReader(settings: Settings): InteractiveReader = { - if (settings.Xnojline.value || Properties.isEmacsShell) - SimpleReader() - else try new SparkJLineReader( - if (settings.noCompletion.value) NoCompletion - else new SparkJLineCompletion(intp) - ) - catch { - case ex @ (_: Exception | _: NoClassDefFoundError) => - echo("Failed to created SparkJLineReader: " + ex + "\nFalling back to SimpleReader.") - SimpleReader() - } - } - - private val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - private val m = u.runtimeMirror(Utils.getSparkClassLoader) - private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = - u.TypeTag[T]( - m, - new TypeCreator { - def apply[U <: ApiUniverse with Singleton](m: Mirror[U]): U # Type = - m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] - }) - - private def process(settings: Settings): Boolean = savingContextLoader { - this.settings = settings - createInterpreter() - - // sets in to some kind of reader depending on environmental cues - in = in0 match { - case Some(reader) => SimpleReader(reader, out, true) - case None => - // some post-initialization - chooseReader(settings) match { - case x: SparkJLineReader => addThunk(x.consoleReader.postInit) ; x - case x => x - } - } - lazy val tagOfSparkIMain = tagOfStaticClass[org.apache.spark.repl.SparkIMain] - // Bind intp somewhere out of the regular namespace where - // we can get at it in generated code. - addThunk(intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfSparkIMain, classTag[SparkIMain]))) - addThunk({ - import scala.tools.nsc.io._ - import Properties.userHome - import scala.compat.Platform.EOL - val autorun = replProps.replAutorunCode.option flatMap (f => io.File(f).safeSlurp()) - if (autorun.isDefined) intp.quietRun(autorun.get) - }) - - addThunk(printWelcome()) - addThunk(initializeSpark()) - - // it is broken on startup; go ahead and exit - if (intp.reporter.hasErrors) - return false - - // This is about the illusion of snappiness. We call initialize() - // which spins off a separate thread, then print the prompt and try - // our best to look ready. The interlocking lazy vals tend to - // inter-deadlock, so we break the cycle with a single asynchronous - // message to an rpcEndpoint. - if (isAsync) { - intp initialize initializedCallback() - createAsyncListener() // listens for signal to run postInitialization - } - else { - intp.initializeSynchronous() - postInitialization() - } - // printWelcome() - - loadFiles(settings) - - try loop() - catch AbstractOrMissingHandler() - finally closeInterpreter() - - true - } - - // NOTE: Must be public for visibility - @DeveloperApi - def createSparkSession(): SparkSession = { - val execUri = System.getenv("SPARK_EXECUTOR_URI") - val jars = getAddedJars() - val conf = new SparkConf() - .setMaster(getMaster()) - .setJars(jars) - .setIfMissing("spark.app.name", "Spark shell") - // SparkContext will detect this configuration and register it with the RpcEnv's - // file server, setting spark.repl.class.uri to the actual URI for executors to - // use. This is sort of ugly but since executors are started as part of SparkContext - // initialization in certain cases, there's an initialization order issue that prevents - // this from being set after SparkContext is instantiated. - .set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath()) - if (execUri != null) { - conf.set("spark.executor.uri", execUri) - } - - val builder = SparkSession.builder.config(conf) - val sparkSession = if (SparkSession.hiveClassesArePresent) { - logInfo("Creating Spark session with Hive support") - builder.enableHiveSupport().getOrCreate() - } else { - logInfo("Creating Spark session") - builder.getOrCreate() - } - sparkContext = sparkSession.sparkContext - sparkSession - } - - private def getMaster(): String = { - val master = this.master match { - case Some(m) => m - case None => - val envMaster = sys.env.get("MASTER") - val propMaster = sys.props.get("spark.master") - propMaster.orElse(envMaster).getOrElse("local[*]") - } - master - } - - /** process command-line arguments and do as they request */ - def process(args: Array[String]): Boolean = { - val command = new SparkCommandLine(args.toList, msg => echo(msg)) - def neededHelp(): String = - (if (command.settings.help.value) command.usageMsg + "\n" else "") + - (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "") - - // if they asked for no help and command is valid, we call the real main - neededHelp() match { - case "" => command.ok && process(command.settings) - case help => echoNoNL(help) ; true - } - } - - @deprecated("Use `process` instead", "2.9.0") - private def main(settings: Settings): Unit = process(settings) - - @DeveloperApi - def getAddedJars(): Array[String] = { - val conf = new SparkConf().setMaster(getMaster()) - val envJars = sys.env.get("ADD_JARS") - if (envJars.isDefined) { - logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") - } - val jars = { - val userJars = Utils.getUserJars(conf, isShell = true) - if (userJars.isEmpty) { - envJars.getOrElse("") - } else { - userJars.mkString(",") - } - } - Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) - } - -} - -object SparkILoop extends Logging { - implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp - private def echo(msg: String) = Console println msg - - // Designed primarily for use by test code: take a String with a - // bunch of code, and prints out a transcript of what it would look - // like if you'd just typed it into the repl. - private[repl] def runForTranscript(code: String, settings: Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { - override def write(str: String) = { - // completely skip continuation lines - if (str forall (ch => ch.isWhitespace || ch == '|')) () - // print a newline on empty scala prompts - else if ((str contains '\n') && (str.trim == "scala> ")) super.write("\n") - else super.write(str) - } - } - val input = new BufferedReader(new StringReader(code)) { - override def readLine(): String = { - val s = super.readLine() - // helping out by printing the line being interpreted. - if (s != null) - // scalastyle:off println - output.println(s) - // scalastyle:on println - s - } - } - val repl = new SparkILoop(input, output) - - if (settings.classpath.isDefault) - settings.classpath.value = sys.props("java.class.path") - - repl.getAddedJars().map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_)) - - repl process settings - } - } - } - - /** Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ - private[repl] def run(code: String, sets: Settings = new Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new ILoop(input, output) - - if (sets.classpath.isDefault) - sets.classpath.value = sys.props("java.class.path") - - repl process sets - } - } - } - private[repl] def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala deleted file mode 100644 index 5f0d92bccd809..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ /dev/null @@ -1,168 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Paul Phillips - */ - -package org.apache.spark.repl - -import scala.tools.nsc._ -import scala.tools.nsc.interpreter._ - -import scala.tools.nsc.util.stackTraceString - -import org.apache.spark.SPARK_VERSION - -/** - * Machinery for the asynchronous initialization of the repl. - */ -private[repl] trait SparkILoopInit { - self: SparkILoop => - - /** Print a welcome message */ - def printWelcome() { - echo("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ -""".format(SPARK_VERSION)) - import Properties._ - val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) - echo(welcomeMsg) - echo("Type in expressions to have them evaluated.") - echo("Type :help for more information.") - } - - protected def asyncMessage(msg: String) { - if (isReplInfo || isReplPower) - echoAndRefresh(msg) - } - - private val initLock = new java.util.concurrent.locks.ReentrantLock() - private val initCompilerCondition = initLock.newCondition() // signal the compiler is initialized - private val initLoopCondition = initLock.newCondition() // signal the whole repl is initialized - private val initStart = System.nanoTime - - private def withLock[T](body: => T): T = { - initLock.lock() - try body - finally initLock.unlock() - } - // a condition used to ensure serial access to the compiler. - @volatile private var initIsComplete = false - @volatile private var initError: String = null - private def elapsed() = "%.3f".format((System.nanoTime - initStart).toDouble / 1000000000L) - - // the method to be called when the interpreter is initialized. - // Very important this method does nothing synchronous (i.e. do - // not try to use the interpreter) because until it returns, the - // repl's lazy val `global` is still locked. - protected def initializedCallback() = withLock(initCompilerCondition.signal()) - - // Spins off a thread which awaits a single message once the interpreter - // has been initialized. - protected def createAsyncListener() = { - io.spawn { - withLock(initCompilerCondition.await()) - asyncMessage("[info] compiler init time: " + elapsed() + " s.") - postInitialization() - } - } - - // called from main repl loop - protected def awaitInitialized(): Boolean = { - if (!initIsComplete) - withLock { while (!initIsComplete) initLoopCondition.await() } - if (initError != null) { - // scalastyle:off println - println(""" - |Failed to initialize the REPL due to an unexpected error. - |This is a bug, please, report it along with the error diagnostics printed below. - |%s.""".stripMargin.format(initError) - ) - // scalastyle:on println - false - } else true - } - // private def warningsThunks = List( - // () => intp.bind("lastWarnings", "" + typeTag[List[(Position, String)]], intp.lastWarnings _), - // ) - - protected def postInitThunks = List[Option[() => Unit]]( - Some(intp.setContextClassLoader _), - if (isReplPower) Some(() => enablePowerMode(true)) else None - ).flatten - // ++ ( - // warningsThunks - // ) - // called once after init condition is signalled - protected def postInitialization() { - try { - postInitThunks foreach (f => addThunk(f())) - runThunks() - } catch { - case ex: Throwable => - initError = stackTraceString(ex) - throw ex - } finally { - initIsComplete = true - - if (isAsync) { - asyncMessage("[info] total init time: " + elapsed() + " s.") - withLock(initLoopCondition.signal()) - } - } - } - - def initializeSpark() { - intp.beQuietDuring { - command(""" - @transient val spark = org.apache.spark.repl.Main.interp.createSparkSession() - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println(s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """) - command("import org.apache.spark.SparkContext._") - command("import spark.implicits._") - command("import spark.sql") - command("import org.apache.spark.sql.functions._") - } - } - - // code to be executed only after the interpreter is initialized - // and the lazy val `global` can be accessed without risk of deadlock. - private var pendingThunks: List[() => Unit] = Nil - protected def addThunk(body: => Unit) = synchronized { - pendingThunks :+= (() => body) - } - protected def runThunks(): Unit = synchronized { - if (pendingThunks.nonEmpty) - logDebug("Clearing " + pendingThunks.size + " thunks.") - - while (pendingThunks.nonEmpty) { - val thunk = pendingThunks.head - pendingThunks = pendingThunks.tail - thunk() - } - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala deleted file mode 100644 index 74a04d5a42bb2..0000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ /dev/null @@ -1,1808 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Martin Odersky - */ - -package org.apache.spark.repl - -import java.io.File - -import scala.tools.nsc._ -import scala.tools.nsc.backend.JavaPlatform -import scala.tools.nsc.interpreter._ - -import Predef.{ println => _, _ } -import scala.tools.nsc.util.{MergedClassPath, stringFromWriter, ScalaClassLoader, stackTraceString} -import scala.reflect.internal.util._ -import java.net.URL -import scala.sys.BooleanProp -import io.{AbstractFile, PlainFile, VirtualDirectory} - -import reporters._ -import symtab.Flags -import scala.reflect.internal.Names -import scala.tools.util.PathResolver -import ScalaClassLoader.URLClassLoader -import scala.tools.nsc.util.Exceptional.unwrap -import scala.collection.{ mutable, immutable } -import scala.util.control.Exception.{ ultimately } -import SparkIMain._ -import java.util.concurrent.Future -import typechecker.Analyzer -import scala.language.implicitConversions -import scala.reflect.runtime.{ universe => ru } -import scala.reflect.{ ClassTag, classTag } -import scala.tools.reflect.StdRuntimeTags._ -import scala.util.control.ControlThrowable - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils -import org.apache.spark.annotation.DeveloperApi - -// /** directory to save .class files to */ -// private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) { -// private def pp(root: AbstractFile, indentLevel: Int) { -// val spaces = " " * indentLevel -// out.println(spaces + root.name) -// if (root.isDirectory) -// root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1)) -// } -// // print the contents hierarchically -// def show() = pp(this, 0) -// } - - /** An interpreter for Scala code. - * - * The main public entry points are compile(), interpret(), and bind(). - * The compile() method loads a complete Scala file. The interpret() method - * executes one line of Scala code at the request of the user. The bind() - * method binds an object to a variable that can then be used by later - * interpreted code. - * - * The overall approach is based on compiling the requested code and then - * using a Java classloader and Java reflection to run the code - * and access its results. - * - * In more detail, a single compiler instance is used - * to accumulate all successfully compiled or interpreted Scala code. To - * "interpret" a line of code, the compiler generates a fresh object that - * includes the line of code and which has public member(s) to export - * all variables defined by that code. To extract the result of an - * interpreted line to show the user, a second "result object" is created - * which imports the variables exported by the above object and then - * exports members called "$eval" and "$print". To accommodate user expressions - * that read from variables or methods defined in previous statements, "import" - * statements are used. - * - * This interpreter shares the strengths and weaknesses of using the - * full compiler-to-Java. The main strength is that interpreted code - * behaves exactly as does compiled code, including running at full speed. - * The main weakness is that redefining classes and methods is not handled - * properly, because rebinding at the Java level is technically difficult. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - */ - @DeveloperApi - class SparkIMain( - initialSettings: Settings, - val out: JPrintWriter, - propagateExceptions: Boolean = false) - extends SparkImports with Logging { imain => - - private val conf = new SparkConf() - - private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") - /** Local directory to save .class files too */ - private[repl] val outputDir = { - val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) - Utils.createTempDir(root = rootDir, namePrefix = "repl") - } - if (SPARK_DEBUG_REPL) { - echo("Output directory: " + outputDir) - } - - /** - * Returns the path to the output directory containing all generated - * class files that will be served by the REPL class server. - */ - @DeveloperApi - lazy val getClassOutputDirectory = outputDir - - private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles - /** Jetty server that will serve our classes to worker nodes */ - private var currentSettings: Settings = initialSettings - private var printResults = true // whether to print result lines - private var totalSilence = false // whether to print anything - private var _initializeComplete = false // compiler is initialized - private var _isInitialized: Future[Boolean] = null // set up initialization future - private var bindExceptions = true // whether to bind the lastException variable - private var _executionWrapper = "" // code to be wrapped around all lines - - /** We're going to go to some trouble to initialize the compiler asynchronously. - * It's critical that nothing call into it until it's been initialized or we will - * run into unrecoverable issues, but the perceived repl startup time goes - * through the roof if we wait for it. So we initialize it with a future and - * use a lazy val to ensure that any attempt to use the compiler object waits - * on the future. - */ - private var _classLoader: AbstractFileClassLoader = null // active classloader - private val _compiler: Global = newCompiler(settings, reporter) // our private compiler - - private trait ExposeAddUrl extends URLClassLoader { def addNewUrl(url: URL) = this.addURL(url) } - private var _runtimeClassLoader: URLClassLoader with ExposeAddUrl = null // wrapper exposing addURL - - private val nextReqId = { - var counter = 0 - () => { counter += 1 ; counter } - } - - private def compilerClasspath: Seq[URL] = ( - if (isInitializeComplete) global.classPath.asURLs - else new PathResolver(settings).result.asURLs // the compiler's classpath - ) - // NOTE: Exposed to repl package since accessed indirectly from SparkIMain - private[repl] def settings = currentSettings - private def mostRecentLine = prevRequestList match { - case Nil => "" - case req :: _ => req.originalLine - } - // Run the code body with the given boolean settings flipped to true. - private def withoutWarnings[T](body: => T): T = beQuietDuring { - val saved = settings.nowarn.value - if (!saved) - settings.nowarn.value = true - - try body - finally if (!saved) settings.nowarn.value = false - } - - /** construct an interpreter that reports to Console */ - def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this() = this(new Settings()) - - private lazy val repllog: Logger = new Logger { - val out: JPrintWriter = imain.out - val isInfo: Boolean = BooleanProp keyExists "scala.repl.info" - val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug" - val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace" - } - private[repl] lazy val formatting: Formatting = new Formatting { - val prompt = Properties.shellPromptString - } - - // NOTE: Exposed to repl package since used by SparkExprTyper and SparkILoop - private[repl] lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) - - /** - * Determines if errors were reported (typically during compilation). - * - * @note This is not for runtime errors - * - * @return True if had errors, otherwise false - */ - @DeveloperApi - def isReportingErrors = reporter.hasErrors - - import formatting._ - import reporter.{ printMessage, withoutTruncating } - - // This exists mostly because using the reporter too early leads to deadlock. - private def echo(msg: String) { Console println msg } - private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) - private def _initialize() = { - try { - // todo. if this crashes, REPL will hang - new _compiler.Run() compileSources _initSources - _initializeComplete = true - true - } - catch AbstractOrMissingHandler() - } - private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" - - // argument is a thunk to execute after init is done - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def initialize(postInitSignal: => Unit) { - synchronized { - if (_isInitialized == null) { - _isInitialized = io.spawn { - try _initialize() - finally postInitSignal - } - } - } - } - - /** - * Initializes the underlying compiler/interpreter in a blocking fashion. - * - * @note Must be executed before using SparkIMain! - */ - @DeveloperApi - def initializeSynchronous(): Unit = { - if (!isInitializeComplete) { - _initialize() - assert(global != null, global) - } - } - private def isInitializeComplete = _initializeComplete - - /** the public, go through the future compiler */ - - /** - * The underlying compiler used to generate ASTs and execute code. - */ - @DeveloperApi - lazy val global: Global = { - if (isInitializeComplete) _compiler - else { - // If init hasn't been called yet you're on your own. - if (_isInitialized == null) { - logWarning("Warning: compiler accessed before init set up. Assuming no postInit code.") - initialize(()) - } - // // blocks until it is ; false means catastrophic failure - if (_isInitialized.get()) _compiler - else null - } - } - @deprecated("Use `global` for access to the compiler instance.", "2.9.0") - private lazy val compiler: global.type = global - - import global._ - import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember} - import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass} - - private implicit class ReplTypeOps(tp: Type) { - def orElse(other: => Type): Type = if (tp ne NoType) tp else other - def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) - } - - // TODO: If we try to make naming a lazy val, we run into big time - // scalac unhappiness with what look like cycles. It has not been easy to - // reduce, but name resolution clearly takes different paths. - // NOTE: Exposed to repl package since used by SparkExprTyper - private[repl] object naming extends { - val global: imain.global.type = imain.global - } with Naming { - // make sure we don't overwrite their unwisely named res3 etc. - def freshUserTermName(): TermName = { - val name = newTermName(freshUserVarName()) - if (definedNameMap contains name) freshUserTermName() - else name - } - def isUserTermName(name: Name) = isUserVarName("" + name) - def isInternalTermName(name: Name) = isInternalVarName("" + name) - } - import naming._ - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] object deconstruct extends { - val global: imain.global.type = imain.global - } with StructuredTypeStrings - - // NOTE: Exposed to repl package since used by SparkImports - private[repl] lazy val memberHandlers = new { - val intp: imain.type = imain - } with SparkMemberHandlers - import memberHandlers._ - - /** - * Suppresses overwriting print results during the operation. - * - * @param body The block to execute - * @tparam T The return type of the block - * - * @return The result from executing the block - */ - @DeveloperApi - def beQuietDuring[T](body: => T): T = { - val saved = printResults - printResults = false - try body - finally printResults = saved - } - - /** - * Completely masks all output during the operation (minus JVM standard - * out and error). - * - * @param operation The block to execute - * @tparam T The return type of the block - * - * @return The result from executing the block - */ - @DeveloperApi - def beSilentDuring[T](operation: => T): T = { - val saved = totalSilence - totalSilence = true - try operation - finally totalSilence = saved - } - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { - case t: ControlThrowable => throw t - case t: Throwable => - logDebug(label + ": " + unwrap(t)) - logDebug(stackTraceString(unwrap(t))) - alt - } - /** takes AnyRef because it may be binding a Throwable or an Exceptional */ - - private def withLastExceptionLock[T](body: => T, alt: => T): T = { - assert(bindExceptions, "withLastExceptionLock called incorrectly.") - bindExceptions = false - - try beQuietDuring(body) - catch logAndDiscard("withLastExceptionLock", alt) - finally bindExceptions = true - } - - /** - * Contains the code (in string form) representing a wrapper around all - * code executed by this instance. - * - * @return The wrapper code as a string - */ - @DeveloperApi - def executionWrapper = _executionWrapper - - /** - * Sets the code to use as a wrapper around all code executed by this - * instance. - * - * @param code The wrapper code as a string - */ - @DeveloperApi - def setExecutionWrapper(code: String) = _executionWrapper = code - - /** - * Clears the code used as a wrapper around all code executed by - * this instance. - */ - @DeveloperApi - def clearExecutionWrapper() = _executionWrapper = "" - - /** interpreter settings */ - private lazy val isettings = new SparkISettings(this) - - /** - * Instantiates a new compiler used by SparkIMain. Overridable to provide - * own instance of a compiler. - * - * @param settings The settings to provide the compiler - * @param reporter The reporter to use for compiler output - * - * @return The compiler as a Global - */ - @DeveloperApi - protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = { - settings.outputDirs setSingleOutput virtualDirectory - settings.exposeEmptyPackage.value = true - new Global(settings, reporter) with ReplGlobal { - override def toString: String = "" - } - } - - /** - * Adds any specified jars to the compile and runtime classpaths. - * - * @note Currently only supports jars, not directories - * @param urls The list of items to add to the compile and runtime classpaths - */ - @DeveloperApi - def addUrlsToClassPath(urls: URL*): Unit = { - new Run // Needed to force initialization of "something" to correctly load Scala classes from jars - urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution - updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling - } - - private def updateCompilerClassPath(urls: URL*): Unit = { - require(!global.forMSIL) // Only support JavaPlatform - - val platform = global.platform.asInstanceOf[JavaPlatform] - - val newClassPath = mergeUrlsIntoClassPath(platform, urls: _*) - - // NOTE: Must use reflection until this is exposed/fixed upstream in Scala - val fieldSetter = platform.getClass.getMethods - .find(_.getName.endsWith("currentClassPath_$eq")).get - fieldSetter.invoke(platform, Some(newClassPath)) - - // Reload all jars specified into our compiler - global.invalidateClassPathEntries(urls.map(_.getPath): _*) - } - - private def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { - // Collect our new jars/directories and add them to the existing set of classpaths - val allClassPaths = ( - platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++ - urls.map(url => { - platform.classPath.context.newClassPath( - if (url.getProtocol == "file") { - val f = new File(url.getPath) - if (f.isDirectory) - io.AbstractFile.getDirectory(f) - else - io.AbstractFile.getFile(f) - } else { - io.AbstractFile.getURL(url) - } - ) - }) - ).distinct - - // Combine all of our classpaths (old and new) into one merged classpath - new MergedClassPath(allClassPaths, platform.classPath.context) - } - - /** - * Represents the parent classloader used by this instance. Can be - * overridden to provide alternative classloader. - * - * @return The classloader used as the parent loader of this instance - */ - @DeveloperApi - protected def parentClassLoader: ClassLoader = - SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) - - /* A single class loader is used for all commands interpreted by this Interpreter. - It would also be possible to create a new class loader for each command - to interpret. The advantages of the current approach are: - - - Expressions are only evaluated one time. This is especially - significant for I/O, e.g. "val x = Console.readLine" - - The main disadvantage is: - - - Objects, classes, and methods cannot be rebound. Instead, definitions - shadow the old ones, and old code objects refer to the old - definitions. - */ - private def resetClassLoader() = { - logDebug("Setting new classloader: was " + _classLoader) - _classLoader = null - ensureClassLoader() - } - private final def ensureClassLoader() { - if (_classLoader == null) - _classLoader = makeClassLoader() - } - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def classLoader: AbstractFileClassLoader = { - ensureClassLoader() - _classLoader - } - private class TranslatingClassLoader(parent: ClassLoader) extends AbstractFileClassLoader(virtualDirectory, parent) { - /** Overridden here to try translating a simple name to the generated - * class name if the original attempt fails. This method is used by - * getResourceAsStream as well as findClass. - */ - override protected def findAbstractFile(name: String): AbstractFile = { - super.findAbstractFile(name) match { - // deadlocks on startup if we try to translate names too early - case null if isInitializeComplete => - generatedName(name) map (x => super.findAbstractFile(x)) orNull - case file => - file - } - } - } - private def makeClassLoader(): AbstractFileClassLoader = - new TranslatingClassLoader(parentClassLoader match { - case null => ScalaClassLoader fromURLs compilerClasspath - case p => - _runtimeClassLoader = new URLClassLoader(compilerClasspath, p) with ExposeAddUrl - _runtimeClassLoader - }) - - private def getInterpreterClassLoader() = classLoader - - // Set the current Java "context" class loader to this interpreter's class loader - // NOTE: Exposed to repl package since used by SparkILoopInit - private[repl] def setContextClassLoader() = classLoader.setAsContext() - - /** - * Returns the real name of a class based on its repl-defined name. - * - * ==Example== - * Given a simple repl-defined name, returns the real name of - * the class representing it, e.g. for "Bippy" it may return - * {{{ - * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy - * }}} - * - * @param simpleName The repl-defined name whose real name to retrieve - * - * @return Some real name if the simple name exists, else None - */ - @DeveloperApi - def generatedName(simpleName: String): Option[String] = { - if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING) - else optFlatName(simpleName) - } - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def flatName(id: String) = optFlatName(id) getOrElse id - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) - - /** - * Retrieves all simple names contained in the current instance. - * - * @return A list of sorted names - */ - @DeveloperApi - def allDefinedNames = definedNameMap.keys.toList.sorted - - private def pathToType(id: String): String = pathToName(newTypeName(id)) - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def pathToTerm(id: String): String = pathToName(newTermName(id)) - - /** - * Retrieves the full code path to access the specified simple name - * content. - * - * @param name The simple name of the target whose path to determine - * - * @return The full path used to access the specified target (name) - */ - @DeveloperApi - def pathToName(name: Name): String = { - if (definedNameMap contains name) - definedNameMap(name) fullPath name - else name.toString - } - - /** Most recent tree handled which wasn't wholly synthetic. */ - private def mostRecentlyHandledTree: Option[Tree] = { - prevRequests.reverse foreach { req => - req.handlers.reverse foreach { - case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) - case _ => () - } - } - None - } - - /** Stubs for work in progress. */ - private def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { - for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { - logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) - } - } - - private def handleTermRedefinition(name: TermName, old: Request, req: Request) = { - for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { - // Printing the types here has a tendency to cause assertion errors, like - // assertion failed: fatal: has owner value x, but a class owner is required - // so DBG is by-name now to keep it in the family. (It also traps the assertion error, - // but we don't want to unnecessarily risk hosing the compiler's internal state.) - logDebug("Redefining term '%s'\n %s -> %s".format(name, t1, t2)) - } - } - - private def recordRequest(req: Request) { - if (req == null || referencedNameMap == null) - return - - prevRequests += req - req.referencedNames foreach (x => referencedNameMap(x) = req) - - // warning about serially defining companions. It'd be easy - // enough to just redefine them together but that may not always - // be what people want so I'm waiting until I can do it better. - for { - name <- req.definedNames filterNot (x => req.definedNames contains x.companionName) - oldReq <- definedNameMap get name.companionName - newSym <- req.definedSymbols get name - oldSym <- oldReq.definedSymbols get name.companionName - if Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule } - } { - afterTyper(replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")) - replwarn("Companions must be defined together; you may wish to use :paste mode for this.") - } - - // Updating the defined name map - req.definedNames foreach { name => - if (definedNameMap contains name) { - if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req) - else handleTermRedefinition(name.toTermName, definedNameMap(name), req) - } - definedNameMap(name) = req - } - } - - private def replwarn(msg: => String) { - if (!settings.nowarnings.value) - printMessage(msg) - } - - private def isParseable(line: String): Boolean = { - beSilentDuring { - try parse(line) match { - case Some(xs) => xs.nonEmpty // parses as-is - case None => true // incomplete - } - catch { case x: Exception => // crashed the compiler - replwarn("Exception in isParseable(\"" + line + "\"): " + x) - false - } - } - } - - private def compileSourcesKeepingRun(sources: SourceFile*) = { - val run = new Run() - reporter.reset() - run compileSources sources.toList - (!reporter.hasErrors, run) - } - - /** - * Compiles specified source files. - * - * @param sources The sequence of source files to compile - * - * @return True if successful, otherwise false - */ - @DeveloperApi - def compileSources(sources: SourceFile*): Boolean = - compileSourcesKeepingRun(sources: _*)._1 - - /** - * Compiles a string of code. - * - * @param code The string of code to compile - * - * @return True if successful, otherwise false - */ - @DeveloperApi - def compileString(code: String): Boolean = - compileSources(new BatchSourceFile(" - UIUtils.headerSparkPage("SQL", content, parent, Some(5000)) + val summary: NodeSeq = +
+
    + { + if (listener.getRunningExecutions.nonEmpty) { +
  • + Running Queries: + {listener.getRunningExecutions.size} +
  • + } + } + { + if (listener.getCompletedExecutions.nonEmpty) { +
  • + Completed Queries: + {listener.getCompletedExecutions.size} +
  • + } + } + { + if (listener.getFailedExecutions.nonEmpty) { +
  • + Failed Queries: + {listener.getFailedExecutions.size} +
  • + } + } +
+
+ UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) } } From 12e740bba110c6ab017c73c5ef940cce39dd45b7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 27 Sep 2017 23:19:10 +0900 Subject: [PATCH 630/779] [SPARK-22130][CORE] UTF8String.trim() scans " " twice ## What changes were proposed in this pull request? This PR allows us to scan a string including only white space (e.g. `" "`) once while the current implementation scans twice (right to left, and then left to right). ## How was this patch tested? Existing test suites Author: Kazuaki Ishizaki Closes #19355 from kiszk/SPARK-22130. --- .../org/apache/spark/unsafe/types/UTF8String.java | 11 +++++------ .../apache/spark/unsafe/types/UTF8StringSuite.java | 3 +++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index ce4a06bde80c4..b0d0c44823e68 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -498,17 +498,16 @@ private UTF8String copyUTF8String(int start, int end) { public UTF8String trim() { int s = 0; - int e = this.numBytes - 1; // skip all of the space (0x20) in the left side while (s < this.numBytes && getByte(s) == 0x20) s++; - // skip all of the space (0x20) in the right side - while (e >= 0 && getByte(e) == 0x20) e--; - if (s > e) { + if (s == this.numBytes) { // empty string return EMPTY_UTF8; - } else { - return copyUTF8String(s, e); } + // skip all of the space (0x20) in the right side + int e = this.numBytes - 1; + while (e > s && getByte(e) == 0x20) e--; + return copyUTF8String(s, e); } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7b03d2c650fc9..9b303fa5bc6c5 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -222,10 +222,13 @@ public void substring() { @Test public void trims() { + assertEquals(fromString("1"), fromString("1").trim()); + assertEquals(fromString("hello"), fromString(" hello ").trim()); assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.trim()); assertEquals(EMPTY_UTF8, fromString(" ").trim()); assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); From 09cbf3df20efea09c0941499249b7a3b2bf7e9fd Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Sep 2017 23:21:44 +0900 Subject: [PATCH 631/779] [SPARK-22125][PYSPARK][SQL] Enable Arrow Stream format for vectorized UDF. ## What changes were proposed in this pull request? Currently we use Arrow File format to communicate with Python worker when invoking vectorized UDF but we can use Arrow Stream format. This pr replaces the Arrow File format with the Arrow Stream format. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #19349 from ueshin/issues/SPARK-22125. --- .../apache/spark/api/python/PythonRDD.scala | 325 +------------ .../spark/api/python/PythonRunner.scala | 441 ++++++++++++++++++ python/pyspark/serializers.py | 70 +-- python/pyspark/worker.py | 4 +- .../execution/vectorized/ColumnarBatch.java | 5 + .../python/ArrowEvalPythonExec.scala | 54 ++- .../execution/python/ArrowPythonRunner.scala | 181 +++++++ .../python/BatchEvalPythonExec.scala | 4 +- .../execution/python/PythonUDFRunner.scala | 113 +++++ 9 files changed, 825 insertions(+), 372 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 86d0405c678a7..f6293c0dc5091 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -48,7 +48,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions: Array[Partition] = firstParent.partitions @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonRunner(func, bufferSize, reuse_worker) + val runner = PythonRunner(func, bufferSize, reuseWorker) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -83,318 +83,9 @@ private[spark] case class PythonFunction( */ private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) -/** - * Enumerate the type of command that will be sent to the Python worker - */ -private[spark] object PythonEvalType { - val NON_UDF = 0 - val SQL_BATCHED_UDF = 1 - val SQL_PANDAS_UDF = 2 -} - -private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { - new PythonRunner( - Seq(ChainedPythonFunctions(Seq(func))), - bufferSize, - reuse_worker, - PythonEvalType.NON_UDF, - Array(Array(0))) - } -} - -/** - * A helper class to run Python mapPartition/UDFs in Spark. - * - * funcs is a list of independent Python functions, each one of them is a list of chained Python - * functions (from bottom to top). - */ -private[spark] class PythonRunner( - funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuse_worker: Boolean, - evalType: Int, - argOffsets: Array[Array[Int]]) - extends Logging { - - require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") - - // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.funcs.head.envVars - private val pythonExec = funcs.head.funcs.head.pythonExec - private val pythonVer = funcs.head.funcs.head.pythonVer - - // TODO: support accumulator in multiple UDF - private val accumulator = funcs.head.funcs.head.accumulator - - def compute( - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext): Iterator[Array[Byte]] = { - val startTime = System.currentTimeMillis - val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") - envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread - if (reuse_worker) { - envVars.put("SPARK_REUSE_WORKER", "1") - } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) - // Whether is the worker released into idle pool - @volatile var released = false - - // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) - - context.addTaskCompletionListener { context => - writerThread.shutdownOnTaskCompletion() - if (!reuse_worker || !released) { - try { - worker.close() - } catch { - case e: Exception => - logWarning("Failed to close worker socket", e) - } - } - } - - writerThread.start() - new MonitorThread(env, worker, context).start() - - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - val stdoutIterator = new Iterator[Array[Byte]] { - override def next(): Array[Byte] = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - private def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case 0 => Array.empty[Byte] - case SpecialLengths.TIMING_DATA => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.getOrElse(null)) - case SpecialLengths.END_OF_DATA_SECTION => - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - accumulator.add(update) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - released = true - } - } - null - } - } catch { - - case e: Exception if context.isInterrupted => - logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - - case e: Exception if env.isStopped => - logDebug("Exception thrown after context is stopped", e) - null // exit silently - - case e: Exception if writerThread.exception.isDefined => - logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get - - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - } - - var _nextObj = read() - - override def hasNext: Boolean = _nextObj != null - } - new InterruptibleIterator(context, stdoutIterator) - } - - /** - * The thread responsible for writing the data from the PythonRDD's parent iterator to the - * Python process. - */ - class WriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { - - @volatile private var _exception: Exception = null - - private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet - private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - - setDaemon(true) - - /** Contains the exception thrown while writing the parent iterator to the Python process. */ - def exception: Option[Exception] = Option(_exception) - - /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ - def shutdownOnTaskCompletion() { - assert(context.isCompleted) - this.interrupt() - } - - override def run(): Unit = Utils.logUncaughtExceptions { - try { - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - // Partition index - dataOut.writeInt(partitionIndex) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) - // Write out the TaskContextInfo - dataOut.writeInt(context.stageId()) - dataOut.writeInt(context.partitionId()) - dataOut.writeInt(context.attemptNumber()) - dataOut.writeLong(context.taskAttemptId()) - // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val cnt = toRemove.size + newBids.diff(oldBids).size - dataOut.writeInt(cnt) - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) - } - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - dataOut.flush() - // Serialized command: - dataOut.writeInt(evalType) - if (evalType != PythonEvalType.NON_UDF) { - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) - } - } - } else { - val command = funcs.head.funcs.head.command - dataOut.writeInt(command.length) - dataOut.write(command) - } - // Data values - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) - dataOut.writeInt(SpecialLengths.END_OF_STREAM) - dataOut.flush() - } catch { - case e: Exception if context.isCompleted || context.isInterrupted => - logDebug("Exception thrown after task completion (likely due to cleanup)", e) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see org.apache.spark.executor.Executor). - _exception = e - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } - } - } - - /** - * It is necessary to have a monitor thread for python workers if the user cancels with - * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the - * threads can block indefinitely. - */ - class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) - extends Thread(s"Worker Monitor for $pythonExec") { - - setDaemon(true) - - override def run() { - // Kill the worker if it is interrupted, checking until task completion. - // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.isInterrupted && !context.isCompleted) { - Thread.sleep(2000) - } - if (!context.isCompleted) { - try { - logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - } -} - /** Thrown for exceptions in user Python code. */ -private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause) +private[spark] class PythonException(msg: String, cause: Exception) + extends RuntimeException(msg, cause) /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. @@ -411,14 +102,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private object SpecialLengths { - val END_OF_DATA_SECTION = -1 - val PYTHON_EXCEPTION_THROWN = -2 - val TIMING_DATA = -3 - val END_OF_STREAM = -4 - val NULL = -5 -} - private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala new file mode 100644 index 0000000000000..3688a149443c1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -0,0 +1,441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io._ +import java.net._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.util._ + + +/** + * Enumerate the type of command that will be sent to the Python worker + */ +private[spark] object PythonEvalType { + val NON_UDF = 0 + val SQL_BATCHED_UDF = 1 + val SQL_PANDAS_UDF = 2 +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). + */ +private[spark] abstract class BasePythonRunner[IN, OUT]( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]]) + extends Logging { + + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + // All the Python functions should have the same exec, version and envvars. + protected val envVars = funcs.head.funcs.head.envVars + protected val pythonExec = funcs.head.funcs.head.pythonExec + protected val pythonVer = funcs.head.funcs.head.pythonVer + + // TODO: support accumulator in multiple UDF + protected val accumulator = funcs.head.funcs.head.accumulator + + def compute( + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): Iterator[OUT] = { + val startTime = System.currentTimeMillis + val env = SparkEnv.get + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + if (reuseWorker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) + // Whether is the worker released into idle pool + val released = new AtomicBoolean(false) + + // Start a thread to feed the process input from our parent's iterator + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + + context.addTaskCompletionListener { _ => + writerThread.shutdownOnTaskCompletion() + if (!reuseWorker || !released.get) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + writerThread.start() + new MonitorThread(env, worker, context).start() + + // Return an iterator that read lines from the process's stdout + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val stdoutIterator = newReaderIterator( + stream, writerThread, startTime, env, worker, released, context) + new InterruptibleIterator(context, stdoutIterator) + } + + protected def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[OUT] + + /** + * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Python process. + */ + abstract class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { + + @volatile private var _exception: Exception = null + + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) + + setDaemon(true) + + /** Contains the exception thrown while writing the parent iterator to the Python process. */ + def exception: Option[Exception] = Option(_exception) + + /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ + def shutdownOnTaskCompletion() { + assert(context.isCompleted) + this.interrupt() + } + + /** + * Writes a command section to the stream connected to the Python worker. + */ + protected def writeCommand(dataOut: DataOutputStream): Unit + + /** + * Writes input data to the stream connected to the Python worker. + */ + protected def writeIteratorToStream(dataOut: DataOutputStream): Unit + + override def run(): Unit = Utils.logUncaughtExceptions { + try { + TaskContext.setTaskContext(context) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + // Partition index + dataOut.writeInt(partitionIndex) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) + // Write out the TaskContextInfo + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) + // sparkFilesDir + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + dataOut.flush() + + dataOut.writeInt(evalType) + writeCommand(dataOut) + writeIteratorToStream(dataOut) + + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } catch { + case e: Exception if context.isCompleted || context.isInterrupted => + logDebug("Exception thrown after task completion (likely due to cleanup)", e) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + + case e: Exception => + // We must avoid throwing exceptions here, because the thread uncaught exception handler + // will kill the whole executor (see org.apache.spark.executor.Executor). + _exception = e + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } + } + } + + abstract class ReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext) + extends Iterator[OUT] { + + private var nextObj: OUT = _ + private var eos = false + + override def hasNext: Boolean = nextObj != null || { + if (!eos) { + nextObj = read() + hasNext + } else { + false + } + } + + override def next(): OUT = { + if (hasNext) { + val obj = nextObj + nextObj = null.asInstanceOf[OUT] + obj + } else { + Iterator.empty.next() + } + } + + /** + * Reads next object from the stream. + * When the stream reaches end of data, needs to process the following sections, + * and then returns null. + */ + protected def read(): OUT + + protected def handleTimingData(): Unit = { + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + } + + protected def handlePythonException(): PythonException = { + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + new PythonException(new String(obj, StandardCharsets.UTF_8), + writerThread.exception.getOrElse(null)) + } + + protected def handleEndOfDataSection(): Unit = { + // We've finished the data section of the output, but we can still + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) + stream.readFully(update) + accumulator.add(update) + } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuseWorker) { + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + released.set(true) + } + } + eos = true + } + + protected val handleException: PartialFunction[Throwable, OUT] = { + case e: Exception if context.isInterrupted => + logDebug("Exception thrown after task interruption", e) + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) + + case e: Exception if env.isStopped => + logDebug("Exception thrown after context is stopped", e) + null.asInstanceOf[OUT] // exit silently + + case e: Exception if writerThread.exception.isDefined => + logError("Python worker exited unexpectedly (crashed)", e) + logError("This may have been caused by a prior exception:", writerThread.exception.get) + throw writerThread.exception.get + + case eof: EOFException => + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } + } + + /** + * It is necessary to have a monitor thread for python workers if the user cancels with + * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the + * threads can block indefinitely. + */ + class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) + extends Thread(s"Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + // Kill the worker if it is interrupted, checking until task completion. + // TODO: This has a race condition if interruption occurs, as completed may still become true. + while (!context.isInterrupted && !context.isCompleted) { + Thread.sleep(2000) + } + if (!context.isCompleted) { + try { + logWarning("Incomplete task interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + } + } +} + +private[spark] object PythonRunner { + + def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = { + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) + } +} + +/** + * A helper class to run Python mapPartition in Spark. + */ +private[spark] class PythonRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean) + extends BasePythonRunner[Array[Byte], Array[Byte]]( + funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } +} + +private[spark] object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 + val END_OF_STREAM = -4 + val NULL = -5 + val START_ARROW_STREAM = -6 +} diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 7c1fbadcb82be..db77b7e150b24 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -79,6 +79,7 @@ class SpecialLengths(object): TIMING_DATA = -3 END_OF_STREAM = -4 NULL = -5 + START_ARROW_STREAM = -6 class PythonEvalType(object): @@ -211,44 +212,61 @@ def __repr__(self): return "ArrowSerializer" -class ArrowPandasSerializer(ArrowSerializer): +def _create_batch(series): + import pyarrow as pa + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or \ + (len(series) == 2 and isinstance(series[1], pa.DataType)): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + # If a nullable integer series has been promoted to floating point with NaNs, need to cast + # NOTE: this is not necessary with Arrow >= 0.7 + def cast_series(s, t): + if t is None or s.dtype == t.to_pandas_dtype(): + return s + else: + return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) + + arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + + +class ArrowStreamPandasSerializer(Serializer): """ - Serializes Pandas.Series as Arrow data. + Serializes Pandas.Series as Arrow data with Arrow streaming format. """ - def dumps(self, series): + def dump_stream(self, iterator, stream): """ - Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or + Make ArrowRecordBatches from Pandas Serieses and serialize. Input is a single series or a list of series accompanied by an optional pyarrow type to coerce the data to. """ import pyarrow as pa - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or \ - (len(series) == 2 and isinstance(series[1], pa.DataType)): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - # If a nullable integer series has been promoted to floating point with NaNs, need to cast - # NOTE: this is not necessary with Arrow >= 0.7 - def cast_series(s, t): - if t is None or s.dtype == t.to_pandas_dtype(): - return s - else: - return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - - arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] - batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) - return super(ArrowPandasSerializer, self).dumps(batch) + writer = None + try: + for series in iterator: + batch = _create_batch(series) + if writer is None: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() - def loads(self, obj): + def load_stream(self, stream): """ - Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series. + Deserialize ArrowRecordBatchs to an Arrow table and return as a list of pandas.Series. """ - table = super(ArrowPandasSerializer, self).loads(obj) - return [c.to_pandas() for c in table.itercolumns()] + import pyarrow as pa + reader = pa.open_stream(stream) + for batch in reader: + table = pa.Table.from_batches([batch]) + yield [c.to_pandas() for c in table.itercolumns()] def __repr__(self): - return "ArrowPandasSerializer" + return "ArrowStreamPandasSerializer" class BatchedSerializer(Serializer): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fd917c400c872..4e24789cf010d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowPandasSerializer + BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import toArrowType from pyspark import shuffle @@ -123,7 +123,7 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) if eval_type == PythonEvalType.SQL_PANDAS_UDF: - ser = ArrowPandasSerializer() + ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index e782756a3e781..bc546c7c425b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -462,6 +462,11 @@ public int numValidRows() { return numRows - numRowsFiltered; } + /** + * Returns the schema that makes up this batch. + */ + public StructType schema() { return schema; } + /** * Returns the max capacity (in number of rows) for this batch. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 5e72cd255873a..f7e8cbe416121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.python +import scala.collection.JavaConverters._ + import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.types.StructType /** @@ -39,25 +40,36 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[InternalRow] = { - val inputIterator = ArrowConverters.toPayloadIterator( - iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable) - - // Output iterator for results from Python. - val outputIterator = new PythonRunner( - funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets) - .compute(inputIterator, context.partitionId(), context) - - val outputRowIterator = ArrowConverters.fromPayloadIterator( - outputIterator.map(new ArrowPayload(_)), context) - - // Verify that the output schema is correct - if (outputRowIterator.hasNext) { - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) - assert(schemaOut.equals(outputRowIterator.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}") - } - outputRowIterator + val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex + .map { case (attr, i) => attr.withName(s"_$i") }) + + val columnarBatchIter = new ArrowPythonRunner( + funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + .compute(iter, context.partitionId(), context) + + new Iterator[InternalRow] { + + var currentIter = if (columnarBatchIter.hasNext) { + val batch = columnarBatchIter.next() + assert(schemaOut.equals(batch.schema), + s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + batch.rowIterator.asScala + } else { + Iterator.empty + } + + override def hasNext: Boolean = currentIter.hasNext || { + if (columnarBatchIter.hasNext) { + currentIter = columnarBatchIter.next().rowIterator.asScala + hasNext + } else { + false + } + } + + override def next(): InternalRow = currentIter.next() + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala new file mode 100644 index 0000000000000..bbad9d6b631fd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter} + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class ArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + batchSize: Int, + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType) + extends BasePythonRunner[InternalRow, ColumnarBatch]( + funcs, bufferSize, reuseWorker, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[InternalRow], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + + var closed = false + + context.addTaskCompletionListener { _ => + if (!closed) { + root.close() + allocator.close() + } + } + + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + + Utils.tryWithSafeFinally { + while (inputIterator.hasNext) { + var rowCount = 0 + while (inputIterator.hasNext && (batchSize <= 0 || rowCount < batchSize)) { + val row = inputIterator.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + } { + writer.end() + root.close() + allocator.close() + closed = true + } + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + private var closed = false + + context.addTaskCompletionListener { _ => + // todo: we need something like `reader.end()`, which release all the resources, but leave + // the input stream open. `reader.close()` will close the socket and we can't reuse worker. + // So here we simply not close the reader, which is problematic. + if (!closed) { + if (root != null) { + root.close() + } + allocator.close() + } + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + batch.setNumRows(root.getRowCount) + batch + } else { + root.close() + allocator.close() + closed = true + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 2978eac50554d..26ee25f633ea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan @@ -68,7 +68,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonRunner( + val outputIterator = new PythonUDFRunner( funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) .compute(inputIterator, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala new file mode 100644 index 0000000000000..e28def1c4b423 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark._ +import org.apache.spark.api.python._ + +/** + * A helper class to run Python UDFs in Spark. + */ +class PythonUDFRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]]) + extends BasePythonRunner[Array[Byte], Array[Byte]]( + funcs, bufferSize, reuseWorker, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } +} + +object PythonUDFRunner { + + def writeUDFs( + dataOut: DataOutputStream, + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]]): Unit = { + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } +} From 9b98aef6a39a5a9ea9fc5481b5a0d92620ba6347 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 27 Sep 2017 13:40:21 -0700 Subject: [PATCH 632/779] [HOTFIX][BUILD] Fix finalizer checkstyle error and re-disable checkstyle ## What changes were proposed in this pull request? Fix finalizer checkstyle violation by just turning it off; re-disable checkstyle as it won't be run by SBT PR builder. See https://github.com/apache/spark/pull/18887#issuecomment-332580700 ## How was this patch tested? `./dev/lint-java` runs successfully Author: Sean Owen Closes #19371 from srowen/HotfixFinalizerCheckstlye. --- .../java/org/apache/spark/io/NioBufferedFileInputStream.java | 2 -- dev/checkstyle-suppressions.xml | 2 -- dev/checkstyle.xml | 1 - pom.xml | 1 + 4 files changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index ea5f1a9abf69b..f6d1288cb263d 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,10 +130,8 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } - //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } - //checkstyle.on: NoFinalizer } diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 6e15f6955984e..bbda824dd13b4 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -40,8 +40,6 @@ files="src/main/java/org/apache/hive/service/*"/> -
- diff --git a/pom.xml b/pom.xml index b0408ecca0f66..83a35006707da 100644 --- a/pom.xml +++ b/pom.xml @@ -2488,6 +2488,7 @@ maven-checkstyle-plugin 2.17 + false true ${basedir}/src/main/java,${basedir}/src/main/scala ${basedir}/src/test/java From 02bb0682e68a2ce81f3b98d33649d368da7f2b3d Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 27 Sep 2017 23:08:30 +0200 Subject: [PATCH 633/779] [SPARK-22143][SQL] Fix memory leak in OffHeapColumnVector ## What changes were proposed in this pull request? `WriteableColumnVector` does not close its child column vectors. This can create memory leaks for `OffHeapColumnVector` where we do not clean up the memory allocated by a vectors children. This can be especially bad for string columns (which uses a child byte column vector). ## How was this patch tested? I have updated the existing tests to always use both on-heap and off-heap vectors. Testing and diagnoses was done locally. Author: Herman van Hovell Closes #19367 from hvanhovell/SPARK-22143. --- .../vectorized/OffHeapColumnVector.java | 1 + .../vectorized/OnHeapColumnVector.java | 10 + .../vectorized/WritableColumnVector.java | 18 ++ .../vectorized/ColumnVectorSuite.scala | 102 +++++---- .../vectorized/ColumnarBatchSuite.scala | 194 ++++++++---------- 5 files changed, 165 insertions(+), 160 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e1d36858d4eee..8cbc895506d91 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -85,6 +85,7 @@ public long nullsNativeAddress() { @Override public void close() { + super.close(); Platform.freeMemory(nulls); Platform.freeMemory(data); Platform.freeMemory(lengthData); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 96a452978cb35..2725a29eeabe8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -90,6 +90,16 @@ public long nullsNativeAddress() { @Override public void close() { + super.close(); + nulls = null; + byteData = null; + shortData = null; + intData = null; + longData = null; + floatData = null; + doubleData = null; + arrayLengths = null; + arrayOffsets = null; } // diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 0bddc351e1bed..163f2511e5f73 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -59,6 +59,24 @@ public void reset() { } } + @Override + public void close() { + if (childColumns != null) { + for (int i = 0; i < childColumns.length; i++) { + childColumns[i].close(); + childColumns[i] = null; + } + childColumns = null; + } + if (dictionaryIds != null) { + dictionaryIds.close(); + dictionaryIds = null; + } + dictionary = null; + resultStruct = null; + resultArray = null; + } + public void reserve(int requiredCapacity) { if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index f7b06c97f9db6..85da8270d4cba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -25,19 +25,24 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { - - var testVector: WritableColumnVector = _ - - private def allocate(capacity: Int, dt: DataType): WritableColumnVector = { - new OnHeapColumnVector(capacity, dt) + private def withVector( + vector: WritableColumnVector)( + block: WritableColumnVector => Unit): Unit = { + try block(vector) finally vector.close() } - override def afterEach(): Unit = { - testVector.close() + private def testVectors( + name: String, + size: Int, + dt: DataType)( + block: WritableColumnVector => Unit): Unit = { + test(name) { + withVector(new OnHeapColumnVector(size, dt))(block) + withVector(new OffHeapColumnVector(size, dt))(block) + } } - test("boolean") { - testVector = allocate(10, BooleanType) + testVectors("boolean", 10, BooleanType) { testVector => (0 until 10).foreach { i => testVector.appendBoolean(i % 2 == 0) } @@ -49,8 +54,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("byte") { - testVector = allocate(10, ByteType) + testVectors("byte", 10, ByteType) { testVector => (0 until 10).foreach { i => testVector.appendByte(i.toByte) } @@ -58,12 +62,11 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val array = new ColumnVector.Array(testVector) (0 until 10).foreach { i => - assert(array.get(i, ByteType) === (i.toByte)) + assert(array.get(i, ByteType) === i.toByte) } } - test("short") { - testVector = allocate(10, ShortType) + testVectors("short", 10, ShortType) { testVector => (0 until 10).foreach { i => testVector.appendShort(i.toShort) } @@ -71,12 +74,11 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val array = new ColumnVector.Array(testVector) (0 until 10).foreach { i => - assert(array.get(i, ShortType) === (i.toShort)) + assert(array.get(i, ShortType) === i.toShort) } } - test("int") { - testVector = allocate(10, IntegerType) + testVectors("int", 10, IntegerType) { testVector => (0 until 10).foreach { i => testVector.appendInt(i) } @@ -88,8 +90,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("long") { - testVector = allocate(10, LongType) + testVectors("long", 10, LongType) { testVector => (0 until 10).foreach { i => testVector.appendLong(i) } @@ -101,8 +102,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("float") { - testVector = allocate(10, FloatType) + testVectors("float", 10, FloatType) { testVector => (0 until 10).foreach { i => testVector.appendFloat(i.toFloat) } @@ -114,8 +114,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("double") { - testVector = allocate(10, DoubleType) + testVectors("double", 10, DoubleType) { testVector => (0 until 10).foreach { i => testVector.appendDouble(i.toDouble) } @@ -127,8 +126,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("string") { - testVector = allocate(10, StringType) + testVectors("string", 10, StringType) { testVector => (0 until 10).map { i => val utf8 = s"str$i".getBytes("utf8") testVector.appendByteArray(utf8, 0, utf8.length) @@ -141,8 +139,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("binary") { - testVector = allocate(10, BinaryType) + testVectors("binary", 10, BinaryType) { testVector => (0 until 10).map { i => val utf8 = s"str$i".getBytes("utf8") testVector.appendByteArray(utf8, 0, utf8.length) @@ -156,9 +153,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("array") { - val arrayType = ArrayType(IntegerType, true) - testVector = allocate(10, arrayType) + val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true) + testVectors("array", 10, arrayType) { testVector => val data = testVector.arrayData() var i = 0 @@ -181,9 +177,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) } - test("struct") { - val schema = new StructType().add("int", IntegerType).add("double", DoubleType) - testVector = allocate(10, schema) + val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) + testVectors("struct", 10, structType) { testVector => val c1 = testVector.getChildColumn(0) val c2 = testVector.getChildColumn(1) c1.putInt(0, 123) @@ -193,35 +188,34 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val array = new ColumnVector.Array(testVector) - assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) - assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) - assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) - assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) + assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) + assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) + assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) + assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { - val arrayType = ArrayType(IntegerType, true) - testVector = new OffHeapColumnVector(8, arrayType) + withVector(new OffHeapColumnVector(8, arrayType)) { testVector => + val data = testVector.arrayData() + (0 until 8).foreach(i => data.putInt(i, i)) + (0 until 8).foreach(i => testVector.putArray(i, i, 1)) - val data = testVector.arrayData() - (0 until 8).foreach(i => data.putInt(i, i)) - (0 until 8).foreach(i => testVector.putArray(i, i, 1)) + // Increase vector's capacity and reallocate the data to new bigger buffers. + testVector.reserve(16) - // Increase vector's capacity and reallocate the data to new bigger buffers. - testVector.reserve(16) - - // Check that none of the values got lost/overwritten. - val array = new ColumnVector.Array(testVector) - (0 until 8).foreach { i => - assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) + // Check that none of the values got lost/overwritten. + val array = new ColumnVector.Array(testVector) + (0 until 8).foreach { i => + assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) + } } } test("[SPARK-22092] off-heap column vector reallocation corrupts struct nullability") { - val structType = new StructType().add("int", IntegerType).add("double", DoubleType) - testVector = new OffHeapColumnVector(8, structType) - (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) - testVector.reserve(16) - (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) + withVector(new OffHeapColumnVector(8, structType)) { testVector => + (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) + testVector.reserve(16) + (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index ebf76613343ba..983eb103682c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { - def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { + private def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { if (memMode == MemoryMode.OFF_HEAP) { new OffHeapColumnVector(capacity, dt) } else { @@ -46,23 +46,36 @@ class ColumnarBatchSuite extends SparkFunSuite { } } - test("Null Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val reference = mutable.ArrayBuffer.empty[Boolean] + private def testVector( + name: String, + size: Int, + dt: DataType)( + block: (WritableColumnVector, MemoryMode) => Unit): Unit = { + test(name) { + Seq(MemoryMode.ON_HEAP, MemoryMode.OFF_HEAP).foreach { mode => + val vector = allocate(size, dt, mode) + try block(vector, mode) finally { + vector.close() + } + } + } + } - val column = allocate(1024, IntegerType, memMode) + testVector("Null APIs", 1024, IntegerType) { + (column, memMode) => + val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNotNull() reference += false - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNotNulls(3) (1 to 3).foreach(_ => reference += false) - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNull() @@ -113,16 +126,12 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) } } - column.close - }} } - test("Byte Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Byte APIs", 1024, ByteType) { + (column, memMode) => val reference = mutable.ArrayBuffer.empty[Byte] - val column = allocate(1024, ByteType, memMode) - var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray column.appendBytes(2, values, 0) reference += 10.toByte @@ -170,17 +179,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getByte(null, addr + v._2)) } } - }} } - test("Short Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Short APIs", 1024, ShortType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Short] - val column = allocate(1024, ShortType, memMode) - var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray column.appendShorts(2, values, 0) reference += 10.toShort @@ -248,19 +254,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getShort(null, addr + 2 * v._2)) } } - - column.close - }} } - test("Int Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Int APIs", 1024, IntegerType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Int] - val column = allocate(1024, IntegerType, memMode) - var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray column.appendInts(2, values, 0) reference += 10 @@ -334,18 +335,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) } } - column.close - }} } - test("Long Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Long APIs", 1024, LongType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Long] - val column = allocate(1024, LongType, memMode) - var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray column.appendLongs(2, values, 0) reference += 10L @@ -422,17 +419,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) } } - }} } - test("Float APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Float APIs", 1024, FloatType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Float] - val column = allocate(1024, FloatType, memMode) - var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray column.appendFloats(2, values, 0) reference += .1f @@ -512,18 +506,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getFloat(null, addr + 4 * v._2)) } } - column.close - }} } - test("Double APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Double APIs", 1024, DoubleType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Double] - val column = allocate(1024, DoubleType, memMode) - var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray column.appendDoubles(2, values, 0) reference += .1 @@ -603,15 +593,12 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) } } - column.close - }} } - test("String APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("String APIs", 6, StringType) { + (column, memMode) => val reference = mutable.ArrayBuffer.empty[String] - val column = allocate(6, BinaryType, memMode) assert(column.arrayData().elementsAppended == 0) val str = "string" @@ -663,15 +650,13 @@ class ColumnarBatchSuite extends SparkFunSuite { column.reset() assert(column.arrayData().elementsAppended == 0) - }} } - test("Int Array") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val column = allocate(10, new ArrayType(IntegerType, true), memMode) + testVector("Int Array", 10, new ArrayType(IntegerType, true)) { + (column, _) => // Fill the underlying data with all the arrays back to back. - val data = column.arrayData(); + val data = column.arrayData() var i = 0 while (i < 6) { data.putInt(i, i) @@ -709,7 +694,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(3).getInt(2) == 5) // Add a longer array which requires resizing - column.reset + column.reset() val array = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) assert(data.capacity == 10) data.reserve(array.length) @@ -718,63 +703,67 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(0, 0, array.length) assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] === array) - }} } test("toArray for primitive types") { - // (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - (MemoryMode.ON_HEAP :: Nil).foreach { memMode => { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => val len = 4 val columnBool = allocate(len, new ArrayType(BooleanType, false), memMode) val boolArray = Array(false, true, false, true) - boolArray.zipWithIndex.map { case (v, i) => columnBool.arrayData.putBoolean(i, v) } + boolArray.zipWithIndex.foreach { case (v, i) => columnBool.arrayData.putBoolean(i, v) } columnBool.putArray(0, 0, len) assert(columnBool.getArray(0).toBooleanArray === boolArray) + columnBool.close() val columnByte = allocate(len, new ArrayType(ByteType, false), memMode) val byteArray = Array[Byte](0, 1, 2, 3) - byteArray.zipWithIndex.map { case (v, i) => columnByte.arrayData.putByte(i, v) } + byteArray.zipWithIndex.foreach { case (v, i) => columnByte.arrayData.putByte(i, v) } columnByte.putArray(0, 0, len) assert(columnByte.getArray(0).toByteArray === byteArray) + columnByte.close() val columnShort = allocate(len, new ArrayType(ShortType, false), memMode) val shortArray = Array[Short](0, 1, 2, 3) - shortArray.zipWithIndex.map { case (v, i) => columnShort.arrayData.putShort(i, v) } + shortArray.zipWithIndex.foreach { case (v, i) => columnShort.arrayData.putShort(i, v) } columnShort.putArray(0, 0, len) assert(columnShort.getArray(0).toShortArray === shortArray) + columnShort.close() val columnInt = allocate(len, new ArrayType(IntegerType, false), memMode) val intArray = Array(0, 1, 2, 3) - intArray.zipWithIndex.map { case (v, i) => columnInt.arrayData.putInt(i, v) } + intArray.zipWithIndex.foreach { case (v, i) => columnInt.arrayData.putInt(i, v) } columnInt.putArray(0, 0, len) assert(columnInt.getArray(0).toIntArray === intArray) + columnInt.close() val columnLong = allocate(len, new ArrayType(LongType, false), memMode) val longArray = Array[Long](0, 1, 2, 3) - longArray.zipWithIndex.map { case (v, i) => columnLong.arrayData.putLong(i, v) } + longArray.zipWithIndex.foreach { case (v, i) => columnLong.arrayData.putLong(i, v) } columnLong.putArray(0, 0, len) assert(columnLong.getArray(0).toLongArray === longArray) + columnLong.close() val columnFloat = allocate(len, new ArrayType(FloatType, false), memMode) val floatArray = Array(0.0F, 1.1F, 2.2F, 3.3F) - floatArray.zipWithIndex.map { case (v, i) => columnFloat.arrayData.putFloat(i, v) } + floatArray.zipWithIndex.foreach { case (v, i) => columnFloat.arrayData.putFloat(i, v) } columnFloat.putArray(0, 0, len) assert(columnFloat.getArray(0).toFloatArray === floatArray) + columnFloat.close() val columnDouble = allocate(len, new ArrayType(DoubleType, false), memMode) val doubleArray = Array(0.0, 1.1, 2.2, 3.3) - doubleArray.zipWithIndex.map { case (v, i) => columnDouble.arrayData.putDouble(i, v) } + doubleArray.zipWithIndex.foreach { case (v, i) => columnDouble.arrayData.putDouble(i, v) } columnDouble.putArray(0, 0, len) assert(columnDouble.getArray(0).toDoubleArray === doubleArray) - }} + columnDouble.close() + } } - test("Struct Column") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val schema = new StructType().add("int", IntegerType).add("double", DoubleType) - val column = allocate(1024, schema, memMode) - + testVector( + "Struct Column", + 10, + new StructType().add("int", IntegerType).add("double", DoubleType)) { (column, _) => val c1 = column.getChildColumn(0) val c2 = column.getChildColumn(1) assert(c1.dataType() == IntegerType) @@ -797,13 +786,10 @@ class ColumnarBatchSuite extends SparkFunSuite { val s2 = column.getStruct(1) assert(s2.getInt(0) == 456) assert(s2.getDouble(1) == 5.67) - }} } - test("Nest Array in Array.") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val column = allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), - memMode) + testVector("Nest Array in Array", 10, new ArrayType(new ArrayType(IntegerType, true), true)) { + (column, _) => val childColumn = column.arrayData() val data = column.arrayData().arrayData() (0 until 6).foreach { @@ -829,13 +815,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(2).getArray(1).getInt(1) === 4) assert(column.getArray(2).getArray(1).getInt(2) === 5) assert(column.isNullAt(3)) - } } - test("Nest Struct in Array.") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val schema = new StructType().add("int", IntegerType).add("long", LongType) - val column = allocate(10, new ArrayType(schema, true), memMode) + private val structType: StructType = new StructType().add("i", IntegerType).add("l", LongType) + + testVector( + "Nest Struct in Array", + 10, + new ArrayType(structType, true)) { (column, _) => val data = column.arrayData() val c0 = data.getChildColumn(0) val c1 = data.getChildColumn(1) @@ -850,22 +837,21 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(1, 1, 3) column.putArray(2, 4, 2) - assert(column.getArray(0).getStruct(0, 2).toSeq(schema) === Seq(0, 0)) - assert(column.getArray(0).getStruct(1, 2).toSeq(schema) === Seq(1, 10)) - assert(column.getArray(1).getStruct(0, 2).toSeq(schema) === Seq(1, 10)) - assert(column.getArray(1).getStruct(1, 2).toSeq(schema) === Seq(2, 20)) - assert(column.getArray(1).getStruct(2, 2).toSeq(schema) === Seq(3, 30)) - assert(column.getArray(2).getStruct(0, 2).toSeq(schema) === Seq(4, 40)) - assert(column.getArray(2).getStruct(1, 2).toSeq(schema) === Seq(5, 50)) - } + assert(column.getArray(0).getStruct(0, 2).toSeq(structType) === Seq(0, 0)) + assert(column.getArray(0).getStruct(1, 2).toSeq(structType) === Seq(1, 10)) + assert(column.getArray(1).getStruct(0, 2).toSeq(structType) === Seq(1, 10)) + assert(column.getArray(1).getStruct(1, 2).toSeq(structType) === Seq(2, 20)) + assert(column.getArray(1).getStruct(2, 2).toSeq(structType) === Seq(3, 30)) + assert(column.getArray(2).getStruct(0, 2).toSeq(structType) === Seq(4, 40)) + assert(column.getArray(2).getStruct(1, 2).toSeq(structType) === Seq(5, 50)) } - test("Nest Array in Struct.") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val schema = new StructType() - .add("int", IntegerType) - .add("array", new ArrayType(IntegerType, true)) - val column = allocate(10, schema, memMode) + testVector( + "Nest Array in Struct", + 10, + new StructType() + .add("int", IntegerType) + .add("array", new ArrayType(IntegerType, true))) { (column, _) => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -886,18 +872,15 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getStruct(1).getArray(1).toIntArray() === Array(2)) assert(column.getStruct(2).getInt(0) === 2) assert(column.getStruct(2).getArray(1).toIntArray() === Array(3, 4, 5)) - } } - test("Nest Struct in Struct.") { - (MemoryMode.ON_HEAP :: Nil).foreach { memMode => - val subSchema = new StructType() - .add("int", IntegerType) - .add("int", IntegerType) - val schema = new StructType() - .add("int", IntegerType) - .add("struct", subSchema) - val column = allocate(10, schema, memMode) + private val subSchema: StructType = new StructType() + .add("int", IntegerType) + .add("int", IntegerType) + testVector( + "Nest Struct in Struct", + 10, + new StructType().add("int", IntegerType).add("struct", subSchema)) { (column, _) => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -919,7 +902,6 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getStruct(1).getStruct(1, 2).toSeq(subSchema) === Seq(8, 80)) assert(column.getStruct(2).getInt(0) === 2) assert(column.getStruct(2).getStruct(1, 2).toSeq(subSchema) === Seq(9, 90)) - } } test("ColumnarBatch basic") { @@ -1040,7 +1022,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val it4 = batch.rowIterator() rowEquals(it4.next(), Row(null, 2.2, 2, "abc")) - batch.close + batch.close() }} } From 9244957b500cb2b458c32db2c63293a1444690d7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 27 Sep 2017 17:03:42 -0700 Subject: [PATCH 634/779] [SPARK-22140] Add TPCDSQuerySuite ## What changes were proposed in this pull request? Now, we are not running TPC-DS queries as regular test cases. Thus, we need to add a test suite using empty tables for ensuring the new code changes will not break them. For example, optimizer/analyzer batches should not exceed the max iteration. ## How was this patch tested? N/A Author: gatorsmile Closes #19361 from gatorsmile/tpcdsQuerySuite. --- .../apache/spark/sql/TPCDSQuerySuite.scala | 348 ++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala new file mode 100644 index 0000000000000..c0797fa55f5da --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { + + /** + * Drop all the tables + */ + protected override def afterAll(): Unit = { + try { + spark.sessionState.catalog.reset() + } finally { + super.afterAll() + } + } + + override def beforeAll() { + super.beforeAll() + sql( + """ + |CREATE TABLE `catalog_page` ( + |`cp_catalog_page_sk` INT, `cp_catalog_page_id` STRING, `cp_start_date_sk` INT, + |`cp_end_date_sk` INT, `cp_department` STRING, `cp_catalog_number` INT, + |`cp_catalog_page_number` INT, `cp_description` STRING, `cp_type` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `catalog_returns` ( + |`cr_returned_date_sk` INT, `cr_returned_time_sk` INT, `cr_item_sk` INT, + |`cr_refunded_customer_sk` INT, `cr_refunded_cdemo_sk` INT, `cr_refunded_hdemo_sk` INT, + |`cr_refunded_addr_sk` INT, `cr_returning_customer_sk` INT, `cr_returning_cdemo_sk` INT, + |`cr_returning_hdemo_sk` INT, `cr_returning_addr_sk` INT, `cr_call_center_sk` INT, + |`cr_catalog_page_sk` INT, `cr_ship_mode_sk` INT, `cr_warehouse_sk` INT, `cr_reason_sk` INT, + |`cr_order_number` INT, `cr_return_quantity` INT, `cr_return_amount` DECIMAL(7,2), + |`cr_return_tax` DECIMAL(7,2), `cr_return_amt_inc_tax` DECIMAL(7,2), `cr_fee` DECIMAL(7,2), + |`cr_return_ship_cost` DECIMAL(7,2), `cr_refunded_cash` DECIMAL(7,2), + |`cr_reversed_charge` DECIMAL(7,2), `cr_store_credit` DECIMAL(7,2), + |`cr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer` ( + |`c_customer_sk` INT, `c_customer_id` STRING, `c_current_cdemo_sk` INT, + |`c_current_hdemo_sk` INT, `c_current_addr_sk` INT, `c_first_shipto_date_sk` INT, + |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, + |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, + |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, + |`c_email_address` STRING, `c_last_review_date` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer_address` ( + |`ca_address_sk` INT, `ca_address_id` STRING, `ca_street_number` STRING, + |`ca_street_name` STRING, `ca_street_type` STRING, `ca_suite_number` STRING, + |`ca_city` STRING, `ca_county` STRING, `ca_state` STRING, `ca_zip` STRING, + |`ca_country` STRING, `ca_gmt_offset` DECIMAL(5,2), `ca_location_type` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer_demographics` ( + |`cd_demo_sk` INT, `cd_gender` STRING, `cd_marital_status` STRING, + |`cd_education_status` STRING, `cd_purchase_estimate` INT, `cd_credit_rating` STRING, + |`cd_dep_count` INT, `cd_dep_employed_count` INT, `cd_dep_college_count` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `date_dim` ( + |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, + |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, + |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, + |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, + |`d_weekend` STRING, `d_following_holiday` STRING, `d_first_dom` INT, `d_last_dom` INT, + |`d_same_day_ly` INT, `d_same_day_lq` INT, `d_current_day` STRING, `d_current_week` STRING, + |`d_current_month` STRING, `d_current_quarter` STRING, `d_current_year` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `household_demographics` ( + |`hd_demo_sk` INT, `hd_income_band_sk` INT, `hd_buy_potential` STRING, `hd_dep_count` INT, + |`hd_vehicle_count` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `inventory` (`inv_date_sk` INT, `inv_item_sk` INT, `inv_warehouse_sk` INT, + |`inv_quantity_on_hand` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, + |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), + |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, + |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, + |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, + |`i_units` STRING, `i_container` STRING, `i_manager_id` INT, `i_product_name` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `promotion` ( + |`p_promo_sk` INT, `p_promo_id` STRING, `p_start_date_sk` INT, `p_end_date_sk` INT, + |`p_item_sk` INT, `p_cost` DECIMAL(15,2), `p_response_target` INT, `p_promo_name` STRING, + |`p_channel_dmail` STRING, `p_channel_email` STRING, `p_channel_catalog` STRING, + |`p_channel_tv` STRING, `p_channel_radio` STRING, `p_channel_press` STRING, + |`p_channel_event` STRING, `p_channel_demo` STRING, `p_channel_details` STRING, + |`p_purpose` STRING, `p_discount_active` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store` ( + |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, + |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, + |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, + |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, + |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, + |`s_company_id` INT, `s_company_name` STRING, `s_street_number` STRING, + |`s_street_name` STRING, `s_street_type` STRING, `s_suite_number` STRING, `s_city` STRING, + |`s_county` STRING, `s_state` STRING, `s_zip` STRING, `s_country` STRING, + |`s_gmt_offset` DECIMAL(5,2), `s_tax_precentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store_returns` ( + |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, + |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, + |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, + |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), + |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), + |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), + |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), + |`sr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `catalog_sales` ( + |`cs_sold_date_sk` INT, `cs_sold_time_sk` INT, `cs_ship_date_sk` INT, + |`cs_bill_customer_sk` INT, `cs_bill_cdemo_sk` INT, `cs_bill_hdemo_sk` INT, + |`cs_bill_addr_sk` INT, `cs_ship_customer_sk` INT, `cs_ship_cdemo_sk` INT, + |`cs_ship_hdemo_sk` INT, `cs_ship_addr_sk` INT, `cs_call_center_sk` INT, + |`cs_catalog_page_sk` INT, `cs_ship_mode_sk` INT, `cs_warehouse_sk` INT, + |`cs_item_sk` INT, `cs_promo_sk` INT, `cs_order_number` INT, `cs_quantity` INT, + |`cs_wholesale_cost` DECIMAL(7,2), `cs_list_price` DECIMAL(7,2), + |`cs_sales_price` DECIMAL(7,2), `cs_ext_discount_amt` DECIMAL(7,2), + |`cs_ext_sales_price` DECIMAL(7,2), `cs_ext_wholesale_cost` DECIMAL(7,2), + |`cs_ext_list_price` DECIMAL(7,2), `cs_ext_tax` DECIMAL(7,2), `cs_coupon_amt` DECIMAL(7,2), + |`cs_ext_ship_cost` DECIMAL(7,2), `cs_net_paid` DECIMAL(7,2), + |`cs_net_paid_inc_tax` DECIMAL(7,2), `cs_net_paid_inc_ship` DECIMAL(7,2), + |`cs_net_paid_inc_ship_tax` DECIMAL(7,2), `cs_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_sales` ( + |`ws_sold_date_sk` INT, `ws_sold_time_sk` INT, `ws_ship_date_sk` INT, `ws_item_sk` INT, + |`ws_bill_customer_sk` INT, `ws_bill_cdemo_sk` INT, `ws_bill_hdemo_sk` INT, + |`ws_bill_addr_sk` INT, `ws_ship_customer_sk` INT, `ws_ship_cdemo_sk` INT, + |`ws_ship_hdemo_sk` INT, `ws_ship_addr_sk` INT, `ws_web_page_sk` INT, `ws_web_site_sk` INT, + |`ws_ship_mode_sk` INT, `ws_warehouse_sk` INT, `ws_promo_sk` INT, `ws_order_number` INT, + |`ws_quantity` INT, `ws_wholesale_cost` DECIMAL(7,2), `ws_list_price` DECIMAL(7,2), + |`ws_sales_price` DECIMAL(7,2), `ws_ext_discount_amt` DECIMAL(7,2), + |`ws_ext_sales_price` DECIMAL(7,2), `ws_ext_wholesale_cost` DECIMAL(7,2), + |`ws_ext_list_price` DECIMAL(7,2), `ws_ext_tax` DECIMAL(7,2), + |`ws_coupon_amt` DECIMAL(7,2), `ws_ext_ship_cost` DECIMAL(7,2), `ws_net_paid` DECIMAL(7,2), + |`ws_net_paid_inc_tax` DECIMAL(7,2), `ws_net_paid_inc_ship` DECIMAL(7,2), + |`ws_net_paid_inc_ship_tax` DECIMAL(7,2), `ws_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store_sales` ( + |`ss_sold_date_sk` INT, `ss_sold_time_sk` INT, `ss_item_sk` INT, `ss_customer_sk` INT, + |`ss_cdemo_sk` INT, `ss_hdemo_sk` INT, `ss_addr_sk` INT, `ss_store_sk` INT, + |`ss_promo_sk` INT, `ss_ticket_number` INT, `ss_quantity` INT, + |`ss_wholesale_cost` DECIMAL(7,2), `ss_list_price` DECIMAL(7,2), + |`ss_sales_price` DECIMAL(7,2), `ss_ext_discount_amt` DECIMAL(7,2), + |`ss_ext_sales_price` DECIMAL(7,2), `ss_ext_wholesale_cost` DECIMAL(7,2), + |`ss_ext_list_price` DECIMAL(7,2), `ss_ext_tax` DECIMAL(7,2), + |`ss_coupon_amt` DECIMAL(7,2), `ss_net_paid` DECIMAL(7,2), + |`ss_net_paid_inc_tax` DECIMAL(7,2), `ss_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_returns` ( + |`wr_returned_date_sk` BIGINT, `wr_returned_time_sk` BIGINT, `wr_item_sk` BIGINT, + |`wr_refunded_customer_sk` BIGINT, `wr_refunded_cdemo_sk` BIGINT, + |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, + |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, + |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, + |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, + |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), + |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), + |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), + |`wr_reversed_charge` DECIMAL(7,2), `wr_account_credit` DECIMAL(7,2), + |`wr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_site` ( + |`web_site_sk` INT, `web_site_id` STRING, `web_rec_start_date` DATE, + |`web_rec_end_date` DATE, `web_name` STRING, `web_open_date_sk` INT, + |`web_close_date_sk` INT, `web_class` STRING, `web_manager` STRING, `web_mkt_id` INT, + |`web_mkt_class` STRING, `web_mkt_desc` STRING, `web_market_manager` STRING, + |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, + |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, + |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, + |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `reason` ( + |`r_reason_sk` INT, `r_reason_id` STRING, `r_reason_desc` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `call_center` ( + |`cc_call_center_sk` INT, `cc_call_center_id` STRING, `cc_rec_start_date` DATE, + |`cc_rec_end_date` DATE, `cc_closed_date_sk` INT, `cc_open_date_sk` INT, `cc_name` STRING, + |`cc_class` STRING, `cc_employees` INT, `cc_sq_ft` INT, `cc_hours` STRING, + |`cc_manager` STRING, `cc_mkt_id` INT, `cc_mkt_class` STRING, `cc_mkt_desc` STRING, + |`cc_market_manager` STRING, `cc_division` INT, `cc_division_name` STRING, `cc_company` INT, + |`cc_company_name` STRING, `cc_street_number` STRING, `cc_street_name` STRING, + |`cc_street_type` STRING, `cc_suite_number` STRING, `cc_city` STRING, `cc_county` STRING, + |`cc_state` STRING, `cc_zip` STRING, `cc_country` STRING, `cc_gmt_offset` DECIMAL(5,2), + |`cc_tax_percentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `warehouse` ( + |`w_warehouse_sk` INT, `w_warehouse_id` STRING, `w_warehouse_name` STRING, + |`w_warehouse_sq_ft` INT, `w_street_number` STRING, `w_street_name` STRING, + |`w_street_type` STRING, `w_suite_number` STRING, `w_city` STRING, `w_county` STRING, + |`w_state` STRING, `w_zip` STRING, `w_country` STRING, `w_gmt_offset` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `ship_mode` ( + |`sm_ship_mode_sk` INT, `sm_ship_mode_id` STRING, `sm_type` STRING, `sm_code` STRING, + |`sm_carrier` STRING, `sm_contract` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `income_band` ( + |`ib_income_band_sk` INT, `ib_lower_bound` INT, `ib_upper_bound` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `time_dim` ( + |`t_time_sk` INT, `t_time_id` STRING, `t_time` INT, `t_hour` INT, `t_minute` INT, + |`t_second` INT, `t_am_pm` STRING, `t_shift` STRING, `t_sub_shift` STRING, + |`t_meal_time` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_page` (`wp_web_page_sk` INT, `wp_web_page_id` STRING, + |`wp_rec_start_date` DATE, `wp_rec_end_date` DATE, `wp_creation_date_sk` INT, + |`wp_access_date_sk` INT, `wp_autogen_flag` STRING, `wp_customer_sk` INT, + |`wp_url` STRING, `wp_type` STRING, `wp_char_count` INT, `wp_link_count` INT, + |`wp_image_count` INT, `wp_max_ad_count` INT) + |USING parquet + """.stripMargin) + } + + val tpcdsQueries = Seq( + "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", + "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", "q30", + "q31", "q32", "q33", "q34", "q35", "q36", "q37", "q38", "q39a", "q39b", "q40", + "q41", "q42", "q43", "q44", "q45", "q46", "q47", "q48", "q49", "q50", + "q51", "q52", "q53", "q54", "q55", "q56", "q57", "q58", "q59", "q60", + "q61", "q62", "q63", "q64", "q65", "q66", "q67", "q68", "q69", "q70", + "q71", "q72", "q73", "q74", "q75", "q76", "q77", "q78", "q79", "q80", + "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", + "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + + tpcdsQueries.foreach { name => + val queryString = resourceToString(s"tpcds/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(name) { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + sql(queryString).collect() + } + } + } +} From 7bf4da8a33c33b03bbfddc698335fe9b86ce1e0e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 28 Sep 2017 10:24:51 +0900 Subject: [PATCH 635/779] [MINOR] Fixed up pandas_udf related docs and formatting ## What changes were proposed in this pull request? Fixed some minor issues with pandas_udf related docs and formatting. ## How was this patch tested? NA Author: Bryan Cutler Closes #19375 from BryanCutler/arrow-pandas_udf-cleanup-minor. --- python/pyspark/serializers.py | 6 +++--- python/pyspark/sql/functions.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index db77b7e150b24..ad18bd0c81eaa 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -191,7 +191,7 @@ def loads(self, obj): class ArrowSerializer(FramedSerializer): """ - Serializes an Arrow stream. + Serializes bytes as Arrow data with the Arrow file format. """ def dumps(self, batch): @@ -239,7 +239,7 @@ class ArrowStreamPandasSerializer(Serializer): def dump_stream(self, iterator, stream): """ - Make ArrowRecordBatches from Pandas Serieses and serialize. Input is a single series or + Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or a list of series accompanied by an optional pyarrow type to coerce the data to. """ import pyarrow as pa @@ -257,7 +257,7 @@ def dump_stream(self, iterator, stream): def load_stream(self, stream): """ - Deserialize ArrowRecordBatchs to an Arrow table and return as a list of pandas.Series. + Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ import pyarrow as pa reader = pa.open_stream(stream) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 63e9a830bbc9e..b45a59db93679 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2199,16 +2199,14 @@ def pandas_udf(f=None, returnType=StringType()): ... >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP + ... .show() # doctest: +SKIP +----------+--------------+------------+ |slen(name)|to_upper(name)|add_one(age)| +----------+--------------+------------+ | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - wrapped_udf = _create_udf(f, returnType=returnType, vectorized=True) - - return wrapped_udf + return _create_udf(f, returnType=returnType, vectorized=True) blacklist = ['map', 'since', 'ignore_unicode_prefix'] From 3b117d631e1ff387b70ed8efba229594f4594db5 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 28 Sep 2017 09:25:21 +0800 Subject: [PATCH 636/779] [SPARK-22123][CORE] Add latest failure reason for task set blacklist ## What changes were proposed in this pull request? This patch add latest failure reason for task set blacklist.Which can be showed on spark ui and let user know failure reason directly. Till now , every job which aborted by completed blacklist just show log like below which has no more information: `Aborting $taskSet because task $indexInTaskSet (partition $partition) cannot run anywhere due to node and executor blacklist. Blacklisting behavior cannot run anywhere due to node and executor blacklist.Blacklisting behavior can be configured via spark.blacklist.*."` **After modify:** ``` Aborting TaskSet 0.0 because task 0 (partition 0) cannot run anywhere due to node and executor blacklist. Most recent failure: Some(Lost task 0.1 in stage 0.0 (TID 3,xxx, executor 1): java.lang.Exception: Fake error! at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:73) at org.apache.spark.scheduler.Task.run(Task.scala:99) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:305) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) ). Blacklisting behavior can be configured via spark.blacklist.*. ``` ## How was this patch tested? Unit test and manually test. Author: zhoukang Closes #19338 from caneGuy/zhoukang/improve-blacklist. --- .../spark/scheduler/TaskSetBlacklist.scala | 14 ++++- .../spark/scheduler/TaskSetManager.scala | 15 +++-- .../scheduler/BlacklistIntegrationSuite.scala | 5 +- .../scheduler/BlacklistTrackerSuite.scala | 60 ++++++++++++------- .../scheduler/TaskSchedulerImplSuite.scala | 11 +++- .../scheduler/TaskSetBlacklistSuite.scala | 45 +++++++++----- .../spark/scheduler/TaskSetManagerSuite.scala | 2 +- 7 files changed, 104 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index e815b7e0cf6c9..233781f3d9719 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -61,6 +61,16 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private val blacklistedExecs = new HashSet[String]() private val blacklistedNodes = new HashSet[String]() + private var latestFailureReason: String = null + + /** + * Get the most recent failure reason of this TaskSet. + * @return + */ + def getLatestFailureReason: String = { + latestFailureReason + } + /** * Return true if this executor is blacklisted for the given task. This does *not* * need to return true if the executor is blacklisted for the entire stage, or blacklisted @@ -94,7 +104,9 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private[scheduler] def updateBlacklistForFailedTask( host: String, exec: String, - index: Int): Unit = { + index: Int, + failureReason: String): Unit = { + latestFailureReason = failureReason val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host)) execFailures.updateWithFailure(index, clock.getTimeMillis()) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3804ea863b4f9..bb867416a4fac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -670,9 +670,14 @@ private[spark] class TaskSetManager( } if (blacklistedEverywhere) { val partition = tasks(indexInTaskSet).partitionId - abort(s"Aborting $taskSet because task $indexInTaskSet (partition $partition) " + - s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior " + - s"can be configured via spark.blacklist.*.") + abort(s""" + |Aborting $taskSet because task $indexInTaskSet (partition $partition) + |cannot run anywhere due to node and executor blacklist. + |Most recent failure: + |${taskSetBlacklist.getLatestFailureReason} + | + |Blacklisting behavior can be configured via spark.blacklist.*. + |""".stripMargin) } } } @@ -837,9 +842,9 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) if (!isZombie && reason.countTowardsTaskFailures) { - taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( - info.host, info.executorId, index)) assert (null != failureReason) + taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( + info.host, info.executorId, index, failureReason)) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { logError("Task %d in stage %s failed %d times; aborting job".format( diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index f6015cd51c2bd..d3bbfd11d406d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -115,8 +115,9 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) awaitJobTermination(jobFuture, duration) - val pattern = ("Aborting TaskSet 0.0 because task .* " + - "cannot run anywhere due to node and executor blacklist").r + val pattern = ( + s"""|Aborting TaskSet 0.0 because task .* + |cannot run anywhere due to node and executor blacklist""".stripMargin).r assert(pattern.findFirstIn(failure.getMessage).isDefined, s"Couldn't find $pattern in ${failure.getMessage()}") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index a136d69b36d6c..cd1b7a9e5ab18 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -110,7 +110,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val taskSetBlacklist = createTaskSetBlacklist(stageId) if (stageId % 2 == 0) { // fail one task in every other taskset - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") failuresSoFar += 1 } blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) @@ -132,7 +133,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // for many different stages, executor 1 fails a task, and then the taskSet fails. (0 until failuresUntilBlacklisted * 10).foreach { stage => val taskSetBlacklist = createTaskSetBlacklist(stage) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") } assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) } @@ -147,7 +149,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val numFailures = math.max(conf.get(config.MAX_FAILURES_PER_EXEC), conf.get(config.MAX_FAILURES_PER_EXEC_STAGE)) (0 until numFailures).foreach { index => - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = index) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = index, failureReason = "testing") } assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) @@ -170,7 +173,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) assert(blacklist.nodeBlacklist() === Set()) @@ -183,7 +187,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) assert(blacklist.nodeBlacklist() === Set("hostA")) @@ -207,7 +212,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail one more task, but executor isn't put back into blacklist since the count of failures // on that executor should have been reset to 0. val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(2, 0, taskSetBlacklist2.execToFailures) assert(blacklist.nodeBlacklist() === Set()) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) @@ -221,7 +227,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Lets say that executor 1 dies completely. We get some task failures, but // the taskset then finishes successfully (elsewhere). (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.handleRemovedExecutor("1") blacklist.updateBlacklistForSuccessfulTaskSet( @@ -236,7 +243,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Now another executor gets spun up on that host, but it also dies. val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.handleRemovedExecutor("2") blacklist.updateBlacklistForSuccessfulTaskSet( @@ -279,7 +287,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M def failOneTaskInTaskSet(exec: String): Unit = { val taskSetBlacklist = createTaskSetBlacklist(stageId = stageId) - taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0) + taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0, "testing") blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) stageId += 1 } @@ -354,12 +362,12 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) // Taskset1 has one failure immediately - taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0) + taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0, "testing") // Then we have a *long* delay, much longer than the timeout, before any other failures or // taskset completion clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS * 5) // After the long delay, we have one failure on taskset 2, on the same executor - taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0) + taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0, "testing") // Finally, we complete both tasksets. Its important here to complete taskset2 *first*. We // want to make sure that when taskset 1 finishes, even though we've now got two task failures, // we realize that the task failure we just added was well before the timeout. @@ -377,16 +385,20 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // we blacklist executors on two different hosts -- make sure that doesn't lead to any // node blacklisting val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "1", 2)) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) - taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) - taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 0, failureReason = "testing") + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(1, 0, taskSetBlacklist1.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "2", 2)) @@ -395,8 +407,10 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Finally, blacklist another executor on the same node as the original blacklisted executor, // and make sure this time we *do* blacklist the node. val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 0) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 0) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 1) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 0, failureReason = "testing") + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2", "3")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "3", 2)) @@ -486,7 +500,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) @@ -497,7 +512,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) @@ -512,7 +528,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) @@ -523,7 +540,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist3.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist3.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index b8626bf777598..6003899bb7bef 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -660,9 +660,14 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(tsm.isZombie) assert(failedTaskSet) val idx = failedTask.index - assert(failedTaskSetReason === s"Aborting TaskSet 0.0 because task $idx (partition $idx) " + - s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior can be " + - s"configured via spark.blacklist.*.") + assert(failedTaskSetReason === s""" + |Aborting $taskSet because task $idx (partition $idx) + |cannot run anywhere due to node and executor blacklist. + |Most recent failure: + |${tsm.taskSetBlacklistHelperOpt.get.getLatestFailureReason} + | + |Blacklisting behavior can be configured via spark.blacklist.*. + |""".stripMargin) } test("don't abort if there is an executor available, though it hasn't had scheduled tasks yet") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index f1392e9db6bfd..18981d5be2f94 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -37,7 +37,8 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // First, mark task 0 as failed on exec1. // task 0 should be blacklisted on exec1, and nowhere else - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec1", index = 0, failureReason = "testing") for { executor <- (1 to 4).map(_.toString) index <- 0 until 10 @@ -49,17 +50,20 @@ class TaskSetBlacklistSuite extends SparkFunSuite { assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to // blacklisting the entire node. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) @@ -108,34 +112,41 @@ class TaskSetBlacklistSuite extends SparkFunSuite { .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) // Fail a task twice on hostA, exec:1 - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail the same task once more on hostA, exec:2 - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, // so its blacklisted - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 2) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "2", index = 2, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are // blacklisted for the taskset, so blacklist the whole node. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 3) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 4) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 3, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 4, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) } @@ -147,13 +158,17 @@ class TaskSetBlacklistSuite extends SparkFunSuite { val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) - taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ae43f4cadc037..5c712bd6a545b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1146,7 +1146,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Make sure that the blacklist ignored all of the task failures above, since they aren't // the fault of the executor where the task was running. verify(blacklist, never()) - .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) + .updateBlacklistForFailedTask(anyString(), anyString(), anyInt(), anyString()) } test("update application blacklist for shuffle-fetch") { From f20be4d70bf321f377020d1bde761a43e5c72f0a Mon Sep 17 00:00:00 2001 From: Paul Mackles Date: Thu, 28 Sep 2017 14:43:31 +0800 Subject: [PATCH 637/779] [SPARK-22135][MESOS] metrics in spark-dispatcher not being registered properly ## What changes were proposed in this pull request? Fix a trivial bug with how metrics are registered in the mesos dispatcher. Bug resulted in creating a new registry each time the metricRegistry() method was called. ## How was this patch tested? Verified manually on local mesos setup Author: Paul Mackles Closes #19358 from pmackles/SPARK-22135. --- .../cluster/mesos/MesosClusterSchedulerSource.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala index 1fe94974c8e36..76aded4edb431 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala @@ -23,8 +23,9 @@ import org.apache.spark.metrics.source.Source private[mesos] class MesosClusterSchedulerSource(scheduler: MesosClusterScheduler) extends Source { - override def sourceName: String = "mesos_cluster" - override def metricRegistry: MetricRegistry = new MetricRegistry() + + override val sourceName: String = "mesos_cluster" + override val metricRegistry: MetricRegistry = new MetricRegistry() metricRegistry.register(MetricRegistry.name("waitingDrivers"), new Gauge[Int] { override def getValue: Int = scheduler.getQueuedDriversSize From 01bd00d13532af1c7328997cbec446b0d3e21459 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 28 Sep 2017 08:22:48 +0100 Subject: [PATCH 638/779] [SPARK-22128][CORE] Update paranamer to 2.8 to avoid BytecodeReadingParanamer ArrayIndexOutOfBoundsException with Scala 2.12 + Java 8 lambda ## What changes were proposed in this pull request? Un-manage jackson-module-paranamer version to let it use the version desired by jackson-module-scala; manage paranamer up from 2.8 for jackson-module-scala 2.7.9, to override avro 1.7.7's desired paranamer 2.3 ## How was this patch tested? Existing tests Author: Sean Owen Closes #19352 from srowen/SPARK-22128. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 10 ++++------ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e534e38213fb1..76fcbd15869f1 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -93,7 +93,7 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-paranamer-2.6.7.jar +jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.0.jar @@ -153,7 +153,7 @@ orc-core-1.4.0-nohive.jar orc-mapreduce-1.4.0-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar -paranamer-2.6.jar +paranamer-2.8.jar parquet-column-1.8.2.jar parquet-common-1.8.2.jar parquet-encoding-1.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 02c5a19d173be..cb20072bf8b30 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -93,7 +93,7 @@ jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-paranamer-2.6.7.jar +jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.0.jar @@ -154,7 +154,7 @@ orc-core-1.4.0-nohive.jar orc-mapreduce-1.4.0-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar -paranamer-2.6.jar +paranamer-2.8.jar parquet-column-1.8.2.jar parquet-common-1.8.2.jar parquet-encoding-1.8.2.jar diff --git a/pom.xml b/pom.xml index 83a35006707da..87a468c3a6f55 100644 --- a/pom.xml +++ b/pom.xml @@ -179,7 +179,10 @@ 4.7 1.1 2.52.0 - 2.6 + + 2.8 1.8 1.0.0 0.4.0 @@ -637,11 +640,6 @@ - - com.fasterxml.jackson.module - jackson-module-paranamer - ${fasterxml.jackson.version} - com.fasterxml.jackson.module jackson-module-jaxb-annotations From d74dee1336e7152cc0fb7d2b3bf1a44f4f452025 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 Sep 2017 09:20:37 -0700 Subject: [PATCH 639/779] [SPARK-22153][SQL] Rename ShuffleExchange -> ShuffleExchangeExec ## What changes were proposed in this pull request? For some reason when we added the Exec suffix to all physical operators, we missed this one. I was looking for this physical operator today and couldn't find it, because I was looking for ExchangeExec. ## How was this patch tested? This is a simple rename and should be covered by existing tests. Author: Reynold Xin Closes #19376 from rxin/SPARK-22153. --- .../spark/sql/execution/SparkStrategies.scala | 6 +-- .../exchange/EnsureRequirements.scala | 26 ++++++------- .../exchange/ExchangeCoordinator.scala | 38 +++++++++---------- ...change.scala => ShuffleExchangeExec.scala} | 10 ++--- .../apache/spark/sql/execution/limit.scala | 6 +-- .../streaming/IncrementalExecution.scala | 4 +- .../apache/spark/sql/CachedTableSuite.scala | 5 ++- .../org/apache/spark/sql/DataFrameSuite.scala | 10 ++--- .../org/apache/spark/sql/DatasetSuite.scala | 4 +- .../execution/ExchangeCoordinatorSuite.scala | 22 +++++------ .../spark/sql/execution/ExchangeSuite.scala | 12 +++--- .../spark/sql/execution/PlannerSuite.scala | 32 ++++++++-------- .../spark/sql/sources/BucketedReadSuite.scala | 10 ++--- .../EnsureStatefulOpPartitioningSuite.scala | 4 +- 14 files changed, 95 insertions(+), 94 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/{ShuffleExchange.scala => ShuffleExchangeExec.scala} (98%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4da7a73469537..92eaab5cd8f81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf @@ -411,7 +411,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { - ShuffleExchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil + ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } @@ -446,7 +446,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case logical.RepartitionByExpression(expressions, child, numPartitions) => - exchange.ShuffleExchange(HashPartitioning( + exchange.ShuffleExchangeExec(HashPartitioning( expressions, numPartitions), planLater(child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 1da72f2e92329..d28ce60e276d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.internal.SQLConf * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the - * input partition ordering requirements are met. + * each operator by inserting [[ShuffleExchangeExec]] Operators where required. Also ensure that + * the input partition ordering requirements are met. */ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions @@ -57,17 +57,17 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } /** - * Adds [[ExchangeCoordinator]] to [[ShuffleExchange]]s if adaptive query execution is enabled - * and partitioning schemes of these [[ShuffleExchange]]s support [[ExchangeCoordinator]]. + * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled + * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]]. */ private def withExchangeCoordinator( children: Seq[SparkPlan], requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { val supportsCoordinator = - if (children.exists(_.isInstanceOf[ShuffleExchange])) { + if (children.exists(_.isInstanceOf[ShuffleExchangeExec])) { // Right now, ExchangeCoordinator only support HashPartitionings. children.forall { - case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true + case e @ ShuffleExchangeExec(hash: HashPartitioning, _, _) => true case child => child.outputPartitioning match { case hash: HashPartitioning => true @@ -94,7 +94,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { targetPostShuffleInputSize, minNumPostShufflePartitions) children.zip(requiredChildDistributions).map { - case (e: ShuffleExchange, _) => + case (e: ShuffleExchangeExec, _) => // This child is an Exchange, we need to add the coordinator. e.copy(coordinator = Some(coordinator)) case (child, distribution) => @@ -138,7 +138,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val targetPartitioning = createPartitioning(distribution, defaultNumPreShufflePartitions) assert(targetPartitioning.isInstanceOf[HashPartitioning]) - ShuffleExchange(targetPartitioning, child, Some(coordinator)) + ShuffleExchangeExec(targetPartitioning, child, Some(coordinator)) } } else { // If we do not need ExchangeCoordinator, the original children are returned. @@ -162,7 +162,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child) } // If the operator has multiple children and specifies child output distributions (e.g. join), @@ -215,8 +215,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { child match { // If child is an exchange, we replace it with // a new one having targetPartitioning. - case ShuffleExchange(_, c, _) => ShuffleExchange(targetPartitioning, c) - case _ => ShuffleExchange(targetPartitioning, child) + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c) + case _ => ShuffleExchangeExec(targetPartitioning, child) } } } @@ -246,9 +246,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ ShuffleExchange(partitioning, child, _) => + case operator @ ShuffleExchangeExec(partitioning, child, _) => child.children match { - case ShuffleExchange(childPartitioning, baseChild, _)::Nil => + case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil => if (childPartitioning.guarantees(partitioning)) child else operator case _ => operator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 9fc4ffb651ec8..78f11ca8d8c78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -35,9 +35,9 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * * A coordinator is constructed with three parameters, `numExchanges`, * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`. - * - `numExchanges` is used to indicated that how many [[ShuffleExchange]]s that will be registered - * to this coordinator. So, when we start to do any actual work, we have a way to make sure that - * we have got expected number of [[ShuffleExchange]]s. + * - `numExchanges` is used to indicated that how many [[ShuffleExchangeExec]]s that will be + * registered to this coordinator. So, when we start to do any actual work, we have a way to + * make sure that we have got expected number of [[ShuffleExchangeExec]]s. * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's * input data size. With this parameter, we can estimate the number of post-shuffle partitions. * This parameter is configured through @@ -47,28 +47,28 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * partitions. * * The workflow of this coordinator is described as follows: - * - Before the execution of a [[SparkPlan]], for a [[ShuffleExchange]] operator, + * - Before the execution of a [[SparkPlan]], for a [[ShuffleExchangeExec]] operator, * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. * This happens in the `doPrepare` method. - * - Once we start to execute a physical plan, a [[ShuffleExchange]] registered to this + * - Once we start to execute a physical plan, a [[ShuffleExchangeExec]] registered to this * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle * [[ShuffledRowRDD]]. - * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]] + * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchangeExec]] * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. * - If this coordinator has not made the decision on how to shuffle data, it will ask those - * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the + * registered [[ShuffleExchangeExec]]s to submit their pre-shuffle stages. Then, based on the * size statistics of pre-shuffle partitions, this coordinator will determine the number of * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices * to a single post-shuffle partition whenever necessary. * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered - * [[ShuffleExchange]]s. So, when a [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator - * can lookup the corresponding [[RDD]]. + * [[ShuffleExchangeExec]]s. So, when a [[ShuffleExchangeExec]] calls `postShuffleRDD`, this + * coordinator can lookup the corresponding [[RDD]]. * * The strategy used to determine the number of post-shuffle partitions is described as follows. * To determine the number of post-shuffle partitions, we have a target input size for a * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages - * corresponding to the registered [[ShuffleExchange]]s, we will do a pass of those statistics and - * pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until + * corresponding to the registered [[ShuffleExchangeExec]]s, we will do a pass of those statistics + * and pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until * adding another pre-shuffle partition would cause the size of a post-shuffle partition to be * greater than the target size. * @@ -89,11 +89,11 @@ class ExchangeCoordinator( extends Logging { // The registered Exchange operators. - private[this] val exchanges = ArrayBuffer[ShuffleExchange]() + private[this] val exchanges = ArrayBuffer[ShuffleExchangeExec]() // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] val postShuffleRDDs: JMap[ShuffleExchange, ShuffledRowRDD] = - new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) + private[this] val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = + new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) // A boolean that indicates if this coordinator has made decision on how to shuffle data. // This variable will only be updated by doEstimationIfNecessary, which is protected by @@ -101,11 +101,11 @@ class ExchangeCoordinator( @volatile private[this] var estimated: Boolean = false /** - * Registers a [[ShuffleExchange]] operator to this coordinator. This method is only allowed to - * be called in the `doPrepare` method of a [[ShuffleExchange]] operator. + * Registers a [[ShuffleExchangeExec]] operator to this coordinator. This method is only allowed + * to be called in the `doPrepare` method of a [[ShuffleExchangeExec]] operator. */ @GuardedBy("this") - def registerExchange(exchange: ShuffleExchange): Unit = synchronized { + def registerExchange(exchange: ShuffleExchangeExec): Unit = synchronized { exchanges += exchange } @@ -200,7 +200,7 @@ class ExchangeCoordinator( // Make sure we have the expected number of registered Exchange operators. assert(exchanges.length == numExchanges) - val newPostShuffleRDDs = new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) + val newPostShuffleRDDs = new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) // Submit all map stages val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]() @@ -255,7 +255,7 @@ class ExchangeCoordinator( } } - def postShuffleRDD(exchange: ShuffleExchange): ShuffledRowRDD = { + def postShuffleRDD(exchange: ShuffleExchangeExec): ShuffledRowRDD = { doEstimationIfNecessary() if (!postShuffleRDDs.containsKey(exchange)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 0d06d83fb2f3c..11c4aa9b4acf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.MutablePair /** * Performs a shuffle that will result in the desired `newPartitioning`. */ -case class ShuffleExchange( +case class ShuffleExchangeExec( var newPartitioning: Partitioning, child: SparkPlan, @transient coordinator: Option[ExchangeCoordinator]) extends Exchange { @@ -84,7 +84,7 @@ case class ShuffleExchange( */ private[exchange] def prepareShuffleDependency() : ShuffleDependency[Int, InternalRow, InternalRow] = { - ShuffleExchange.prepareShuffleDependency( + ShuffleExchangeExec.prepareShuffleDependency( child.execute(), child.output, newPartitioning, serializer) } @@ -129,9 +129,9 @@ case class ShuffleExchange( } } -object ShuffleExchange { - def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = { - ShuffleExchange(newPartitioning, child, coordinator = Option.empty[ExchangeCoordinator]) +object ShuffleExchangeExec { + def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchangeExec = { + ShuffleExchangeExec(newPartitioning, child, coordinator = Option.empty[ExchangeCoordinator]) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 1f515e29b4af5..13da4b26a5dcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.util.Utils /** @@ -40,7 +40,7 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( - ShuffleExchange.prepareShuffleDependency( + ShuffleExchangeExec.prepareShuffleDependency( locallyLimited, child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } @@ -153,7 +153,7 @@ case class TakeOrderedAndProjectExec( } } val shuffled = new ShuffledRowRDD( - ShuffleExchange.prepareShuffleDependency( + ShuffleExchangeExec.prepareShuffleDependency( localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 8e0aae39cabb6..82f879c763c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.streaming.OutputMode /** @@ -155,7 +155,7 @@ object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { child.execute().getNumPartitions == expectedPartitioning.numPartitions) { child } else { - ShuffleExchange(expectedPartitioning, child) + ShuffleExchangeExec(expectedPartitioning, child) } } so.withNewChildren(children) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 3e4f619431599..1e52445f28fc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -420,7 +420,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext * Verifies that the plan for `df` contains `expected` number of Exchange operators. */ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { - assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected) + assert( + df.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e }.size == expected) } test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6178661cf7b2b..0e2f2e5a193e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} @@ -1529,7 +1529,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { fail("Should not have back to back Aggregates") } atFirstAgg = true - case e: ShuffleExchange => atFirstAgg = false + case e: ShuffleExchangeExec => atFirstAgg = false case _ => } } @@ -1710,19 +1710,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val plan = join.queryExecution.executedPlan checkAnswer(join, df) assert( - join.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + join.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size === 1) assert( join.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 1) val broadcasted = broadcast(join) val join2 = join.join(broadcasted, "id").join(broadcasted, "id") checkAnswer(join2, df) assert( - join2.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + join2.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) assert( join2.queryExecution.executedPlan .collect { case e: BroadcastExchangeExec => true }.size === 1) assert( - join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 4) + join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size == 4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5015f3709f131..dace6825ee40e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1206,7 +1206,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = cp.groupBy('id % 2).agg(count('id)) agg.queryExecution.executedPlan.collectFirst { - case ShuffleExchange(_, _: RDDScanExec, _) => + case ShuffleExchangeExec(_, _: RDDScanExec, _) => case BroadcastExchangeExec(_, _: RDDScanExec) => }.foreach { _ => fail( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index f1b5e3be5b63f..737eeb0af586e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -300,13 +300,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = agg.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -314,7 +314,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -351,13 +351,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -365,7 +365,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -407,13 +407,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 4) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -459,13 +459,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 3) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 59eaf4d1c29b7..aac8d56ba6201 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext @@ -31,7 +31,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => ShuffleExchange(SinglePartition, plan), + plan => ShuffleExchangeExec(SinglePartition, plan), input.map(Row.fromTuple) ) } @@ -81,12 +81,12 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) - val exchange1 = ShuffleExchange(part1, plan) - val exchange2 = ShuffleExchange(part1, plan) + val exchange1 = ShuffleExchangeExec(part1, plan) + val exchange2 = ShuffleExchangeExec(part1, plan) val part2 = HashPartitioning(output, 2) - val exchange3 = ShuffleExchange(part2, plan) + val exchange3 = ShuffleExchangeExec(part2, plan) val part3 = HashPartitioning(output ++ output, 2) - val exchange4 = ShuffleExchange(part3, plan) + val exchange4 = ShuffleExchangeExec(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4) assert(exchange1 sameResult exchange1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 63e17c7f372b0..86066362da9dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -214,7 +214,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (small.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: ShuffleExchange => exchange + case exchange: ShuffleExchangeExec => exchange }.length assert(numExchanges === 5) } @@ -229,7 +229,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (normal.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: ShuffleExchange => exchange + case exchange: ShuffleExchangeExec => exchange }.length assert(numExchanges === 5) } @@ -300,7 +300,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -338,7 +338,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -358,7 +358,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") } } @@ -381,7 +381,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") } } @@ -391,7 +391,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchange(finalPartitioning, + val inputPlan = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -400,7 +400,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") } } @@ -411,7 +411,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchange(finalPartitioning, + val inputPlan = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -420,7 +420,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") } } @@ -430,7 +430,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val shuffle = ShuffleExchange(finalPartitioning, + val shuffle = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -449,7 +449,7 @@ class PlannerSuite extends SharedSQLContext { if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } - if (outputPlan.collect { case e: ShuffleExchange => true }.size != 1) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size != 1) { fail(s"Should have only one shuffle:\n$outputPlan") } @@ -459,14 +459,14 @@ class PlannerSuite extends SharedSQLContext { Literal(1) :: Nil, Inner, None, - ShuffleExchange(finalPartitioning, inputPlan), - ShuffleExchange(finalPartitioning, inputPlan)) + ShuffleExchangeExec(finalPartitioning, inputPlan), + ShuffleExchangeExec(finalPartitioning, inputPlan)) val outputPlan2 = ReuseExchange(spark.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchangeExec => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } - if (outputPlan2.collect { case e: ShuffleExchange => true }.size != 2) { + if (outputPlan2.collect { case e: ShuffleExchangeExec => true }.size != 2) { fail(s"Should have only two shuffles:\n$outputPlan") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index eb9e6458fc61c..ab18905e2ddb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -302,10 +302,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { // check existence of shuffle assert( - joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft, + joinOperator.left.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleLeft, s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") assert( - joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight, + joinOperator.right.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleRight, s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") // check existence of sort @@ -506,7 +506,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) } } @@ -520,7 +520,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala index 044bb03480aa4..ed9823fbddfda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} import org.apache.spark.sql.test.SharedSQLContext @@ -93,7 +93,7 @@ class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLCont fail(s"Was expecting an exchange but didn't get one in:\n$executed") } assert(exchange.get === - ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), inputPlan), + ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan), s"Exchange didn't have expected properties:\n${exchange.get}") } else { assert(!executed.children.exists(_.isInstanceOf[Exchange]), From d29d1e87995e02cb57ba3026c945c3cd66bb06e2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 Sep 2017 15:59:05 -0700 Subject: [PATCH 640/779] [SPARK-22159][SQL] Make config names consistently end with "enabled". ## What changes were proposed in this pull request? spark.sql.execution.arrow.enable and spark.sql.codegen.aggregate.map.twolevel.enable -> enabled ## How was this patch tested? N/A Author: Reynold Xin Closes #19384 from rxin/SPARK-22159. --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d00c672487532..358cf62149070 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -668,7 +668,7 @@ object SQLConf { .createWithDefault(40) val ENABLE_TWOLEVEL_AGG_MAP = - buildConf("spark.sql.codegen.aggregate.map.twolevel.enable") + buildConf("spark.sql.codegen.aggregate.map.twolevel.enabled") .internal() .doc("Enable two-level aggregate hash map. When enabled, records will first be " + "inserted/looked-up at a 1st-level, small, fast map, and then fallback to a " + @@ -908,7 +908,7 @@ object SQLConf { .createWithDefault(false) val ARROW_EXECUTION_ENABLE = - buildConf("spark.sql.execution.arrow.enable") + buildConf("spark.sql.execution.arrow.enabled") .internal() .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + From 323806e68f91f3c7521327186a37ddd1436267d0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 Sep 2017 21:07:12 -0700 Subject: [PATCH 641/779] [SPARK-22160][SQL] Make sample points per partition (in range partitioner) configurable and bump the default value up to 100 ## What changes were proposed in this pull request? Spark's RangePartitioner hard codes the number of sampling points per partition to be 20. This is sometimes too low. This ticket makes it configurable, via spark.sql.execution.rangeExchange.sampleSizePerPartition, and raises the default in Spark SQL to be 100. ## How was this patch tested? Added a pretty sophisticated test based on chi square test ... Author: Reynold Xin Closes #19387 from rxin/SPARK-22160. --- .../scala/org/apache/spark/Partitioner.scala | 15 ++++- .../apache/spark/sql/internal/SQLConf.scala | 10 +++ .../exchange/ShuffleExchangeExec.scala | 7 +- .../spark/sql/ConfigBehaviorSuite.scala | 66 +++++++++++++++++++ 4 files changed, 95 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 1484f29525a4e..debbd8d7c26c9 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -108,11 +108,21 @@ class HashPartitioner(partitions: Int) extends Partitioner { class RangePartitioner[K : Ordering : ClassTag, V]( partitions: Int, rdd: RDD[_ <: Product2[K, V]], - private var ascending: Boolean = true) + private var ascending: Boolean = true, + val samplePointsPerPartitionHint: Int = 20) extends Partitioner { + // A constructor declared in order to maintain backward compatibility for Java, when we add the + // 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160. + // This is added to make sure from a bytecode point of view, there is still a 3-arg ctor. + def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = { + this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20) + } + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") + require(samplePointsPerPartitionHint > 0, + s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint") private var ordering = implicitly[Ordering[K]] @@ -122,7 +132,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. - val sampleSize = math.min(20.0 * partitions, 1e6) + // Cast to double to avoid overflowing ints or longs + val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6) // Assume the input partitions are roughly balanced and over-sample a little bit. val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 358cf62149070..1a73d168b9b6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -907,6 +907,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION = + buildConf("spark.sql.execution.rangeExchange.sampleSizePerPartition") + .internal() + .doc("Number of points to sample per partition in order to determine the range boundaries" + + " for range partitioning, typically used in global sorting (without limit).") + .intConf + .createWithDefault(100) + val ARROW_EXECUTION_ENABLE = buildConf("spark.sql.execution.arrow.enabled") .internal() @@ -1199,6 +1207,8 @@ class SQLConf extends Serializable with Logging { def supportQuotedRegexColumnName: Boolean = getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME) + def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 11c4aa9b4acf0..5a1e217082bc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.MutablePair /** @@ -218,7 +219,11 @@ object ShuffleExchangeExec { iter.map(row => mutablePair.update(row.copy(), null)) } implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes) - new RangePartitioner(numPartitions, rddForSampling, ascending = true) + new RangePartitioner( + numPartitions, + rddForSampling, + ascending = true, + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new Partitioner { override def numPartitions: Int = 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala new file mode 100644 index 0000000000000..2c1e5db5fd9bb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.commons.math3.stat.inference.ChiSquareTest + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + + +class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + test("SPARK-22160 spark.sql.execution.rangeExchange.sampleSizePerPartition") { + // In this test, we run a sort and compute the histogram for partition size post shuffle. + // With a high sample count, the partition size should be more evenly distributed, and has a + // low chi-sq test value. + // Also the whole code path for range partitioning as implemented should be deterministic + // (it uses the partition id as the seed), so this test shouldn't be flaky. + + val numPartitions = 4 + + def computeChiSquareTest(): Double = { + val n = 10000 + // Trigger a sort + val data = spark.range(0, n, 1, 1).sort('id) + .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() + + // Compute histogram for the number of records per partition post sort + val dist = data.groupBy(_._1).map(_._2.length.toLong).toArray + assert(dist.length == 4) + + new ChiSquareTest().chiSquare( + Array.fill(numPartitions) { n.toDouble / numPartitions }, + dist) + } + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) { + // The default chi-sq value should be low + assert(computeChiSquareTest() < 100) + + withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") { + // If we only sample one point, the range boundaries will be pretty bad and the + // chi-sq value would be very high. + assert(computeChiSquareTest() > 1000) + } + } + } + +} From 161ba7eaa4539f0a7f20d9e2a493e0e323ca5249 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 28 Sep 2017 23:14:53 -0700 Subject: [PATCH 642/779] [SPARK-22146] FileNotFoundException while reading ORC files containing special characters ## What changes were proposed in this pull request? Reading ORC files containing special characters like '%' fails with a FileNotFoundException. This PR aims to fix the problem. ## How was this patch tested? Added UT. Author: Marco Gaido Author: Marco Gaido Closes #19368 from mgaido91/SPARK-22146. --- .../apache/spark/sql/hive/orc/OrcFileFormat.scala | 2 +- .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 4d92a67044373..c76f0ebb36a60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -58,7 +58,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { OrcFileOperator.readSchema( - files.map(_.getPath.toUri.toString), + files.map(_.getPath.toString), Some(sparkSession.sessionState.newHadoopConf()) ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 29b0e6c8533ef..f5d41c91270a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -993,7 +993,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv spark.sql("""drop database if exists testdb8156 CASCADE""") } - test("skip hive metadata on table creation") { withTempDir { tempPath => val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) @@ -1345,6 +1344,17 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + Seq("orc", "parquet", "csv", "json", "text").foreach { format => + test(s"SPARK-22146: read files containing special characters using $format") { + val nameWithSpecialChars = s"sp&cial%chars" + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + spark.read.format(format).load(tmpFile) + } + } + } + private def withDebugMode(f: => Unit): Unit = { val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) try { From 0fa4dbe4f4d7b988be2105b46590b5207f7c8121 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 28 Sep 2017 23:23:30 -0700 Subject: [PATCH 643/779] [SPARK-22141][FOLLOWUP][SQL] Add comments for the order of batches ## What changes were proposed in this pull request? Add comments for specifying the position of batch "Check Cartesian Products", as rxin suggested in https://github.com/apache/spark/pull/19362 . ## How was this patch tested? Unit test Author: Wang Gengliang Closes #19379 from gengliangwang/SPARK-22141-followup. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a391c513ad384..b9fa39d6dad4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -134,6 +134,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :: + // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: Batch("OptimizeCodegen", Once, @@ -1089,6 +1090,9 @@ object CombineLimits extends Rule[LogicalPlan] { * SELECT * from R, S where R.r = S.s, * the join between R and S is not a cartesian product and therefore should be allowed. * The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule. + * + * This rule must be run AFTER the batch "LocalRelation", since a join with empty relation should + * not be a cartesian product. */ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { /** From a2516f41aef68e39df7f6380fd2618cc148a609e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 29 Sep 2017 08:26:53 +0100 Subject: [PATCH 644/779] [SPARK-22142][BUILD][STREAMING] Move Flume support behind a profile ## What changes were proposed in this pull request? Add 'flume' profile to enable Flume-related integration modules ## How was this patch tested? Existing tests; no functional change Author: Sean Owen Closes #19365 from srowen/SPARK-22142. --- dev/create-release/release-build.sh | 4 ++-- dev/mima | 2 +- dev/scalastyle | 1 + dev/sparktestsupport/modules.py | 20 +++++++++++++++++++- dev/test-dependencies.sh | 2 +- docs/building-spark.md | 6 ++++++ pom.xml | 13 ++++++++++--- project/SparkBuild.scala | 17 +++++++++-------- python/pyspark/streaming/tests.py | 16 +++++++++++++--- 9 files changed, 62 insertions(+), 19 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 8de1d6a37dc25..c548a0a4e4bee 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -84,9 +84,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/mima b/dev/mima index fdb21f5007cf2..1e3ca9700bc07 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index e5aa589869535..89ecc8abd6f8c 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -25,6 +25,7 @@ ERRORS=$(echo -e "q\n" \ -Pmesos \ -Pkafka-0-8 \ -Pyarn \ + -Pflume \ -Phive \ -Phive-thriftserver \ scalastyle test:scalastyle \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 50e14b60545af..91d5667ed1f07 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -279,6 +279,12 @@ def __hash__(self): source_file_regexes=[ "external/flume-sink", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume-sink/test", ] @@ -291,6 +297,12 @@ def __hash__(self): source_file_regexes=[ "external/flume", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume/test", ] @@ -302,7 +314,13 @@ def __hash__(self): dependencies=[streaming_flume, streaming_flume_sink], source_file_regexes=[ "external/flume-assembly", - ] + ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + } ) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index c7714578bd005..58b295d4f6e00 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 diff --git a/docs/building-spark.md b/docs/building-spark.md index 57baa503259c1..e1532de16108d 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -100,6 +100,12 @@ Note: Kafka 0.8 support is deprecated as of Spark 2.3.0. Kafka 0.10 support is still automatically built. +## Building with Flume support + +Apache Flume support must be explicitly enabled with the `flume` profile. + + ./build/mvn -Pflume -DskipTests clean package + ## Building submodules individually It's possible to build Spark sub-modules using the `mvn -pl` option. diff --git a/pom.xml b/pom.xml index 87a468c3a6f55..9fac8b1e53788 100644 --- a/pom.xml +++ b/pom.xml @@ -98,15 +98,13 @@ sql/core sql/hive assembly - external/flume - external/flume-sink - external/flume-assembly examples repl launcher external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + @@ -2583,6 +2581,15 @@ + + flume + + external/flume + external/flume-sink + external/flume-assembly + + + spark-ganglia-lgpl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a568d264cb2db..9501eed1e906b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,11 +43,8 @@ object BuildCommons { "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) - val streamingProjects@Seq( - streaming, streamingFlumeSink, streamingFlume, streamingKafka010 - ) = Seq( - "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-10" - ).map(ProjectRef(buildLocation, _)) + val streamingProjects@Seq(streaming, streamingKafka010) = + Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* @@ -56,9 +53,13 @@ object BuildCommons { "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects - val optionallyEnabledProjects@Seq(mesos, yarn, streamingKafka, sparkGangliaLgpl, - streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = - Seq("mesos", "yarn", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", + val optionallyEnabledProjects@Seq(mesos, yarn, + streamingFlumeSink, streamingFlume, + streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, + dockerIntegrationTests, hadoopCloud) = + Seq("mesos", "yarn", + "streaming-flume-sink", "streaming-flume", + "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 229cf53e47359..5b86c1cb2c390 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1478,7 +1478,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1495,7 +1495,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1516,6 +1516,9 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in modules.py +flume_test_environ_var = "ENABLE_FLUME_TESTS" +are_flume_tests_enabled = os.environ.get(flume_test_environ_var) == '1' # Must be same as the variable and condition defined in modules.py kafka_test_environ_var = "ENABLE_KAFKA_0_8_TESTS" are_kafka_tests_enabled = os.environ.get(kafka_test_environ_var) == '1' @@ -1538,9 +1541,16 @@ def search_kinesis_asl_assembly_jar(): os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - FlumeStreamTests, FlumePollingStreamTests, StreamingListenerTests] + if are_flume_tests_enabled: + testcases.append(FlumeStreamTests) + testcases.append(FlumePollingStreamTests) + else: + sys.stderr.write( + "Skipped test_flume_stream (enable by setting environment variable %s=1" + % flume_test_environ_var) + if are_kafka_tests_enabled: testcases.append(KafkaStreamTests) else: From ecbe416ab5001b32737966c5a2407597a1dafc32 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 29 Sep 2017 08:04:14 -0700 Subject: [PATCH 645/779] [SPARK-22129][SPARK-22138] Release script improvements ## What changes were proposed in this pull request? Use the GPG_KEY param, fix lsof to non-hardcoded path, remove version swap since it wasn't really needed. Use EXPORT on JAVA_HOME for downstream scripts as well. ## How was this patch tested? Rolled 2.1.2 RC2 Author: Holden Karau Closes #19359 from holdenk/SPARK-22129-fix-signing. --- dev/create-release/release-build.sh | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c548a0a4e4bee..7e8d5c7075195 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -74,7 +74,7 @@ GIT_REF=${GIT_REF:-master} # Destination directory parent on remote server REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} -GPG="gpg --no-tty --batch" +GPG="gpg -u $GPG_KEY --no-tty --batch" NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) @@ -125,7 +125,7 @@ else echo "Please set JAVA_HOME correctly." exit 1 else - JAVA_HOME="$JAVA_7_HOME" + export JAVA_HOME="$JAVA_7_HOME" fi fi fi @@ -140,7 +140,7 @@ DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" function LFTP { SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" COMMANDS=$(cat < Date: Fri, 29 Sep 2017 08:59:42 -0700 Subject: [PATCH 646/779] [SPARK-22161][SQL] Add Impala-modified TPC-DS queries ## What changes were proposed in this pull request? Added IMPALA-modified TPCDS queries to TPC-DS query suites. - Ref: https://github.com/cloudera/impala-tpcds-kit/tree/master/queries ## How was this patch tested? N/A Author: gatorsmile Closes #19386 from gatorsmile/addImpalaQueries. --- .../resources/tpcds-modifiedQueries/q10.sql | 70 ++++++ .../resources/tpcds-modifiedQueries/q19.sql | 38 +++ .../resources/tpcds-modifiedQueries/q27.sql | 43 ++++ .../resources/tpcds-modifiedQueries/q3.sql | 228 ++++++++++++++++++ .../resources/tpcds-modifiedQueries/q34.sql | 45 ++++ .../resources/tpcds-modifiedQueries/q42.sql | 28 +++ .../resources/tpcds-modifiedQueries/q43.sql | 36 +++ .../resources/tpcds-modifiedQueries/q46.sql | 80 ++++++ .../resources/tpcds-modifiedQueries/q52.sql | 27 +++ .../resources/tpcds-modifiedQueries/q53.sql | 37 +++ .../resources/tpcds-modifiedQueries/q55.sql | 24 ++ .../resources/tpcds-modifiedQueries/q59.sql | 83 +++++++ .../resources/tpcds-modifiedQueries/q63.sql | 29 +++ .../resources/tpcds-modifiedQueries/q65.sql | 58 +++++ .../resources/tpcds-modifiedQueries/q68.sql | 62 +++++ .../resources/tpcds-modifiedQueries/q7.sql | 31 +++ .../resources/tpcds-modifiedQueries/q73.sql | 49 ++++ .../resources/tpcds-modifiedQueries/q79.sql | 59 +++++ .../resources/tpcds-modifiedQueries/q89.sql | 43 ++++ .../resources/tpcds-modifiedQueries/q98.sql | 32 +++ .../tpcds-modifiedQueries/ss_max.sql | 14 ++ .../apache/spark/sql/TPCDSQuerySuite.scala | 26 +- 22 files changed, 1141 insertions(+), 1 deletion(-) create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql create mode 100755 sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql new file mode 100755 index 0000000000000..79dd3d516e8c7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql @@ -0,0 +1,70 @@ +-- start query 10 in stream 0 using template query10.tpl +with +v1 as ( + select + ws_bill_customer_sk as customer_sk + from web_sales, + date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 + union all + select + cs_ship_customer_sk as customer_sk + from catalog_sales, + date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 +), +v2 as ( + select + ss_customer_sk as customer_sk + from store_sales, + date_dim + where ss_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 +) +select + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +from customer c +join customer_address ca on (c.c_current_addr_sk = ca.ca_address_sk) +join customer_demographics on (cd_demo_sk = c.c_current_cdemo_sk) +left semi join v1 on (v1.customer_sk = c.c_customer_sk) +left semi join v2 on (v2.customer_sk = c.c_customer_sk) +where + ca_county in ('Walker County','Richland County','Gaines County','Douglas County','Dona Ana County') +group by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 +-- end query 10 in stream 0 using template query10.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql new file mode 100755 index 0000000000000..1799827762916 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql @@ -0,0 +1,38 @@ +-- start query 19 in stream 0 using template query19.tpl +select + i_brand_id brand_id, + i_brand brand, + i_manufact_id, + i_manufact, + sum(ss_ext_sales_price) ext_price +from + date_dim, + store_sales, + item, + customer, + customer_address, + store +where + d_date_sk = ss_sold_date_sk + and ss_item_sk = i_item_sk + and i_manager_id = 7 + and d_moy = 11 + and d_year = 1999 + and ss_customer_sk = c_customer_sk + and c_current_addr_sk = ca_address_sk + and substr(ca_zip, 1, 5) <> substr(s_zip, 1, 5) + and ss_store_sk = s_store_sk + and ss_sold_date_sk between 2451484 and 2451513 -- partition key filter +group by + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +order by + ext_price desc, + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +limit 100 +-- end query 19 in stream 0 using template query19.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql new file mode 100755 index 0000000000000..dedbc62a2ab2e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql @@ -0,0 +1,43 @@ +-- start query 27 in stream 0 using template query27.tpl + with results as + (select i_item_id, + s_state, + ss_quantity agg1, + ss_list_price agg2, + ss_coupon_amt agg3, + ss_sales_price agg4 + --0 as g_state, + --avg(ss_quantity) agg1, + --avg(ss_list_price) agg2, + --avg(ss_coupon_amt) agg3, + --avg(ss_sales_price) agg4 + from store_sales, customer_demographics, date_dim, store, item + where ss_sold_date_sk = d_date_sk and + ss_sold_date_sk between 2451545 and 2451910 and + ss_item_sk = i_item_sk and + ss_store_sk = s_store_sk and + ss_cdemo_sk = cd_demo_sk and + cd_gender = 'F' and + cd_marital_status = 'D' and + cd_education_status = 'Primary' and + d_year = 2000 and + s_state in ('TN','AL', 'SD', 'SD', 'SD', 'SD') + --group by i_item_id, s_state + ) + + select i_item_id, + s_state, g_state, agg1, agg2, agg3, agg4 + from ( + select i_item_id, s_state, 0 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, avg(agg4) agg4 from results + group by i_item_id, s_state + union all + select i_item_id, NULL AS s_state, 1 AS g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, + avg(agg4) agg4 from results + group by i_item_id + union all + select NULL AS i_item_id, NULL as s_state, 1 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, + avg(agg4) agg4 from results + ) foo + order by i_item_id, s_state + limit 100 +-- end query 27 in stream 0 using template query27.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql new file mode 100755 index 0000000000000..35b0a20f80a4e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql @@ -0,0 +1,228 @@ +-- start query 3 in stream 0 using template query3.tpl +select + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_net_profit) sum_agg +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manufact_id = 436 + and dt.d_moy = 12 + -- partition key filters + and ( +ss_sold_date_sk between 2415355 and 2415385 +or ss_sold_date_sk between 2415720 and 2415750 +or ss_sold_date_sk between 2416085 and 2416115 +or ss_sold_date_sk between 2416450 and 2416480 +or ss_sold_date_sk between 2416816 and 2416846 +or ss_sold_date_sk between 2417181 and 2417211 +or ss_sold_date_sk between 2417546 and 2417576 +or ss_sold_date_sk between 2417911 and 2417941 +or ss_sold_date_sk between 2418277 and 2418307 +or ss_sold_date_sk between 2418642 and 2418672 +or ss_sold_date_sk between 2419007 and 2419037 +or ss_sold_date_sk between 2419372 and 2419402 +or ss_sold_date_sk between 2419738 and 2419768 +or ss_sold_date_sk between 2420103 and 2420133 +or ss_sold_date_sk between 2420468 and 2420498 +or ss_sold_date_sk between 2420833 and 2420863 +or ss_sold_date_sk between 2421199 and 2421229 +or ss_sold_date_sk between 2421564 and 2421594 +or ss_sold_date_sk between 2421929 and 2421959 +or ss_sold_date_sk between 2422294 and 2422324 +or ss_sold_date_sk between 2422660 and 2422690 +or ss_sold_date_sk between 2423025 and 2423055 +or ss_sold_date_sk between 2423390 and 2423420 +or ss_sold_date_sk between 2423755 and 2423785 +or ss_sold_date_sk between 2424121 and 2424151 +or ss_sold_date_sk between 2424486 and 2424516 +or ss_sold_date_sk between 2424851 and 2424881 +or ss_sold_date_sk between 2425216 and 2425246 +or ss_sold_date_sk between 2425582 and 2425612 +or ss_sold_date_sk between 2425947 and 2425977 +or ss_sold_date_sk between 2426312 and 2426342 +or ss_sold_date_sk between 2426677 and 2426707 +or ss_sold_date_sk between 2427043 and 2427073 +or ss_sold_date_sk between 2427408 and 2427438 +or ss_sold_date_sk between 2427773 and 2427803 +or ss_sold_date_sk between 2428138 and 2428168 +or ss_sold_date_sk between 2428504 and 2428534 +or ss_sold_date_sk between 2428869 and 2428899 +or ss_sold_date_sk between 2429234 and 2429264 +or ss_sold_date_sk between 2429599 and 2429629 +or ss_sold_date_sk between 2429965 and 2429995 +or ss_sold_date_sk between 2430330 and 2430360 +or ss_sold_date_sk between 2430695 and 2430725 +or ss_sold_date_sk between 2431060 and 2431090 +or ss_sold_date_sk between 2431426 and 2431456 +or ss_sold_date_sk between 2431791 and 2431821 +or ss_sold_date_sk between 2432156 and 2432186 +or ss_sold_date_sk between 2432521 and 2432551 +or ss_sold_date_sk between 2432887 and 2432917 +or ss_sold_date_sk between 2433252 and 2433282 +or ss_sold_date_sk between 2433617 and 2433647 +or ss_sold_date_sk between 2433982 and 2434012 +or ss_sold_date_sk between 2434348 and 2434378 +or ss_sold_date_sk between 2434713 and 2434743 +or ss_sold_date_sk between 2435078 and 2435108 +or ss_sold_date_sk between 2435443 and 2435473 +or ss_sold_date_sk between 2435809 and 2435839 +or ss_sold_date_sk between 2436174 and 2436204 +or ss_sold_date_sk between 2436539 and 2436569 +or ss_sold_date_sk between 2436904 and 2436934 +or ss_sold_date_sk between 2437270 and 2437300 +or ss_sold_date_sk between 2437635 and 2437665 +or ss_sold_date_sk between 2438000 and 2438030 +or ss_sold_date_sk between 2438365 and 2438395 +or ss_sold_date_sk between 2438731 and 2438761 +or ss_sold_date_sk between 2439096 and 2439126 +or ss_sold_date_sk between 2439461 and 2439491 +or ss_sold_date_sk between 2439826 and 2439856 +or ss_sold_date_sk between 2440192 and 2440222 +or ss_sold_date_sk between 2440557 and 2440587 +or ss_sold_date_sk between 2440922 and 2440952 +or ss_sold_date_sk between 2441287 and 2441317 +or ss_sold_date_sk between 2441653 and 2441683 +or ss_sold_date_sk between 2442018 and 2442048 +or ss_sold_date_sk between 2442383 and 2442413 +or ss_sold_date_sk between 2442748 and 2442778 +or ss_sold_date_sk between 2443114 and 2443144 +or ss_sold_date_sk between 2443479 and 2443509 +or ss_sold_date_sk between 2443844 and 2443874 +or ss_sold_date_sk between 2444209 and 2444239 +or ss_sold_date_sk between 2444575 and 2444605 +or ss_sold_date_sk between 2444940 and 2444970 +or ss_sold_date_sk between 2445305 and 2445335 +or ss_sold_date_sk between 2445670 and 2445700 +or ss_sold_date_sk between 2446036 and 2446066 +or ss_sold_date_sk between 2446401 and 2446431 +or ss_sold_date_sk between 2446766 and 2446796 +or ss_sold_date_sk between 2447131 and 2447161 +or ss_sold_date_sk between 2447497 and 2447527 +or ss_sold_date_sk between 2447862 and 2447892 +or ss_sold_date_sk between 2448227 and 2448257 +or ss_sold_date_sk between 2448592 and 2448622 +or ss_sold_date_sk between 2448958 and 2448988 +or ss_sold_date_sk between 2449323 and 2449353 +or ss_sold_date_sk between 2449688 and 2449718 +or ss_sold_date_sk between 2450053 and 2450083 +or ss_sold_date_sk between 2450419 and 2450449 +or ss_sold_date_sk between 2450784 and 2450814 +or ss_sold_date_sk between 2451149 and 2451179 +or ss_sold_date_sk between 2451514 and 2451544 +or ss_sold_date_sk between 2451880 and 2451910 +or ss_sold_date_sk between 2452245 and 2452275 +or ss_sold_date_sk between 2452610 and 2452640 +or ss_sold_date_sk between 2452975 and 2453005 +or ss_sold_date_sk between 2453341 and 2453371 +or ss_sold_date_sk between 2453706 and 2453736 +or ss_sold_date_sk between 2454071 and 2454101 +or ss_sold_date_sk between 2454436 and 2454466 +or ss_sold_date_sk between 2454802 and 2454832 +or ss_sold_date_sk between 2455167 and 2455197 +or ss_sold_date_sk between 2455532 and 2455562 +or ss_sold_date_sk between 2455897 and 2455927 +or ss_sold_date_sk between 2456263 and 2456293 +or ss_sold_date_sk between 2456628 and 2456658 +or ss_sold_date_sk between 2456993 and 2457023 +or ss_sold_date_sk between 2457358 and 2457388 +or ss_sold_date_sk between 2457724 and 2457754 +or ss_sold_date_sk between 2458089 and 2458119 +or ss_sold_date_sk between 2458454 and 2458484 +or ss_sold_date_sk between 2458819 and 2458849 +or ss_sold_date_sk between 2459185 and 2459215 +or ss_sold_date_sk between 2459550 and 2459580 +or ss_sold_date_sk between 2459915 and 2459945 +or ss_sold_date_sk between 2460280 and 2460310 +or ss_sold_date_sk between 2460646 and 2460676 +or ss_sold_date_sk between 2461011 and 2461041 +or ss_sold_date_sk between 2461376 and 2461406 +or ss_sold_date_sk between 2461741 and 2461771 +or ss_sold_date_sk between 2462107 and 2462137 +or ss_sold_date_sk between 2462472 and 2462502 +or ss_sold_date_sk between 2462837 and 2462867 +or ss_sold_date_sk between 2463202 and 2463232 +or ss_sold_date_sk between 2463568 and 2463598 +or ss_sold_date_sk between 2463933 and 2463963 +or ss_sold_date_sk between 2464298 and 2464328 +or ss_sold_date_sk between 2464663 and 2464693 +or ss_sold_date_sk between 2465029 and 2465059 +or ss_sold_date_sk between 2465394 and 2465424 +or ss_sold_date_sk between 2465759 and 2465789 +or ss_sold_date_sk between 2466124 and 2466154 +or ss_sold_date_sk between 2466490 and 2466520 +or ss_sold_date_sk between 2466855 and 2466885 +or ss_sold_date_sk between 2467220 and 2467250 +or ss_sold_date_sk between 2467585 and 2467615 +or ss_sold_date_sk between 2467951 and 2467981 +or ss_sold_date_sk between 2468316 and 2468346 +or ss_sold_date_sk between 2468681 and 2468711 +or ss_sold_date_sk between 2469046 and 2469076 +or ss_sold_date_sk between 2469412 and 2469442 +or ss_sold_date_sk between 2469777 and 2469807 +or ss_sold_date_sk between 2470142 and 2470172 +or ss_sold_date_sk between 2470507 and 2470537 +or ss_sold_date_sk between 2470873 and 2470903 +or ss_sold_date_sk between 2471238 and 2471268 +or ss_sold_date_sk between 2471603 and 2471633 +or ss_sold_date_sk between 2471968 and 2471998 +or ss_sold_date_sk between 2472334 and 2472364 +or ss_sold_date_sk between 2472699 and 2472729 +or ss_sold_date_sk between 2473064 and 2473094 +or ss_sold_date_sk between 2473429 and 2473459 +or ss_sold_date_sk between 2473795 and 2473825 +or ss_sold_date_sk between 2474160 and 2474190 +or ss_sold_date_sk between 2474525 and 2474555 +or ss_sold_date_sk between 2474890 and 2474920 +or ss_sold_date_sk between 2475256 and 2475286 +or ss_sold_date_sk between 2475621 and 2475651 +or ss_sold_date_sk between 2475986 and 2476016 +or ss_sold_date_sk between 2476351 and 2476381 +or ss_sold_date_sk between 2476717 and 2476747 +or ss_sold_date_sk between 2477082 and 2477112 +or ss_sold_date_sk between 2477447 and 2477477 +or ss_sold_date_sk between 2477812 and 2477842 +or ss_sold_date_sk between 2478178 and 2478208 +or ss_sold_date_sk between 2478543 and 2478573 +or ss_sold_date_sk between 2478908 and 2478938 +or ss_sold_date_sk between 2479273 and 2479303 +or ss_sold_date_sk between 2479639 and 2479669 +or ss_sold_date_sk between 2480004 and 2480034 +or ss_sold_date_sk between 2480369 and 2480399 +or ss_sold_date_sk between 2480734 and 2480764 +or ss_sold_date_sk between 2481100 and 2481130 +or ss_sold_date_sk between 2481465 and 2481495 +or ss_sold_date_sk between 2481830 and 2481860 +or ss_sold_date_sk between 2482195 and 2482225 +or ss_sold_date_sk between 2482561 and 2482591 +or ss_sold_date_sk between 2482926 and 2482956 +or ss_sold_date_sk between 2483291 and 2483321 +or ss_sold_date_sk between 2483656 and 2483686 +or ss_sold_date_sk between 2484022 and 2484052 +or ss_sold_date_sk between 2484387 and 2484417 +or ss_sold_date_sk between 2484752 and 2484782 +or ss_sold_date_sk between 2485117 and 2485147 +or ss_sold_date_sk between 2485483 and 2485513 +or ss_sold_date_sk between 2485848 and 2485878 +or ss_sold_date_sk between 2486213 and 2486243 +or ss_sold_date_sk between 2486578 and 2486608 +or ss_sold_date_sk between 2486944 and 2486974 +or ss_sold_date_sk between 2487309 and 2487339 +or ss_sold_date_sk between 2487674 and 2487704 +or ss_sold_date_sk between 2488039 and 2488069 +) +group by + dt.d_year, + item.i_brand, + item.i_brand_id +order by + dt.d_year, + sum_agg desc, + brand_id +limit 100 +-- end query 3 in stream 0 using template query3.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql new file mode 100755 index 0000000000000..d11696e5e0c34 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql @@ -0,0 +1,45 @@ +-- start query 34 in stream 0 using template query34.tpl +select + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +from + (select + ss_ticket_number, + ss_customer_sk, + count(*) cnt + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and (date_dim.d_dom between 1 and 3 + or date_dim.d_dom between 25 and 28) + and (household_demographics.hd_buy_potential = '>10000' + or household_demographics.hd_buy_potential = 'Unknown') + and household_demographics.hd_vehicle_count > 0 + and (case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end) > 1.2 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_county in ('Saginaw County', 'Sumner County', 'Appanoose County', 'Daviess County', 'Fairfield County', 'Raleigh County', 'Ziebach County', 'Williamson County') + and ss_sold_date_sk between 2450816 and 2451910 -- partition key filter + group by + ss_ticket_number, + ss_customer_sk + ) dn, + customer +where + ss_customer_sk = c_customer_sk + and cnt between 15 and 20 +order by + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag desc +-- end query 34 in stream 0 using template query34.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql new file mode 100755 index 0000000000000..b6332a8afbebe --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql @@ -0,0 +1,28 @@ +-- start query 42 in stream 0 using template query42.tpl +select + dt.d_year, + item.i_category_id, + item.i_category, + sum(ss_ext_sales_price) +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manager_id = 1 + and dt.d_moy = 12 + and dt.d_year = 1998 + and ss_sold_date_sk between 2451149 and 2451179 -- partition key filter +group by + dt.d_year, + item.i_category_id, + item.i_category +order by + sum(ss_ext_sales_price) desc, + dt.d_year, + item.i_category_id, + item.i_category +limit 100 +-- end query 42 in stream 0 using template query42.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql new file mode 100755 index 0000000000000..cc2040b2fdb7c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql @@ -0,0 +1,36 @@ +-- start query 43 in stream 0 using template query43.tpl +select + s_store_name, + s_store_id, + sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, + sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales, + sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales, + sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales, + sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales, + sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales, + sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales +from + date_dim, + store_sales, + store +where + d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + and s_gmt_offset = -5 + and d_year = 1998 + and ss_sold_date_sk between 2450816 and 2451179 -- partition key filter +group by + s_store_name, + s_store_id +order by + s_store_name, + s_store_id, + sun_sales, + mon_sales, + tue_sales, + wed_sales, + thu_sales, + fri_sales, + sat_sales +limit 100 +-- end query 43 in stream 0 using template query43.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql new file mode 100755 index 0000000000000..52b7ba4f4b86b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql @@ -0,0 +1,80 @@ +-- start query 46 in stream 0 using template query46.tpl +select + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + amt, + profit +from + (select + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + from + store_sales, + date_dim, + store, + household_demographics, + customer_address + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and store_sales.ss_addr_sk = customer_address.ca_address_sk + and (household_demographics.hd_dep_count = 5 + or household_demographics.hd_vehicle_count = 3) + and date_dim.d_dow in (6, 0) + and date_dim.d_year in (1999, 1999 + 1, 1999 + 2) + and store.s_city in ('Midway', 'Concord', 'Spring Hill', 'Brownsville', 'Greenville') + -- partition key filter + and ss_sold_date_sk in (2451181, 2451182, 2451188, 2451189, 2451195, 2451196, 2451202, 2451203, 2451209, 2451210, 2451216, 2451217, + 2451223, 2451224, 2451230, 2451231, 2451237, 2451238, 2451244, 2451245, 2451251, 2451252, 2451258, 2451259, + 2451265, 2451266, 2451272, 2451273, 2451279, 2451280, 2451286, 2451287, 2451293, 2451294, 2451300, 2451301, + 2451307, 2451308, 2451314, 2451315, 2451321, 2451322, 2451328, 2451329, 2451335, 2451336, 2451342, 2451343, + 2451349, 2451350, 2451356, 2451357, 2451363, 2451364, 2451370, 2451371, 2451377, 2451378, 2451384, 2451385, + 2451391, 2451392, 2451398, 2451399, 2451405, 2451406, 2451412, 2451413, 2451419, 2451420, 2451426, 2451427, + 2451433, 2451434, 2451440, 2451441, 2451447, 2451448, 2451454, 2451455, 2451461, 2451462, 2451468, 2451469, + 2451475, 2451476, 2451482, 2451483, 2451489, 2451490, 2451496, 2451497, 2451503, 2451504, 2451510, 2451511, + 2451517, 2451518, 2451524, 2451525, 2451531, 2451532, 2451538, 2451539, 2451545, 2451546, 2451552, 2451553, + 2451559, 2451560, 2451566, 2451567, 2451573, 2451574, 2451580, 2451581, 2451587, 2451588, 2451594, 2451595, + 2451601, 2451602, 2451608, 2451609, 2451615, 2451616, 2451622, 2451623, 2451629, 2451630, 2451636, 2451637, + 2451643, 2451644, 2451650, 2451651, 2451657, 2451658, 2451664, 2451665, 2451671, 2451672, 2451678, 2451679, + 2451685, 2451686, 2451692, 2451693, 2451699, 2451700, 2451706, 2451707, 2451713, 2451714, 2451720, 2451721, + 2451727, 2451728, 2451734, 2451735, 2451741, 2451742, 2451748, 2451749, 2451755, 2451756, 2451762, 2451763, + 2451769, 2451770, 2451776, 2451777, 2451783, 2451784, 2451790, 2451791, 2451797, 2451798, 2451804, 2451805, + 2451811, 2451812, 2451818, 2451819, 2451825, 2451826, 2451832, 2451833, 2451839, 2451840, 2451846, 2451847, + 2451853, 2451854, 2451860, 2451861, 2451867, 2451868, 2451874, 2451875, 2451881, 2451882, 2451888, 2451889, + 2451895, 2451896, 2451902, 2451903, 2451909, 2451910, 2451916, 2451917, 2451923, 2451924, 2451930, 2451931, + 2451937, 2451938, 2451944, 2451945, 2451951, 2451952, 2451958, 2451959, 2451965, 2451966, 2451972, 2451973, + 2451979, 2451980, 2451986, 2451987, 2451993, 2451994, 2452000, 2452001, 2452007, 2452008, 2452014, 2452015, + 2452021, 2452022, 2452028, 2452029, 2452035, 2452036, 2452042, 2452043, 2452049, 2452050, 2452056, 2452057, + 2452063, 2452064, 2452070, 2452071, 2452077, 2452078, 2452084, 2452085, 2452091, 2452092, 2452098, 2452099, + 2452105, 2452106, 2452112, 2452113, 2452119, 2452120, 2452126, 2452127, 2452133, 2452134, 2452140, 2452141, + 2452147, 2452148, 2452154, 2452155, 2452161, 2452162, 2452168, 2452169, 2452175, 2452176, 2452182, 2452183, + 2452189, 2452190, 2452196, 2452197, 2452203, 2452204, 2452210, 2452211, 2452217, 2452218, 2452224, 2452225, + 2452231, 2452232, 2452238, 2452239, 2452245, 2452246, 2452252, 2452253, 2452259, 2452260, 2452266, 2452267, + 2452273, 2452274) + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + ca_city + ) dn, + customer, + customer_address current_addr +where + ss_customer_sk = c_customer_sk + and customer.c_current_addr_sk = current_addr.ca_address_sk + and current_addr.ca_city <> bought_city +order by + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number +limit 100 +-- end query 46 in stream 0 using template query46.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql new file mode 100755 index 0000000000000..a510eefb13e17 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql @@ -0,0 +1,27 @@ +-- start query 52 in stream 0 using template query52.tpl +select + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_ext_sales_price) ext_price +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manager_id = 1 + and dt.d_moy = 12 + and dt.d_year = 1998 + and ss_sold_date_sk between 2451149 and 2451179 -- added for partition pruning +group by + dt.d_year, + item.i_brand, + item.i_brand_id +order by + dt.d_year, + ext_price desc, + brand_id +limit 100 +-- end query 52 in stream 0 using template query52.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql new file mode 100755 index 0000000000000..fb7bb75183858 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql @@ -0,0 +1,37 @@ +-- start query 53 in stream 0 using template query53.tpl +select + * +from + (select + i_manufact_id, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_manufact_id) avg_quarterly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_month_seq in (1212, 1212 + 1, 1212 + 2, 1212 + 3, 1212 + 4, 1212 + 5, 1212 + 6, 1212 + 7, 1212 + 8, 1212 + 9, 1212 + 10, 1212 + 11) + and ((i_category in ('Books', 'Children', 'Electronics') + and i_class in ('personal', 'portable', 'reference', 'self-help') + and i_brand in ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')) + or (i_category in ('Women', 'Music', 'Men') + and i_class in ('accessories', 'classical', 'fragrances', 'pants') + and i_brand in ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1'))) + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + i_manufact_id, + d_qoy + ) tmp1 +where + case when avg_quarterly_sales > 0 then abs (sum_sales - avg_quarterly_sales) / avg_quarterly_sales else null end > 0.1 +order by + avg_quarterly_sales, + sum_sales, + i_manufact_id +limit 100 +-- end query 53 in stream 0 using template query53.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql new file mode 100755 index 0000000000000..47b1f0292d901 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql @@ -0,0 +1,24 @@ +-- start query 55 in stream 0 using template query55.tpl +select + i_brand_id brand_id, + i_brand brand, + sum(ss_ext_sales_price) ext_price +from + date_dim, + store_sales, + item +where + d_date_sk = ss_sold_date_sk + and ss_item_sk = i_item_sk + and i_manager_id = 48 + and d_moy = 11 + and d_year = 2001 + and ss_sold_date_sk between 2452215 and 2452244 +group by + i_brand, + i_brand_id +order by + ext_price desc, + i_brand_id +limit 100 +-- end query 55 in stream 0 using template query55.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql new file mode 100755 index 0000000000000..3d5c4e9d64419 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql @@ -0,0 +1,83 @@ +-- start query 59 in stream 0 using template query59.tpl +with + wss as + (select + d_week_seq, + ss_store_sk, + sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, + sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales, + sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales, + sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales, + sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales, + sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales, + sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales + from + store_sales, + date_dim + where + d_date_sk = ss_sold_date_sk + group by + d_week_seq, + ss_store_sk + ) +select + s_store_name1, + s_store_id1, + d_week_seq1, + sun_sales1 / sun_sales2, + mon_sales1 / mon_sales2, + tue_sales1 / tue_sales1, + wed_sales1 / wed_sales2, + thu_sales1 / thu_sales2, + fri_sales1 / fri_sales2, + sat_sales1 / sat_sales2 +from + (select + s_store_name s_store_name1, + wss.d_week_seq d_week_seq1, + s_store_id s_store_id1, + sun_sales sun_sales1, + mon_sales mon_sales1, + tue_sales tue_sales1, + wed_sales wed_sales1, + thu_sales thu_sales1, + fri_sales fri_sales1, + sat_sales sat_sales1 + from + wss, + store, + date_dim d + where + d.d_week_seq = wss.d_week_seq + and ss_store_sk = s_store_sk + and d_month_seq between 1185 and 1185 + 11 + ) y, + (select + s_store_name s_store_name2, + wss.d_week_seq d_week_seq2, + s_store_id s_store_id2, + sun_sales sun_sales2, + mon_sales mon_sales2, + tue_sales tue_sales2, + wed_sales wed_sales2, + thu_sales thu_sales2, + fri_sales fri_sales2, + sat_sales sat_sales2 + from + wss, + store, + date_dim d + where + d.d_week_seq = wss.d_week_seq + and ss_store_sk = s_store_sk + and d_month_seq between 1185 + 12 and 1185 + 23 + ) x +where + s_store_id1 = s_store_id2 + and d_week_seq1 = d_week_seq2 - 52 +order by + s_store_name1, + s_store_id1, + d_week_seq1 +limit 100 +-- end query 59 in stream 0 using template query59.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql new file mode 100755 index 0000000000000..b71199ab17d0b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql @@ -0,0 +1,29 @@ +-- start query 63 in stream 0 using template query63.tpl +select * +from (select i_manager_id + ,sum(ss_sales_price) sum_sales + ,avg(sum(ss_sales_price)) over (partition by i_manager_id) avg_monthly_sales + from item + ,store_sales + ,date_dim + ,store + where ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_sold_date_sk between 2452123 and 2452487 + and ss_store_sk = s_store_sk + and d_month_seq in (1219,1219+1,1219+2,1219+3,1219+4,1219+5,1219+6,1219+7,1219+8,1219+9,1219+10,1219+11) + and (( i_category in ('Books','Children','Electronics') + and i_class in ('personal','portable','reference','self-help') + and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7', + 'exportiunivamalg #9','scholaramalgamalg #9')) + or( i_category in ('Women','Music','Men') + and i_class in ('accessories','classical','fragrances','pants') + and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1', + 'importoamalg #1'))) +group by i_manager_id, d_moy) tmp1 +where case when avg_monthly_sales > 0 then abs (sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 +order by i_manager_id + ,avg_monthly_sales + ,sum_sales +limit 100 +-- end query 63 in stream 0 using template query63.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql new file mode 100755 index 0000000000000..7344feeff6a9f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql @@ -0,0 +1,58 @@ +-- start query 65 in stream 0 using template query65.tpl +select + s_store_name, + i_item_desc, + sc.revenue, + i_current_price, + i_wholesale_cost, + i_brand +from + store, + item, + (select + ss_store_sk, + avg(revenue) as ave + from + (select + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) as revenue + from + store_sales, + date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + ss_store_sk, + ss_item_sk + ) sa + group by + ss_store_sk + ) sb, + (select + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) as revenue + from + store_sales, + date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + ss_store_sk, + ss_item_sk + ) sc +where + sb.ss_store_sk = sc.ss_store_sk + and sc.revenue <= 0.1 * sb.ave + and s_store_sk = sc.ss_store_sk + and i_item_sk = sc.ss_item_sk +order by + s_store_name, + i_item_desc +limit 100 +-- end query 65 in stream 0 using template query65.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql new file mode 100755 index 0000000000000..94df4b3f57a90 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql @@ -0,0 +1,62 @@ +-- start query 68 in stream 0 using template query68.tpl +-- changed to match exact same partitions in original query +select + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + extended_price, + extended_tax, + list_price +from + (select + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_ext_sales_price) extended_price, + sum(ss_ext_list_price) list_price, + sum(ss_ext_tax) extended_tax + from + store_sales, + date_dim, + store, + household_demographics, + customer_address + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and store_sales.ss_addr_sk = customer_address.ca_address_sk + and date_dim.d_dom between 1 and 2 + and (household_demographics.hd_dep_count = 5 + or household_demographics.hd_vehicle_count = 3) + and date_dim.d_year in (1999, 1999 + 1, 1999 + 2) + and store.s_city in ('Midway', 'Fairview') + -- partition key filter + and ss_sold_date_sk in (2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301, 2451331, + 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484, 2451485, + 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637, 2451666, + 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819, 2451820, + 2451850, 2451851, 2451880, 2451881, 2451911, 2451912, 2451942, 2451943, 2451970, 2451971, 2452001, + 2452002, 2452031, 2452032, 2452062, 2452063, 2452092, 2452093, 2452123, 2452124, 2452154, 2452155, + 2452184, 2452185, 2452215, 2452216, 2452245, 2452246) + --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months) + --and d_date between '1999-01-01' and '1999-03-31' + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + ca_city + ) dn, + customer, + customer_address current_addr +where + ss_customer_sk = c_customer_sk + and customer.c_current_addr_sk = current_addr.ca_address_sk + and current_addr.ca_city <> bought_city +order by + c_last_name, + ss_ticket_number +limit 100 +-- end query 68 in stream 0 using template query68.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql new file mode 100755 index 0000000000000..c61a2d0d2a8fa --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql @@ -0,0 +1,31 @@ +-- start query 7 in stream 0 using template query7.tpl +select + i_item_id, + avg(ss_quantity) agg1, + avg(ss_list_price) agg2, + avg(ss_coupon_amt) agg3, + avg(ss_sales_price) agg4 +from + store_sales, + customer_demographics, + date_dim, + item, + promotion +where + ss_sold_date_sk = d_date_sk + and ss_item_sk = i_item_sk + and ss_cdemo_sk = cd_demo_sk + and ss_promo_sk = p_promo_sk + and cd_gender = 'F' + and cd_marital_status = 'W' + and cd_education_status = 'Primary' + and (p_channel_email = 'N' + or p_channel_event = 'N') + and d_year = 1998 + and ss_sold_date_sk between 2450815 and 2451179 -- partition key filter +group by + i_item_id +order by + i_item_id +limit 100 +-- end query 7 in stream 0 using template query7.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql new file mode 100755 index 0000000000000..8703910b305a8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql @@ -0,0 +1,49 @@ +-- start query 73 in stream 0 using template query73.tpl +select + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +from + (select + ss_ticket_number, + ss_customer_sk, + count(*) cnt + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and date_dim.d_dom between 1 and 2 + and (household_demographics.hd_buy_potential = '>10000' + or household_demographics.hd_buy_potential = 'Unknown') + and household_demographics.hd_vehicle_count > 0 + and case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end > 1 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_county in ('Fairfield County','Ziebach County','Bronx County','Barrow County') + -- partition key filter + and ss_sold_date_sk in (2450815, 2450816, 2450846, 2450847, 2450874, 2450875, 2450905, 2450906, 2450935, 2450936, 2450966, 2450967, + 2450996, 2450997, 2451027, 2451028, 2451058, 2451059, 2451088, 2451089, 2451119, 2451120, 2451149, + 2451150, 2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301, + 2451331, 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484, + 2451485, 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637, + 2451666, 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819, + 2451820, 2451850, 2451851, 2451880, 2451881) + --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months) + group by + ss_ticket_number, + ss_customer_sk + ) dj, + customer +where + ss_customer_sk = c_customer_sk + and cnt between 1 and 5 +order by + cnt desc +-- end query 73 in stream 0 using template query73.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql new file mode 100755 index 0000000000000..4254310ecd10b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql @@ -0,0 +1,59 @@ +-- start query 79 in stream 0 using template query79.tpl +select + c_last_name, + c_first_name, + substr(s_city, 1, 30), + ss_ticket_number, + amt, + profit +from + (select + ss_ticket_number, + ss_customer_sk, + store.s_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and (household_demographics.hd_dep_count = 8 + or household_demographics.hd_vehicle_count > 0) + and date_dim.d_dow = 1 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_number_employees between 200 and 295 + and ss_sold_date_sk between 2450819 and 2451904 + -- partition key filter + --and ss_sold_date_sk in (2450819, 2450826, 2450833, 2450840, 2450847, 2450854, 2450861, 2450868, 2450875, 2450882, 2450889, + -- 2450896, 2450903, 2450910, 2450917, 2450924, 2450931, 2450938, 2450945, 2450952, 2450959, 2450966, 2450973, 2450980, 2450987, + -- 2450994, 2451001, 2451008, 2451015, 2451022, 2451029, 2451036, 2451043, 2451050, 2451057, 2451064, 2451071, 2451078, 2451085, + -- 2451092, 2451099, 2451106, 2451113, 2451120, 2451127, 2451134, 2451141, 2451148, 2451155, 2451162, 2451169, 2451176, 2451183, + -- 2451190, 2451197, 2451204, 2451211, 2451218, 2451225, 2451232, 2451239, 2451246, 2451253, 2451260, 2451267, 2451274, 2451281, + -- 2451288, 2451295, 2451302, 2451309, 2451316, 2451323, 2451330, 2451337, 2451344, 2451351, 2451358, 2451365, 2451372, 2451379, + -- 2451386, 2451393, 2451400, 2451407, 2451414, 2451421, 2451428, 2451435, 2451442, 2451449, 2451456, 2451463, 2451470, 2451477, + -- 2451484, 2451491, 2451498, 2451505, 2451512, 2451519, 2451526, 2451533, 2451540, 2451547, 2451554, 2451561, 2451568, 2451575, + -- 2451582, 2451589, 2451596, 2451603, 2451610, 2451617, 2451624, 2451631, 2451638, 2451645, 2451652, 2451659, 2451666, 2451673, + -- 2451680, 2451687, 2451694, 2451701, 2451708, 2451715, 2451722, 2451729, 2451736, 2451743, 2451750, 2451757, 2451764, 2451771, + -- 2451778, 2451785, 2451792, 2451799, 2451806, 2451813, 2451820, 2451827, 2451834, 2451841, 2451848, 2451855, 2451862, 2451869, + -- 2451876, 2451883, 2451890, 2451897, 2451904) + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + store.s_city + ) ms, + customer +where + ss_customer_sk = c_customer_sk +order by + c_last_name, + c_first_name, + substr(s_city, 1, 30), + profit + limit 100 +-- end query 79 in stream 0 using template query79.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql new file mode 100755 index 0000000000000..b1d814af5e57a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql @@ -0,0 +1,43 @@ +-- start query 89 in stream 0 using template query89.tpl +select + * +from + (select + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_category, i_brand, s_store_name, s_company_name) avg_monthly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_year in (2000) + and ((i_category in ('Home', 'Books', 'Electronics') + and i_class in ('wallpaper', 'parenting', 'musical')) + or (i_category in ('Shoes', 'Jewelry', 'Men') + and i_class in ('womens', 'birdal', 'pants'))) + and ss_sold_date_sk between 2451545 and 2451910 -- partition key filter + group by + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy + ) tmp1 +where + case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1 +order by + sum_sales - avg_monthly_sales, + s_store_name +limit 100 +-- end query 89 in stream 0 using template query89.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql new file mode 100755 index 0000000000000..f53f2f5f9c5b6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql @@ -0,0 +1,32 @@ +-- start query 98 in stream 0 using template query98.tpl +select + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) as itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) over (partition by i_class) as revenueratio +from + store_sales, + item, + date_dim +where + ss_item_sk = i_item_sk + and i_category in ('Jewelry', 'Sports', 'Books') + and ss_sold_date_sk = d_date_sk + and ss_sold_date_sk between 2451911 and 2451941 -- partition key filter (1 calendar month) + and d_date between '2001-01-01' and '2001-01-31' +group by + i_item_id, + i_item_desc, + i_category, + i_class, + i_current_price +order by + i_category, + i_class, + i_item_id, + i_item_desc, + revenueratio +--limit 1000; -- added limit +-- end query 98 in stream 0 using template query98.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql new file mode 100755 index 0000000000000..bf58b4bb3c5a5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql @@ -0,0 +1,14 @@ +select + count(*) as total, + count(ss_sold_date_sk) as not_null_total, + count(distinct ss_sold_date_sk) as unique_days, + max(ss_sold_date_sk) as max_ss_sold_date_sk, + max(ss_sold_time_sk) as max_ss_sold_time_sk, + max(ss_item_sk) as max_ss_item_sk, + max(ss_customer_sk) as max_ss_customer_sk, + max(ss_cdemo_sk) as max_ss_cdemo_sk, + max(ss_hdemo_sk) as max_ss_hdemo_sk, + max(ss_addr_sk) as max_ss_addr_sk, + max(ss_store_sk) as max_ss_store_sk, + max(ss_promo_sk) as max_ss_promo_sk +from store_sales diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index c0797fa55f5da..e47d4b0ee25d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -22,9 +22,18 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils +/** + * This test suite ensures all the TPC-DS queries can be successfully analyzed and optimized + * without hitting the max iteration threshold. + */ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { + // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting + // the max iteration of analyzer/optimizer batches. + assert(Utils.isTesting, "spark.testing is not set to true") + /** * Drop all the tables */ @@ -341,8 +350,23 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte classLoader = Thread.currentThread().getContextClassLoader) test(name) { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - sql(queryString).collect() + // Just check the plans can be properly generated + sql(queryString).queryExecution.executedPlan } } } + + // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries + val modifiedTPCDSQueries = Seq( + "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", + "q63", "q65", "q68", "q73", "q79", "q89", "q98", "ss_max") + + modifiedTPCDSQueries.foreach { name => + val queryString = resourceToString(s"tpcds-modifiedQueries/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"modified-$name") { + // Just check the plans can be properly generated + sql(queryString).queryExecution.executedPlan + } + } } From 472864014c42da08b9d3f3fffbe657c6fcf1e2ef Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 29 Sep 2017 11:45:58 -0700 Subject: [PATCH 647/779] Revert "[SPARK-22142][BUILD][STREAMING] Move Flume support behind a profile" This reverts commit a2516f41aef68e39df7f6380fd2618cc148a609e. --- dev/create-release/release-build.sh | 4 ++-- dev/mima | 2 +- dev/scalastyle | 1 - dev/sparktestsupport/modules.py | 20 +------------------- dev/test-dependencies.sh | 2 +- docs/building-spark.md | 6 ------ pom.xml | 13 +++---------- project/SparkBuild.scala | 17 ++++++++--------- python/pyspark/streaming/tests.py | 16 +++------------- 9 files changed, 19 insertions(+), 62 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 7e8d5c7075195..5390f5916fc0d 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -84,9 +84,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/mima b/dev/mima index 1e3ca9700bc07..fdb21f5007cf2 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index 89ecc8abd6f8c..e5aa589869535 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -25,7 +25,6 @@ ERRORS=$(echo -e "q\n" \ -Pmesos \ -Pkafka-0-8 \ -Pyarn \ - -Pflume \ -Phive \ -Phive-thriftserver \ scalastyle test:scalastyle \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 91d5667ed1f07..50e14b60545af 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -279,12 +279,6 @@ def __hash__(self): source_file_regexes=[ "external/flume-sink", ], - build_profile_flags=[ - "-Pflume", - ], - environ={ - "ENABLE_FLUME_TESTS": "1" - }, sbt_test_goals=[ "streaming-flume-sink/test", ] @@ -297,12 +291,6 @@ def __hash__(self): source_file_regexes=[ "external/flume", ], - build_profile_flags=[ - "-Pflume", - ], - environ={ - "ENABLE_FLUME_TESTS": "1" - }, sbt_test_goals=[ "streaming-flume/test", ] @@ -314,13 +302,7 @@ def __hash__(self): dependencies=[streaming_flume, streaming_flume_sink], source_file_regexes=[ "external/flume-assembly", - ], - build_profile_flags=[ - "-Pflume", - ], - environ={ - "ENABLE_FLUME_TESTS": "1" - } + ] ) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 58b295d4f6e00..c7714578bd005 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 diff --git a/docs/building-spark.md b/docs/building-spark.md index e1532de16108d..57baa503259c1 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -100,12 +100,6 @@ Note: Kafka 0.8 support is deprecated as of Spark 2.3.0. Kafka 0.10 support is still automatically built. -## Building with Flume support - -Apache Flume support must be explicitly enabled with the `flume` profile. - - ./build/mvn -Pflume -DskipTests clean package - ## Building submodules individually It's possible to build Spark sub-modules using the `mvn -pl` option. diff --git a/pom.xml b/pom.xml index 9fac8b1e53788..87a468c3a6f55 100644 --- a/pom.xml +++ b/pom.xml @@ -98,13 +98,15 @@ sql/core sql/hive assembly + external/flume + external/flume-sink + external/flume-assembly examples repl launcher external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql - @@ -2581,15 +2583,6 @@ - - flume - - external/flume - external/flume-sink - external/flume-assembly - - - spark-ganglia-lgpl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9501eed1e906b..a568d264cb2db 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,8 +43,11 @@ object BuildCommons { "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) - val streamingProjects@Seq(streaming, streamingKafka010) = - Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) + val streamingProjects@Seq( + streaming, streamingFlumeSink, streamingFlume, streamingKafka010 + ) = Seq( + "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-10" + ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* @@ -53,13 +56,9 @@ object BuildCommons { "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects - val optionallyEnabledProjects@Seq(mesos, yarn, - streamingFlumeSink, streamingFlume, - streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, - dockerIntegrationTests, hadoopCloud) = - Seq("mesos", "yarn", - "streaming-flume-sink", "streaming-flume", - "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", + val optionallyEnabledProjects@Seq(mesos, yarn, streamingKafka, sparkGangliaLgpl, + streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = + Seq("mesos", "yarn", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5b86c1cb2c390..229cf53e47359 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1478,7 +1478,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn -Pkafka-0-8 package' before running this test.") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1495,7 +1495,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn -Pflume package' before running this test.") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1516,9 +1516,6 @@ def search_kinesis_asl_assembly_jar(): return jars[0] -# Must be same as the variable and condition defined in modules.py -flume_test_environ_var = "ENABLE_FLUME_TESTS" -are_flume_tests_enabled = os.environ.get(flume_test_environ_var) == '1' # Must be same as the variable and condition defined in modules.py kafka_test_environ_var = "ENABLE_KAFKA_0_8_TESTS" are_kafka_tests_enabled = os.environ.get(kafka_test_environ_var) == '1' @@ -1541,16 +1538,9 @@ def search_kinesis_asl_assembly_jar(): os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, + FlumeStreamTests, FlumePollingStreamTests, StreamingListenerTests] - if are_flume_tests_enabled: - testcases.append(FlumeStreamTests) - testcases.append(FlumePollingStreamTests) - else: - sys.stderr.write( - "Skipped test_flume_stream (enable by setting environment variable %s=1" - % flume_test_environ_var) - if are_kafka_tests_enabled: testcases.append(KafkaStreamTests) else: From 530fe683297cb11b920a4df6630eff5d7e7ddce2 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 29 Sep 2017 19:35:32 -0700 Subject: [PATCH 648/779] [SPARK-21904][SQL] Rename tempTables to tempViews in SessionCatalog ### What changes were proposed in this pull request? `tempTables` is not right. To be consistent, we need to rename the internal variable names/comments to tempViews in SessionCatalog too. ### How was this patch tested? N/A Author: gatorsmile Closes #19117 from gatorsmile/renameTempTablesToTempViews. --- .../sql/catalyst/catalog/SessionCatalog.scala | 79 +++++++++---------- .../sql/execution/command/DDLSuite.scala | 10 +-- 2 files changed, 43 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 9407b727bca4c..6ba9ee5446a01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.catalog -import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.Locale import java.util.concurrent.Callable @@ -25,7 +24,6 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.util.{Failure, Success, Try} -import scala.util.control.NonFatal import com.google.common.cache.{Cache, CacheBuilder} import org.apache.hadoop.conf.Configuration @@ -41,7 +39,6 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -52,7 +49,7 @@ object SessionCatalog { /** * An internal catalog that is used by a Spark Session. This internal catalog serves as a * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary - * tables and functions of the Spark Session that it belongs to. + * views and functions of the Spark Session that it belongs to. * * This class must be thread-safe. */ @@ -90,13 +87,13 @@ class SessionCatalog( new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } - /** List of temporary tables, mapping from table name to their logical plan. */ + /** List of temporary views, mapping from table name to their logical plan. */ @GuardedBy("this") - protected val tempTables = new mutable.HashMap[String, LogicalPlan] + protected val tempViews = new mutable.HashMap[String, LogicalPlan] // Note: we track current database here because certain operations do not explicitly // specify the database (e.g. DROP TABLE my_table). In these cases we must first - // check whether the temporary table or function exists, then, if not, operate on + // check whether the temporary view or function exists, then, if not, operate on // the corresponding item in the current database. @GuardedBy("this") protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) @@ -272,8 +269,8 @@ class SessionCatalog( // ---------------------------------------------------------------------------- // Tables // ---------------------------------------------------------------------------- - // There are two kinds of tables, temporary tables and metastore tables. - // Temporary tables are isolated across sessions and do not belong to any + // There are two kinds of tables, temporary views and metastore tables. + // Temporary views are isolated across sessions and do not belong to any // particular database. Metastore tables can be used across multiple // sessions as their metadata is persisted in the underlying catalog. // ---------------------------------------------------------------------------- @@ -462,10 +459,10 @@ class SessionCatalog( tableDefinition: LogicalPlan, overrideIfExists: Boolean): Unit = synchronized { val table = formatTableName(name) - if (tempTables.contains(table) && !overrideIfExists) { + if (tempViews.contains(table) && !overrideIfExists) { throw new TempTableAlreadyExistsException(name) } - tempTables.put(table, tableDefinition) + tempViews.put(table, tableDefinition) } /** @@ -487,7 +484,7 @@ class SessionCatalog( viewDefinition: LogicalPlan): Boolean = synchronized { val viewName = formatTableName(name.table) if (name.database.isEmpty) { - if (tempTables.contains(viewName)) { + if (tempViews.contains(viewName)) { createTempView(viewName, viewDefinition, overrideIfExists = true) true } else { @@ -504,7 +501,7 @@ class SessionCatalog( * Return a local temporary view exactly as it was stored. */ def getTempView(name: String): Option[LogicalPlan] = synchronized { - tempTables.get(formatTableName(name)) + tempViews.get(formatTableName(name)) } /** @@ -520,7 +517,7 @@ class SessionCatalog( * Returns true if this view is dropped successfully, false otherwise. */ def dropTempView(name: String): Boolean = synchronized { - tempTables.remove(formatTableName(name)).isDefined + tempViews.remove(formatTableName(name)).isDefined } /** @@ -572,7 +569,7 @@ class SessionCatalog( * Rename a table. * * If a database is specified in `oldName`, this will rename the table in that database. - * If no database is specified, this will first attempt to rename a temporary table with + * If no database is specified, this will first attempt to rename a temporary view with * the same name, then, if that does not exist, rename the table in the current database. * * This assumes the database specified in `newName` matches the one in `oldName`. @@ -592,7 +589,7 @@ class SessionCatalog( globalTempViewManager.rename(oldTableName, newTableName) } else { requireDbExists(db) - if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { + if (oldName.database.isDefined || !tempViews.contains(oldTableName)) { requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) validateName(newTableName) @@ -600,16 +597,16 @@ class SessionCatalog( } else { if (newName.database.isDefined) { throw new AnalysisException( - s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': cannot specify database " + + s"RENAME TEMPORARY VIEW from '$oldName' to '$newName': cannot specify database " + s"name '${newName.database.get}' in the destination table") } - if (tempTables.contains(newTableName)) { - throw new AnalysisException(s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': " + + if (tempViews.contains(newTableName)) { + throw new AnalysisException(s"RENAME TEMPORARY VIEW from '$oldName' to '$newName': " + "destination table already exists") } - val table = tempTables(oldTableName) - tempTables.remove(oldTableName) - tempTables.put(newTableName, table) + val table = tempViews(oldTableName) + tempViews.remove(oldTableName) + tempViews.put(newTableName, table) } } } @@ -618,7 +615,7 @@ class SessionCatalog( * Drop a table. * * If a database is specified in `name`, this will drop the table from that database. - * If no database is specified, this will first attempt to drop a temporary table with + * If no database is specified, this will first attempt to drop a temporary view with * the same name, then, if that does not exist, drop the table from the current database. */ def dropTable( @@ -633,7 +630,7 @@ class SessionCatalog( throw new NoSuchTableException(globalTempViewManager.database, table) } } else { - if (name.database.isDefined || !tempTables.contains(table)) { + if (name.database.isDefined || !tempViews.contains(table)) { requireDbExists(db) // When ignoreIfNotExists is false, no exception is issued when the table does not exist. // Instead, log it as an error message. @@ -643,7 +640,7 @@ class SessionCatalog( throw new NoSuchTableException(db = db, table = table) } } else { - tempTables.remove(table) + tempViews.remove(table) } } } @@ -652,7 +649,7 @@ class SessionCatalog( * Return a [[LogicalPlan]] that represents the given table or view. * * If a database is specified in `name`, this will return the table/view from that database. - * If no database is specified, this will first attempt to return a temporary table/view with + * If no database is specified, this will first attempt to return a temporary view with * the same name, then, if that does not exist, return the table/view from the current database. * * Note that, the global temp view database is also valid here, this will return the global temp @@ -671,7 +668,7 @@ class SessionCatalog( globalTempViewManager.get(table).map { viewDef => SubqueryAlias(table, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) - } else if (name.database.isDefined || !tempTables.contains(table)) { + } else if (name.database.isDefined || !tempViews.contains(table)) { val metadata = externalCatalog.getTable(db, table) if (metadata.tableType == CatalogTableType.VIEW) { val viewText = metadata.viewText.getOrElse(sys.error("Invalid view without text.")) @@ -687,21 +684,21 @@ class SessionCatalog( SubqueryAlias(table, UnresolvedCatalogRelation(metadata)) } } else { - SubqueryAlias(table, tempTables(table)) + SubqueryAlias(table, tempViews(table)) } } } /** - * Return whether a table with the specified name is a temporary table. + * Return whether a table with the specified name is a temporary view. * - * Note: The temporary table cache is checked only when database is not + * Note: The temporary view cache is checked only when database is not * explicitly specified. */ def isTemporaryTable(name: TableIdentifier): Boolean = synchronized { val table = formatTableName(name.table) if (name.database.isEmpty) { - tempTables.contains(table) + tempViews.contains(table) } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { globalTempViewManager.get(table).isDefined } else { @@ -710,7 +707,7 @@ class SessionCatalog( } /** - * List all tables in the specified database, including local temporary tables. + * List all tables in the specified database, including local temporary views. * * Note that, if the specified database is global temporary view database, we will list global * temporary views. @@ -718,7 +715,7 @@ class SessionCatalog( def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*") /** - * List all matching tables in the specified database, including local temporary tables. + * List all matching tables in the specified database, including local temporary views. * * Note that, if the specified database is global temporary view database, we will list global * temporary views. @@ -736,7 +733,7 @@ class SessionCatalog( } } val localTempViews = synchronized { - StringUtils.filterPattern(tempTables.keys.toSeq, pattern).map { name => + StringUtils.filterPattern(tempViews.keys.toSeq, pattern).map { name => TableIdentifier(name) } } @@ -750,11 +747,11 @@ class SessionCatalog( val dbName = formatDatabaseName(name.database.getOrElse(currentDb)) val tableName = formatTableName(name.table) - // Go through temporary tables and invalidate them. + // Go through temporary views and invalidate them. // If the database is defined, this may be a global temporary view. - // If the database is not defined, there is a good chance this is a temp table. + // If the database is not defined, there is a good chance this is a temp view. if (name.database.isEmpty) { - tempTables.get(tableName).foreach(_.refresh()) + tempViews.get(tableName).foreach(_.refresh()) } else if (dbName == globalTempViewManager.database) { globalTempViewManager.get(tableName).foreach(_.refresh()) } @@ -765,11 +762,11 @@ class SessionCatalog( } /** - * Drop all existing temporary tables. + * Drop all existing temporary views. * For testing only. */ def clearTempTables(): Unit = synchronized { - tempTables.clear() + tempViews.clear() } // ---------------------------------------------------------------------------- @@ -1337,7 +1334,7 @@ class SessionCatalog( */ private[sql] def copyStateTo(target: SessionCatalog): Unit = synchronized { target.currentDb = currentDb - // copy over temporary tables - tempTables.foreach(kv => target.tempTables.put(kv._1, kv._2)) + // copy over temporary views + tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index d19cfeef7d19f..4ed2cecc5faff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -795,7 +795,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("teachers"), df) } - test("rename temporary table - destination table with database name") { + test("rename temporary view - destination table with database name") { withTempView("tab1") { sql( """ @@ -812,7 +812,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE tab1 RENAME TO default.tab2") } assert(e.getMessage.contains( - "RENAME TEMPORARY TABLE from '`tab1`' to '`default`.`tab2`': " + + "RENAME TEMPORARY VIEW from '`tab1`' to '`default`.`tab2`': " + "cannot specify database name 'default' in the destination table")) val catalog = spark.sessionState.catalog @@ -820,7 +820,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("rename temporary table") { + test("rename temporary view") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") sql("ALTER TABLE tab1 RENAME TO tab2") @@ -832,7 +832,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("rename temporary table - destination table already exists") { + test("rename temporary view - destination table already exists") { withTempView("tab1", "tab2") { sql( """ @@ -860,7 +860,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE tab1 RENAME TO tab2") } assert(e.getMessage.contains( - "RENAME TEMPORARY TABLE from '`tab1`' to '`tab2`': destination table already exists")) + "RENAME TEMPORARY VIEW from '`tab1`' to '`tab2`': destination table already exists")) val catalog = spark.sessionState.catalog assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"), TableIdentifier("tab2"))) From c6610a997f69148a1f1bbf69360e8f39e24cb70a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 29 Sep 2017 21:36:52 -0700 Subject: [PATCH 649/779] [SPARK-22122][SQL] Use analyzed logical plans to count input rows in TPCDSQueryBenchmark ## What changes were proposed in this pull request? Since the current code ignores WITH clauses to check input relations in TPCDS queries, this leads to inaccurate per-row processing time for benchmark results. For example, in `q2`, this fix could catch all the input relations: `web_sales`, `date_dim`, and `catalog_sales` (the current code catches `date_dim` only). The one-third of the TPCDS queries uses WITH clauses, so I think it is worth fixing this. ## How was this patch tested? Manually checked. Author: Takeshi Yamamuro Closes #19344 from maropu/RespectWithInTPCDSBench. --- .../benchmark/TPCDSQueryBenchmark.scala | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 99c6df7389205..69247d7f4e9aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.util.Benchmark /** @@ -66,24 +65,15 @@ object TPCDSQueryBenchmark extends Logging { classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the - // logical plan and adding up the sizes of all tables that appear in the plan. Note that this - // currently doesn't take WITH subqueries into account which might lead to fairly inaccurate - // per-row processing time for those cases. + // logical plan and adding up the sizes of all tables that appear in the plan. val queryRelations = scala.collection.mutable.HashSet[String]() - spark.sql(queryString).queryExecution.logical.map { - case UnresolvedRelation(t: TableIdentifier) => - queryRelations.add(t.table) - case lp: LogicalPlan => - lp.expressions.foreach { _ foreach { - case subquery: SubqueryExpression => - subquery.plan.foreach { - case UnresolvedRelation(t: TableIdentifier) => - queryRelations.add(t.table) - case _ => - } - case _ => - } - } + spark.sql(queryString).queryExecution.analyzed.foreach { + case SubqueryAlias(alias, _: LogicalRelation) => + queryRelations.add(alias) + case LogicalRelation(_, _, Some(catalogTable), _) => + queryRelations.add(catalogTable.identifier.table) + case HiveTableRelation(tableMeta, _, _) => + queryRelations.add(tableMeta.identifier.table) case _ => } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum From 02c91e03f975c2a6a05a9d5327057bb6b3c4a66f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 1 Oct 2017 18:42:45 +0900 Subject: [PATCH 650/779] [SPARK-22063][R] Fixes lint check failures in R by latest commit sha1 ID of lint-r ## What changes were proposed in this pull request? Currently, we set lintr to jimhester/lintra769c0b (see [this](https://github.com/apache/spark/commit/7d1175011c976756efcd4e4e4f70a8fd6f287026) and [SPARK-14074](https://issues.apache.org/jira/browse/SPARK-14074)). I first tested and checked lintr-1.0.1 but it looks many important fixes are missing (for example, checking 100 length). So, I instead tried the latest commit, https://github.com/jimhester/lintr/commit/5431140ffea65071f1327625d4a8de9688fa7e72, in my local and fixed the check failures. It looks it has fixed many bugs and now finds many instances that I have observed and thought should be caught time to time, here I filed [the results](https://gist.github.com/HyukjinKwon/4f59ddcc7b6487a02da81800baca533c). The downside looks it now takes about 7ish mins, (it was 2ish mins before) in my local. ## How was this patch tested? Manually, `./dev/lint-r` after manually updating the lintr package. Author: hyukjinkwon Author: zuotingbing Closes #19290 from HyukjinKwon/upgrade-r-lint. --- R/pkg/.lintr | 2 +- R/pkg/R/DataFrame.R | 30 ++-- R/pkg/R/RDD.R | 6 +- R/pkg/R/WindowSpec.R | 2 +- R/pkg/R/column.R | 2 + R/pkg/R/context.R | 2 +- R/pkg/R/deserialize.R | 2 +- R/pkg/R/functions.R | 79 ++++++----- R/pkg/R/generics.R | 4 +- R/pkg/R/group.R | 4 +- R/pkg/R/mllib_classification.R | 137 +++++++++++-------- R/pkg/R/mllib_clustering.R | 15 +- R/pkg/R/mllib_regression.R | 62 +++++---- R/pkg/R/mllib_tree.R | 36 +++-- R/pkg/R/pairRDD.R | 4 +- R/pkg/R/schema.R | 2 +- R/pkg/R/stats.R | 14 +- R/pkg/R/utils.R | 4 +- R/pkg/inst/worker/worker.R | 2 +- R/pkg/tests/fulltests/test_binary_function.R | 2 +- R/pkg/tests/fulltests/test_rdd.R | 6 +- R/pkg/tests/fulltests/test_sparkSQL.R | 14 +- dev/lint-r.R | 4 +- 23 files changed, 242 insertions(+), 193 deletions(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index ae50b28ec6166..c83ad2adfe0ef 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, object_name_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0728141fa483e..176bb3b8a8d0c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1923,13 +1923,15 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param i,subset (Optional) a logical expression to filter on rows. #' For extract operator [[ and replacement operator [[<-, the indexing parameter for #' a single Column. -#' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame. +#' @param j,select expression for the single Column or a list of columns to select from the +#' SparkDataFrame. #' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column. #' Otherwise, a SparkDataFrame will always be returned. #' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. #' If \code{NULL}, the specified Column is dropped. #' @param ... currently not used. -#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns. +#' @return A new SparkDataFrame containing only the rows that meet the condition with selected +#' columns. #' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method @@ -2608,12 +2610,12 @@ setMethod("merge", } else { # if by or both by.x and by.y have length 0, use Cartesian Product joinRes <- crossJoin(x, y) - return (joinRes) + return(joinRes) } # sets alias for making colnames unique in dataframes 'x' and 'y' - colsX <- generateAliasesForIntersectedCols(x, by, suffixes[1]) - colsY <- generateAliasesForIntersectedCols(y, by, suffixes[2]) + colsX <- genAliasesForIntersectedCols(x, by, suffixes[1]) + colsY <- genAliasesForIntersectedCols(y, by, suffixes[2]) # selects columns with their aliases from dataframes # in case same column names are present in both data frames @@ -2661,9 +2663,8 @@ setMethod("merge", #' @param intersectedColNames a list of intersected column names of the SparkDataFrame #' @param suffix a suffix for the column name #' @return list of columns -#' -#' @note generateAliasesForIntersectedCols since 1.6.0 -generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { +#' @noRd +genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) { allColNames <- names(x) # sets alias for making colnames unique in dataframe 'x' cols <- lapply(allColNames, function(colName) { @@ -2671,7 +2672,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { if (colName %in% intersectedColNames) { newJoin <- paste(colName, suffix, sep = "") if (newJoin %in% allColNames){ - stop ("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", + stop("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.") } col <- alias(col, newJoin) @@ -3058,7 +3059,8 @@ setMethod("describe", #' summary(select(df, "age", "height")) #' } #' @note summary(SparkDataFrame) since 1.5.0 -#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for previous defaults. +#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for +#' previous defaults. #' @seealso \link{describe} setMethod("summary", signature(object = "SparkDataFrame"), @@ -3765,8 +3767,8 @@ setMethod("checkpoint", #' #' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. #' -#' If grouping expression is missing \code{cube} creates a single global aggregate and is equivalent to -#' direct application of \link{agg}. +#' If grouping expression is missing \code{cube} creates a single global aggregate and is +#' equivalent to direct application of \link{agg}. #' #' @param x a SparkDataFrame. #' @param ... character name(s) or Column(s) to group on. @@ -3800,8 +3802,8 @@ setMethod("cube", #' #' Create a multi-dimensional rollup for the SparkDataFrame using the specified columns. #' -#' If grouping expression is missing \code{rollup} creates a single global aggregate and is equivalent to -#' direct application of \link{agg}. +#' If grouping expression is missing \code{rollup} creates a single global aggregate and is +#' equivalent to direct application of \link{agg}. #' #' @param x a SparkDataFrame. #' @param ... character name(s) or Column(s) to group on. diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 15ca212acf87f..6e89b4bb4d964 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -131,7 +131,7 @@ PipelinedRDD <- function(prev, func) { # Return the serialization mode for an RDD. setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) # For normal RDDs we can directly read the serializedMode -setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode) # For pipelined RDDs if jrdd_val is set then serializedMode should exist # if not we return the defaultSerialization mode of "byte" as we don't know the serialization # mode at this point in time. @@ -145,7 +145,7 @@ setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), }) # The jrdd accessor function. -setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd) setMethod("getJRDD", signature(rdd = "PipelinedRDD"), function(rdd, serializedMode = "byte") { if (!is.null(rdd@env$jrdd_val)) { @@ -893,7 +893,7 @@ setMethod("sampleRDD", if (withReplacement) { count <- stats::rpois(1, fraction) if (count > 0) { - res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + res[(len + 1) : (len + count)] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index 81beac9ea9925..debc7cbde55e7 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -73,7 +73,7 @@ setMethod("show", "WindowSpec", setMethod("partitionBy", signature(x = "WindowSpec"), function(x, col, ...) { - stopifnot (class(col) %in% c("character", "Column")) + stopifnot(class(col) %in% c("character", "Column")) if (class(col) == "character") { windowSpec(callJMethod(x@sws, "partitionBy", col, list(...))) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index a5c2ea81f2490..3095adb918b67 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -238,8 +238,10 @@ setMethod("between", signature(x = "Column"), #' @param x a Column. #' @param dataType a character object describing the target data type. #' See +# nolint start #' \href{https://spark.apache.org/docs/latest/sparkr.html#data-type-mapping-between-r-and-spark}{ #' Spark Data Types} for available data types. +# nolint end #' @rdname cast #' @name cast #' @family colum_func diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8349b57a30a93..443c2ff8f9ace 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -329,7 +329,7 @@ spark.addFile <- function(path, recursive = FALSE) { #' spark.getSparkFilesRootDirectory() #'} #' @note spark.getSparkFilesRootDirectory since 2.1.0 -spark.getSparkFilesRootDirectory <- function() { +spark.getSparkFilesRootDirectory <- function() { # nolint if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { # Running on driver. callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 0e99b171cabeb..a90f7d381026b 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -43,7 +43,7 @@ readObject <- function(con) { } readTypedObject <- function(con, type) { - switch (type, + switch(type, "i" = readInt(con), "c" = readString(con), "b" = readBoolean(con), diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9f286263c2162..0143a3e63ba61 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -38,7 +38,8 @@ NULL #' #' Date time functions defined for \code{Column}. #' -#' @param x Column to compute on. In \code{window}, it must be a time Column of \code{TimestampType}. +#' @param x Column to compute on. In \code{window}, it must be a time Column of +#' \code{TimestampType}. #' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse #' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string #' to use to specify the truncation method. For example, "year", "yyyy", "yy" for @@ -90,8 +91,8 @@ NULL #' #' Math functions defined for \code{Column}. #' -#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and \code{shiftRightUnsigned}, -#' this is the number of bits to shift. +#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and +#' \code{shiftRightUnsigned}, this is the number of bits to shift. #' @param y Column to compute on. #' @param ... additional argument(s). #' @name column_math_functions @@ -480,7 +481,7 @@ setMethod("ceiling", setMethod("coalesce", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -676,7 +677,7 @@ setMethod("crc32", setMethod("hash", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -1310,9 +1311,9 @@ setMethod("round", #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' -#' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, -#' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left -#' of the decimal point when \code{scale} < 0. +#' @param scale round to \code{scale} digits to the right of the decimal point when +#' \code{scale} > 0, the nearest even number when \code{scale} = 0, and \code{scale} digits +#' to the left of the decimal point when \code{scale} < 0. #' @rdname column_math_functions #' @aliases bround bround,Column-method #' @export @@ -2005,8 +2006,9 @@ setMethod("months_between", signature(y = "Column"), }) #' @details -#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column (\code{x}) if -#' the first column is NaN. Both inputs should be floating point columns (DoubleType or FloatType). +#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column +#' (\code{x}) if the first column is NaN. Both inputs should be floating point columns +#' (DoubleType or FloatType). #' #' @rdname column_nonaggregate_functions #' @aliases nanvl nanvl,Column-method @@ -2061,7 +2063,7 @@ setMethod("approxCountDistinct", setMethod("countDistinct", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(...), function (x) { + jcols <- lapply(list(...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2090,7 +2092,7 @@ setMethod("countDistinct", setMethod("concat", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2110,7 +2112,7 @@ setMethod("greatest", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2130,7 +2132,7 @@ setMethod("least", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2406,8 +2408,8 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), }) #' @details -#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long value, -#' it will return a long value else it will return an integer value. +#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long +#' value, it will return a long value else it will return an integer value. #' #' @rdname column_math_functions #' @aliases shiftRight shiftRight,Column,numeric-method @@ -2505,9 +2507,10 @@ setMethod("format_string", signature(format = "character", x = "Column"), }) #' @details -#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a -#' string representing the timestamp of that moment in the current system time zone in the JVM in the -#' given format. See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) +#' to a string representing the timestamp of that moment in the current system time zone in the JVM +#' in the given format. +#' See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ #' Customizing Formats} for available options. #' #' @rdname column_datetime_functions @@ -2634,8 +2637,8 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), }) #' @details -#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) samples -#' from U[0.0, 1.0]. +#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) +#' samples from U[0.0, 1.0]. #' #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. @@ -2664,8 +2667,8 @@ setMethod("rand", signature(seed = "numeric"), }) #' @details -#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples from -#' the standard normal distribution. +#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples +#' from the standard normal distribution. #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method @@ -2831,8 +2834,8 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), }) #' @details -#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result expressions. -#' For unmatched expressions null is returned. +#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result +#' expressions. For unmatched expressions null is returned. #' #' @rdname column_nonaggregate_functions #' @param condition the condition to test on. Must be a Column expression. @@ -2859,8 +2862,8 @@ setMethod("when", signature(condition = "Column", value = "ANY"), }) #' @details -#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. -#' Otherwise \code{no} is returned for unmatched conditions. +#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are +#' satisfied. Otherwise \code{no} is returned for unmatched conditions. #' #' @rdname column_nonaggregate_functions #' @param test a Column expression that describes the condition. @@ -2990,7 +2993,8 @@ setMethod("ntile", }) #' @details -#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window partition. +#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window +#' partition. #' This is computed by: (rank of row in its partition - 1) / (number of rows in the partition - 1). #' This is equivalent to the \code{PERCENT_RANK} function in SQL. #' The method should be used with no argument. @@ -3160,7 +3164,8 @@ setMethod("posexplode", }) #' @details -#' \code{create_array}: Creates a new array column. The input columns must all have the same data type. +#' \code{create_array}: Creates a new array column. The input columns must all have the same data +#' type. #' #' @rdname column_nonaggregate_functions #' @aliases create_array create_array,Column-method @@ -3169,7 +3174,7 @@ setMethod("posexplode", setMethod("create_array", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -3178,8 +3183,8 @@ setMethod("create_array", }) #' @details -#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value pairs, -#' e.g. (key1, value1, key2, value2, ...). +#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value +#' pairs, e.g. (key1, value1, key2, value2, ...). #' The key columns must all have the same data type, and can't be null. #' The value columns must all have the same data type. #' @@ -3190,7 +3195,7 @@ setMethod("create_array", setMethod("create_map", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -3352,9 +3357,9 @@ setMethod("not", }) #' @details -#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or not, -#' returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} in SQL -#' and \code{grouping} function in Scala. +#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or +#' not, returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} +#' in SQL and \code{grouping} function in Scala. #' #' @rdname column_aggregate_functions #' @aliases grouping_bit grouping_bit,Column-method @@ -3412,7 +3417,7 @@ setMethod("grouping_bit", setMethod("grouping_id", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0fe8f0453b064..4e427489f6860 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -385,7 +385,7 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @return A SparkDataFrame. #' @rdname summarize #' @export -setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +setGeneric("agg", function(x, ...) { standardGeneric("agg") }) #' alias #' @@ -731,7 +731,7 @@ setGeneric("schema", function(x) { standardGeneric("schema") }) #' @rdname select #' @export -setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) +setGeneric("select", function(x, col, ...) { standardGeneric("select") }) #' @rdname selectExpr #' @export diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 0a7be0e993975..54ef9f07d6fae 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -133,8 +133,8 @@ setMethod("summarize", # Aggregate Functions by name methods <- c("avg", "max", "mean", "min", "sum") -# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", -# "variance", "var_samp", "var_pop" +# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", +# "stddev_pop", "variance", "var_samp", "var_pop" #' Pivot a column of the GroupedData and perform the specified aggregation. #' diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 15af8298ba484..7cd072a1d6f89 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -58,22 +58,25 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. -#' @param standardization Whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. +#' @param standardization Whether to standardize the training features before fitting the model. +#' The coefficients of models will be always returned on the original scale, +#' so it will be transparent for users. Note that with/without +#' standardization, the models should be always converged to the same +#' solution when no regularization is applied. #' @param threshold The threshold in binary classification applied to the linear model prediction. #' This threshold can be any real number, where Inf will make all predictions 0.0 #' and -Inf will make all predictions 1.0. #' @param weightCol The weight column name. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this param +#' could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.svmLinear} returns a fitted linear SVM model. #' @rdname spark.svmLinear @@ -175,62 +178,80 @@ function(object, path, overwrite = FALSE) { #' Logistic Regression Model #' -#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary logistic regression -#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. -#' Users can print, make predictions on the produced model and save the model to the input path. +#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary +#' logistic regression with pivoting; "multinomial": Multinomial logistic (softmax) regression +#' without pivoting, similar to glmnet. Users can print, make predictions on the produced model +#' and save the model to the input path. #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param regParam the regularization parameter. -#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. -#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination -#' of L1 and L2. Default is 0.0 which is an L2 penalty. +#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 +#' penalty. For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, +#' the penalty is a combination of L1 and L2. Default is 0.0 which is an +#' L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param family the name of family which is a description of the label distribution to be used in the model. +#' @param family the name of family which is a description of the label distribution to be used +#' in the model. #' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". #' Else, set to "multinomial".} #' \item{"binomial": Binary logistic regression with pivoting.} -#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} +#' \item{"multinomial": Multinomial logistic (softmax) regression without +#' pivoting.} #' } -#' @param standardization whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. -#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 -#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 -#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with -#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of -#' predicting each class. Array must have length equal to the number of classes, with values > 0, -#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. +#' @param standardization whether to standardize the training features before fitting the model. +#' The coefficients of models will be always returned on the original scale, +#' so it will be transparent for users. Note that with/without +#' standardization, the models should be always converged to the same +#' solution when no regularization is applied. Default is TRUE, same as +#' glmnet. +#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of +#' class label 1 is > threshold, then predict 1, else 0. A high threshold +#' encourages the model to predict 0 more often; a low threshold encourages the +#' model to predict 1 more often. Note: Setting this with threshold p is +#' equivalent to setting thresholds c(1-p, p). In multiclass (or binary) +#' classification to adjust the probability of predicting each class. Array must +#' have length equal to the number of classes, with values > 0, excepting that +#' at most one value may be 0. The class with largest value p/t is predicted, +#' where p is the original probability of that class and t is the class's +#' threshold. #' @param weightCol The weight column name. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. -#' This is an expert parameter. Default value should be good for most cases. -#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound constrained optimization. -#' The bound matrix must be compatible with the shape (1, number of features) for binomial -#' regression, or (number of classes, number of features) for multinomial regression. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this param +#' could be adjusted to a larger size. This is an expert parameter. Default +#' value should be good for most cases. +#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound +#' constrained optimization. +#' The bound matrix must be compatible with the shape (1, number +#' of features) for binomial regression, or (number of classes, +#' number of features) for multinomial regression. #' It is a R matrix. -#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound constrained optimization. -#' The bound matrix must be compatible with the shape (1, number of features) for binomial -#' regression, or (number of classes, number of features) for multinomial regression. +#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound +#' constrained optimization. +#' The bound matrix must be compatible with the shape (1, number +#' of features) for binomial regression, or (number of classes, +#' number of features) for multinomial regression. #' It is a R matrix. -#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained optimization. -#' The bounds vector size must be equal to 1 for binomial regression, or the number -#' of classes for multinomial regression. -#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization. -#' The bound vector size must be equal to 1 for binomial regression, or the number +#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained +#' optimization. +#' The bounds vector size must be equal to 1 for binomial regression, +#' or the number #' of classes for multinomial regression. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained +#' optimization. +#' The bound vector size must be equal to 1 for binomial regression, +#' or the number of classes for multinomial regression. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -412,11 +433,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @param seed seed parameter for weights initialization. #' @param initialWeights initialWeights parameter for weights initialization, it should be a #' numeric vector. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp @@ -452,11 +474,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { - stop ("layers must be a integer vector with length > 1.") + stop("layers must be a integer vector with length > 1.") } layers <- as.integer(na.omit(layers)) if (length(layers) <= 1) { - stop ("layers must be a integer vector with length > 1.") + stop("layers must be a integer vector with length > 1.") } if (!is.null(seed)) { seed <- as.character(as.integer(seed)) @@ -538,11 +560,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param smoothing smoothing parameter. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. #' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. #' @rdname spark.naiveBayes diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 97c9fa1b45840..a25bf81c6d977 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -60,9 +60,9 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @param maxIter maximum iteration number. #' @param seed the random seed. #' @param minDivisibleClusterSize The minimum number of points (if greater than or equal to 1.0) -#' or the minimum proportion of points (if less than 1.0) of a divisible cluster. -#' Note that it is an expert parameter. The default value should be good enough -#' for most cases. +#' or the minimum proportion of points (if less than 1.0) of a +#' divisible cluster. Note that it is an expert parameter. The +#' default value should be good enough for most cases. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.bisectingKmeans} returns a fitted bisecting k-means model. #' @rdname spark.bisectingKmeans @@ -325,10 +325,11 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' Note that the response variable of formula is empty in spark.kmeans. #' @param k number of centers. #' @param maxIter maximum iteration number. -#' @param initMode the initialization algorithm choosen to fit the model. +#' @param initMode the initialization algorithm chosen to fit the model. #' @param seed the random seed for cluster initialization. #' @param initSteps the number of steps for the k-means|| initialization mode. -#' This is an advanced setting, the default of 2 is almost always enough. Must be > 0. +#' This is an advanced setting, the default of 2 is almost always enough. +#' Must be > 0. #' @param tol convergence tolerance of iterations. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.kmeans} returns a fitted k-means model. @@ -548,8 +549,8 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), #' \item{\code{topics}}{top 10 terms and their weights of all topics} #' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file #' used as training set} -#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the training set, -#' given the current parameter estimates: +#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the +#' training set, given the current parameter estimates: #' log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) #' It is only for distributed LDA model (i.e., optimizer = "em")} #' \item{\code{logPrior}}{Log probability of the current parameter estimate: diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index ebaeae970218a..f734a0865ec3b 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -58,8 +58,8 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' Note that there are two ways to specify the tweedie family. #' \itemize{ #' \item Set \code{family = "tweedie"} and specify the var.power and link.power; -#' \item When package \code{statmod} is loaded, the tweedie family is specified using the -#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. +#' \item When package \code{statmod} is loaded, the tweedie family is specified +#' using the family definition therein, i.e., \code{tweedie(var.power, link.power)}. #' } #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. @@ -71,13 +71,15 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' applicable to the Tweedie family. #' @param link.power the index in the power link function. Only applicable to the Tweedie family. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. -#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets -#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance +#' offsets as 0.0. The feature specified as offset has a constant coefficient of +#' 1.0. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -197,13 +199,15 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @param var.power the index of the power variance function in the Tweedie family. #' @param link.power the index of the power link function in the Tweedie family. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. -#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets -#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance +#' offsets as 0.0. The feature specified as offset has a constant coefficient of +#' 1.0. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -233,11 +237,11 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' @param object a fitted generalized linear model. #' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes -#' coefficients, standard error of coefficients, t value and p value), +#' The list of components includes at least the \code{coefficients} (coefficients matrix, +#' which includes coefficients, standard error of coefficients, t value and p value), #' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) -#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, -#' the coefficients matrix only provides coefficients. +#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in +#' the data, the coefficients matrix only provides coefficients. #' @rdname spark.glm #' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 @@ -457,15 +461,17 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. -#' This is an expert parameter. Default value should be good for most cases. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this +#' param could be adjusted to a larger size. This is an expert parameter. +#' Default value should be good for most cases. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. #' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 33c4653f4c184..89a58bf0aadae 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -132,10 +132,12 @@ print.summary.decisionTree <- function(x) { #' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and #' \code{write.ml}/\code{read.ml} to save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ #' GBT Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ #' GBT Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -164,11 +166,12 @@ print.summary.decisionTree <- function(x) { #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.gbt,SparkDataFrame,formula-method #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. @@ -352,10 +355,12 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ #' Random Forest Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ #' Random Forest Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -382,11 +387,12 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -567,10 +573,12 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{ #' Decision Tree Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{ #' Decision Tree Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -592,11 +600,12 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.decisionTree,SparkDataFrame,formula-method #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. @@ -671,7 +680,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees). +#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of +#' trees). #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeRegressionModel-method #' @export diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 8fa21be3076b5..9c2e57d3067db 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -860,7 +860,7 @@ setMethod("subtractByKey", other, numPartitions = numPartitions), filterFunction), - function (v) { v[[1]] }) + function(v) { v[[1]] }) }) #' Return a subset of this RDD sampled by key. @@ -925,7 +925,7 @@ setMethod("sampleByKey", if (withReplacement) { count <- stats::rpois(1, frac) if (count > 0) { - res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + res[(len + 1) : (len + count)] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index d1ed6833d5d02..65f418740c643 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -155,7 +155,7 @@ checkType <- function(type) { } else { # Check complex types firstChar <- substr(type, 1, 1) - switch (firstChar, + switch(firstChar, a = { # Array type m <- regexec("^array<(.+)>$", type) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 9a9fa84044ce6..c8af798830b30 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -29,9 +29,9 @@ setOldClass("jobj") #' @param col1 name of the first column. Distinct items will make the first item of each row. #' @param col2 name of the second column. Distinct items will make the column names of the output. #' @return a local R data.frame representing the contingency table. The first column of each row -#' will be the distinct values of \code{col1} and the column names will be the distinct values -#' of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". Pairs -#' that have no occurrences will have zero as their counts. +#' will be the distinct values of \code{col1} and the column names will be the distinct +#' values of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". +#' Pairs that have no occurrences will have zero as their counts. #' #' @rdname crosstab #' @name crosstab @@ -53,8 +53,8 @@ setMethod("crosstab", }) #' @details -#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two numerical -#' columns of \emph{one} SparkDataFrame. +#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two +#' numerical columns of \emph{one} SparkDataFrame. #' #' @param colName1 the name of the first column #' @param colName2 the name of the second column @@ -159,8 +159,8 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' @param relativeError The relative target precision to achieve (>= 0). If set to zero, #' the exact quantiles are computed, which could be very expensive. #' Note that values greater than 1 are accepted but give the same result as 1. -#' @return The approximate quantiles at the given probabilities. If the input is a single column name, -#' the output is a list of approximate quantiles in that column; If the input is +#' @return The approximate quantiles at the given probabilities. If the input is a single column +#' name, the output is a list of approximate quantiles in that column; If the input is #' multiple column names, the output should be a list, and each element in it is a list of #' numeric values which represents the approximate quantiles in corresponding column. #' diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 91483a4d23d9b..4b716995f2c46 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -625,7 +625,7 @@ appendPartitionLengths <- function(x, other) { x <- lapplyPartition(x, appendLength) other <- lapplyPartition(other, appendLength) } - list (x, other) + list(x, other) } # Perform zip or cartesian between elements from two RDDs in each partition @@ -657,7 +657,7 @@ mergePartitions <- function(rdd, zip) { keys <- list() } if (lengthOfValues > 1) { - values <- part[ (lengthOfKeys + 1) : (len - 1) ] + values <- part[(lengthOfKeys + 1) : (len - 1)] } else { values <- list() } diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 03e7450147865..00789d815bba8 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -68,7 +68,7 @@ compute <- function(mode, partition, serializer, deserializer, key, } else { output <- computeFunc(partition, inputData) } - return (output) + return(output) } outputResult <- function(serializer, output, outputCon) { diff --git a/R/pkg/tests/fulltests/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R index 442bed509bb1d..c5d240f3e7344 100644 --- a/R/pkg/tests/fulltests/test_binary_function.R +++ b/R/pkg/tests/fulltests/test_binary_function.R @@ -73,7 +73,7 @@ test_that("zipPartitions() on RDDs", { rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 actual <- collectRDD(zipPartitions(rdd1, rdd2, rdd3, - func = function(x, y, z) { list(list(x, y, z))} )) + func = function(x, y, z) { list(list(x, y, z))})) expect_equal(actual, list(list(1, c(1, 2), c(1, 2, 3)), list(2, c(3, 4), c(4, 5, 6)))) diff --git a/R/pkg/tests/fulltests/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R index 6ee1fceffd822..0c702ea897f7c 100644 --- a/R/pkg/tests/fulltests/test_rdd.R +++ b/R/pkg/tests/fulltests/test_rdd.R @@ -698,14 +698,14 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { - numPairsRdd <- map(rdd, function(x) { list (x, x) }) + numPairsRdd <- map(rdd, function(x) { list(x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) - numPairs <- lapply(nums, function(x) { list (x, x) }) + numPairs <- lapply(nums, function(x) { list(x, x) }) expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) - numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + numPairsRdd2 <- map(rdd2, function(x) { list(x, x) }) sortedRdd2 <- sortByKey(numPairsRdd2) actual <- collectRDD(sortedRdd2) expect_equal(actual, numPairs) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4e62be9b4d619..7f781f2f66a7f 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -560,9 +560,9 @@ test_that("Collect DataFrame with complex types", { expect_equal(nrow(ldf), 3) expect_equal(ncol(ldf), 3) expect_equal(names(ldf), c("c1", "c2", "c3")) - expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) - expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) - expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list(7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list(7.0, 8.0, 9.0))) # MapType schema <- structType(structField("name", "string"), @@ -1524,7 +1524,7 @@ test_that("column functions", { expect_equal(ncol(s), 1) expect_equal(nrow(s), 3) expect_is(s[[1]][[1]], "struct") - expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 }))) } # passing option @@ -2710,7 +2710,7 @@ test_that("freqItems() on a DataFrame", { input <- 1:1000 rdf <- data.frame(numbers = input, letters = as.character(input), negDoubles = input * -1.0, stringsAsFactors = F) - rdf[ input %% 3 == 0, ] <- c(1, "1", -1) + rdf[input %% 3 == 0, ] <- c(1, "1", -1) df <- createDataFrame(rdf) multiColResults <- freqItems(df, c("numbers", "letters"), support = 0.1) expect_true(1 %in% multiColResults$numbers[[1]]) @@ -3064,7 +3064,7 @@ test_that("coalesce, repartition, numPartitions", { }) test_that("gapply() and gapplyCollect() on a DataFrame", { - df <- createDataFrame ( + df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) expected <- collect(df) @@ -3135,7 +3135,7 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { actual <- df3Collect[order(df3Collect$a), ] expect_identical(actual$avg, expected$avg) - irisDF <- suppressWarnings(createDataFrame (iris)) + irisDF <- suppressWarnings(createDataFrame(iris)) schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double")) # Groups by `Sepal_Length` and computes the average for `Sepal_Width` df4 <- gapply( diff --git a/dev/lint-r.R b/dev/lint-r.R index 87ee36d5c9b68..a4261d266bbc0 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -26,8 +26,8 @@ if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) { # Installs lintr from Github in a local directory. # NOTE: The CRAN's version is too old to adapt to our rules. -if ("lintr" %in% row.names(installed.packages()) == FALSE) { - devtools::install_github("jimhester/lintr@a769c0b") +if ("lintr" %in% row.names(installed.packages()) == FALSE) { + devtools::install_github("jimhester/lintr@5431140") } library(lintr) From 3ca367083e196e6487207211e6c49d4bbfe31288 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 1 Oct 2017 10:49:22 -0700 Subject: [PATCH 651/779] [SPARK-22001][ML][SQL] ImputerModel can do withColumn for all input columns at one pass ## What changes were proposed in this pull request? SPARK-21690 makes one-pass `Imputer` by parallelizing the computation of all input columns. When we transform dataset with `ImputerModel`, we do `withColumn` on all input columns sequentially. We can also do this on all input columns at once by adding a `withColumns` API to `Dataset`. The new `withColumns` API is for internal use only now. ## How was this patch tested? Existing tests for `ImputerModel`'s change. Added tests for `withColumns` API. Author: Liang-Chi Hsieh Closes #19229 from viirya/SPARK-22001. --- .../org/apache/spark/ml/feature/Imputer.scala | 10 ++-- .../scala/org/apache/spark/sql/Dataset.scala | 42 ++++++++++----- .../org/apache/spark/sql/DataFrameSuite.scala | 52 +++++++++++++++++++ 3 files changed, 86 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 1f36eced3d08f..4663f16b5f5dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -223,20 +223,18 @@ class ImputerModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - var outputDF = dataset val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq - $(inputCols).zip($(outputCols)).zip(surrogates).foreach { + val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), surrogate) => val inputType = dataset.schema(inputCol).dataType val ic = col(inputCol) - outputDF = outputDF.withColumn(outputCol, - when(ic.isNull, surrogate) + when(ic.isNull, surrogate) .when(ic === $(missingValue), surrogate) .otherwise(ic) - .cast(inputType)) + .cast(inputType) } - outputDF.toDF() + dataset.withColumns($(outputCols), newCols).toDF() } override def transformSchema(schema: StructType): StructType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ab0c4126bcbdd..f2a76a506eb6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2083,22 +2083,40 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def withColumn(colName: String, col: Column): DataFrame = { + def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) + + /** + * Returns a new Dataset by adding columns or replacing the existing columns that has + * the same names. + */ + private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = { + require(colNames.size == cols.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of columns: ${cols.size}") + SchemaUtils.checkColumnNameDuplication( + colNames, + "in given column names", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output - val shouldReplace = output.exists(f => resolver(f.name, colName)) - if (shouldReplace) { - val columns = output.map { field => - if (resolver(field.name, colName)) { - col.as(colName) - } else { - Column(field) - } + + val columnMap = colNames.zip(cols).toMap + + val replacedAndExistingColumns = output.map { field => + columnMap.find { case (colName, _) => + resolver(field.name, colName) + } match { + case Some((colName: String, col: Column)) => col.as(colName) + case _ => Column(field) } - select(columns : _*) - } else { - select(Column("*"), col.as(colName)) } + + val newColumns = columnMap.filter { case (colName, col) => + !output.exists(f => resolver(f.name, colName)) + }.map { case (colName, col) => col.as(colName) } + + select(replacedAndExistingColumns ++ newColumns : _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0e2f2e5a193e1..672deeac597f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -641,6 +641,49 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) } + test("withColumns") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2")) + + val err = intercept[IllegalArgumentException] { + testData.toDF().withColumns(Seq("newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert( + err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2")) + + val err2 = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCOL1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err2.getMessage.contains("Found duplicate column(s)")) + } + + test("withColumns: case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1")) + + val err = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err.getMessage.contains("Found duplicate column(s)")) + } + } + test("replace column using withColumn") { val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) @@ -649,6 +692,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(2) :: Row(3) :: Row(4) :: Nil) } + test("replace column using withColumns") { + val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y") + val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"), + Seq(df2("x") + 1, df2("y"), df2("y") + 1)) + checkAnswer( + df3.select("x", "newCol1", "newCol2"), + Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil) + } + test("drop column using drop") { val df = testData.drop("key") checkAnswer( From 405c0e99e7697bfa88aa4abc9a55ce5e043e48b1 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 2 Oct 2017 08:07:56 +0100 Subject: [PATCH 652/779] [SPARK-22173][WEB-UI] Table CSS style needs to be adjusted in History Page and in Executors Page. ## What changes were proposed in this pull request? There is a problem with table CSS style. 1. At present, table CSS style is too crowded, and the table width cannot adapt itself. 2. Table CSS style is different from job page, stage page, task page, master page, worker page, etc. The Spark web UI needs to be consistent. fix before: ![01](https://user-images.githubusercontent.com/26266482/31041261-c6766c3a-a5c4-11e7-97a7-96bd51ef12bd.png) ![02](https://user-images.githubusercontent.com/26266482/31041266-d75b6a32-a5c4-11e7-8071-e3bbbba39b80.png) ---------------------------------------------------------------------------------------------------------- fix after: ![1](https://user-images.githubusercontent.com/26266482/31041162-808a5a3e-a5c3-11e7-8d92-d763b500ce53.png) ![2](https://user-images.githubusercontent.com/26266482/31041166-86e583e0-a5c3-11e7-949c-11c370db9e27.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19397 from guoxiaolongzte/SPARK-22173. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 4 ++-- .../main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index af14717633409..6399dccc1676a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,7 +37,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val content =
-
+
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
@@ -58,7 +58,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { if (allAppsSize > 0) { ++ - ++ +
++ ++ ++ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index d63381c78bc3b..7b2767f0be3cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -82,7 +82,7 @@ private[ui] class ExecutorsPage(
++ - ++ +
++ ++ ++ From 8fab7995d36c7bc4524393b20a4e524dbf6bbf62 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 2 Oct 2017 11:46:51 -0700 Subject: [PATCH 653/779] [SPARK-22167][R][BUILD] sparkr packaging issue allow zinc ## What changes were proposed in this pull request? When zinc is running the pwd might be in the root of the project. A quick solution to this is to not go a level up incase we are in the root rather than root/core/. If we are in the root everything works fine, if we are in core add a script which goes and runs the level up ## How was this patch tested? set -x in the SparkR install scripts. Author: Holden Karau Closes #19402 from holdenk/SPARK-22167-sparkr-packaging-issue-allow-zinc. --- R/install-dev.sh | 1 + core/pom.xml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/R/install-dev.sh b/R/install-dev.sh index d613552718307..9fbc999f2e805 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -28,6 +28,7 @@ set -o pipefail set -e +set -x FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" LIB_DIR="$FWDIR/lib" diff --git a/core/pom.xml b/core/pom.xml index 09669149d8123..54f7a34a6c37e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -499,7 +499,7 @@ - ..${file.separator}R${file.separator}install-dev${script.extension} + ${project.basedir}${file.separator}..${file.separator}R${file.separator}install-dev${script.extension} From e5431f2cfddc8e96194827a2123b92716c7a1467 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 2 Oct 2017 15:00:26 -0700 Subject: [PATCH 654/779] [SPARK-22158][SQL] convertMetastore should not ignore table property ## What changes were proposed in this pull request? From the beginning, convertMetastoreOrc ignores table properties and use an empty map instead. This PR fixes that. For the diff, please see [this](https://github.com/apache/spark/pull/19382/files?w=1). convertMetastoreParquet also ignore. ```scala val options = Map[String, String]() ``` - [SPARK-14070: HiveMetastoreCatalog.scala](https://github.com/apache/spark/pull/11891/files#diff-ee66e11b56c21364760a5ed2b783f863R650) - [Master branch: HiveStrategies.scala](https://github.com/apache/spark/blob/master/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala#L197 ) ## How was this patch tested? Pass the Jenkins with an updated test suite. Author: Dongjoon Hyun Closes #19382 from dongjoon-hyun/SPARK-22158. --- .../spark/sql/hive/HiveStrategies.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 54 ++++++++++++++++--- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 805b3171cdaab..3592b8f4846d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -189,12 +189,12 @@ case class RelationConversions( private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) if (serde.contains("parquet")) { - val options = Map(ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = Map[String, String]() + val options = relation.tableMeta.storage.properties sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[OrcFileFormat], "orc") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 668da5fb47323..02e26bbe876a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -23,6 +23,8 @@ import java.net.URI import scala.language.existentials import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER +import org.apache.parquet.hadoop.ParquetFileReader import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException @@ -32,6 +34,7 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAl import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET} import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -1455,12 +1458,8 @@ class HiveDDLSuite sql("INSERT INTO t SELECT 1") checkAnswer(spark.table("t"), Row(1)) // Check if this is compressed as ZLIB. - val maybeOrcFile = path.listFiles().find(!_.getName.endsWith(".crc")) - assert(maybeOrcFile.isDefined) - val orcFilePath = maybeOrcFile.get.toPath.toString - val expectedCompressionKind = - OrcFileOperator.getFileReader(orcFilePath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + val maybeOrcFile = path.listFiles().find(_.getName.startsWith("part")) + assertCompression(maybeOrcFile, "orc", "ZLIB") sql("CREATE TABLE t2 USING HIVE AS SELECT 1 AS c1, 'a' AS c2") val table2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t2")) @@ -2009,4 +2008,47 @@ class HiveDDLSuite } } } + + private def assertCompression(maybeFile: Option[File], format: String, compression: String) = { + assert(maybeFile.isDefined) + + val actualCompression = format match { + case "orc" => + OrcFileOperator.getFileReader(maybeFile.get.toPath.toString).get.getCompression.name + + case "parquet" => + val footer = ParquetFileReader.readFooter( + sparkContext.hadoopConfiguration, new Path(maybeFile.get.getPath), NO_FILTER) + footer.getBlocks.get(0).getColumns.get(0).getCodec.toString + } + + assert(compression === actualCompression) + } + + Seq(("orc", "ZLIB"), ("parquet", "GZIP")).foreach { case (fileFormat, compression) => + test(s"SPARK-22158 convertMetastore should not ignore table property - $fileFormat") { + withSQLConf(CONVERT_METASTORE_ORC.key -> "true", CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) USING hive + |OPTIONS(fileFormat '$fileFormat', compression '$compression') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains(fileFormat)) + assert(table.storage.properties.get("compression") == Some(compression)) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + assertCompression(maybeFile, fileFormat, compression) + } + } + } + } + } } From 4329eb2e73181819bb712f57ca9c7feac0d640ea Mon Sep 17 00:00:00 2001 From: Gene Pang Date: Mon, 2 Oct 2017 15:09:11 -0700 Subject: [PATCH 655/779] [SPARK-16944][Mesos] Improve data locality when launching new executors when dynamic allocation is enabled ## What changes were proposed in this pull request? Improve the Spark-Mesos coarse-grained scheduler to consider the preferred locations when dynamic allocation is enabled. ## How was this patch tested? Added a unittest, and performed manual testing on AWS. Author: Gene Pang Closes #18098 from gpang/mesos_data_locality. --- .../spark/internal/config/package.scala | 4 ++ .../spark/scheduler/TaskSetManager.scala | 6 +- .../MesosCoarseGrainedSchedulerBackend.scala | 52 ++++++++++++++-- ...osCoarseGrainedSchedulerBackendSuite.scala | 62 +++++++++++++++++++ .../spark/scheduler/cluster/mesos/Utils.scala | 6 ++ 5 files changed, 123 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 44a2815b81a73..d85b6a0200b8d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -72,6 +72,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("3s") + private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index bb867416a4fac..3bdede6743d1b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap @@ -980,7 +980,7 @@ private[spark] class TaskSetManager( } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.get("spark.locality.wait", "3s") + val defaultWait = conf.get(config.LOCALITY_WAIT) val localityWaitKey = level match { case TaskLocality.PROCESS_LOCAL => "spark.locality.wait.process" case TaskLocality.NODE_LOCAL => "spark.locality.wait.node" @@ -989,7 +989,7 @@ private[spark] class TaskSetManager( } if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait.toString) } else { 0L } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 26699873145b4..80c0a041b7322 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -99,6 +99,14 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private var totalCoresAcquired = 0 private var totalGpusAcquired = 0 + // The amount of time to wait for locality scheduling + private val localityWait = conf.get(config.LOCALITY_WAIT) + // The start of the waiting, for data local scheduling + private var localityWaitStartTime = System.currentTimeMillis() + // If true, the scheduler is in the process of launching executors to reach the requested + // executor limit + private var launchingExecutors = false + // SlaveID -> Slave // This map accumulates entries for the duration of the job. Slaves are never deleted, because // we need to maintain e.g. failure state and connection state. @@ -311,6 +319,19 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( return } + if (numExecutors >= executorLimit) { + logDebug("Executor limit reached. numExecutors: " + numExecutors + + " executorLimit: " + executorLimit) + offers.asScala.map(_.getId).foreach(d.declineOffer) + launchingExecutors = false + return + } else { + if (!launchingExecutors) { + launchingExecutors = true + localityWaitStartTime = System.currentTimeMillis() + } + } + logDebug(s"Received ${offers.size} resource offers.") val (matchedOffers, unmatchedOffers) = offers.asScala.partition { offer => @@ -413,7 +434,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val offerId = offer.getId.getValue val resources = remainingResources(offerId) - if (canLaunchTask(slaveId, resources)) { + if (canLaunchTask(slaveId, offer.getHostname, resources)) { // Create a task launchTasks = true val taskId = newMesosTaskId() @@ -477,7 +498,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpuResourcesToUse ++ memResourcesToUse ++ portResourcesToUse ++ gpuResourcesToUse) } - private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { + private def canLaunchTask(slaveId: String, offerHostname: String, + resources: JList[Resource]): Boolean = { val offerMem = getResource(resources, "mem") val offerCPUs = getResource(resources, "cpus").toInt val cpus = executorCores(offerCPUs) @@ -489,9 +511,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpus <= offerCPUs && cpus + totalCoresAcquired <= maxCores && mem <= offerMem && - numExecutors() < executorLimit && + numExecutors < executorLimit && slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES && - meetsPortRequirements + meetsPortRequirements && + satisfiesLocality(offerHostname) } private def executorCores(offerCPUs: Int): Int = { @@ -500,6 +523,25 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( ) } + private def satisfiesLocality(offerHostname: String): Boolean = { + if (!Utils.isDynamicAllocationEnabled(conf) || hostToLocalTaskCount.isEmpty) { + return true + } + + // Check the locality information + val currentHosts = slaves.values.filter(_.taskIDs.nonEmpty).map(_.hostname).toSet + val allDesiredHosts = hostToLocalTaskCount.keys.toSet + // Try to match locality for hosts which do not have executors yet, to potentially + // increase coverage. + val remainingHosts = allDesiredHosts -- currentHosts + if (!remainingHosts.contains(offerHostname) && + (System.currentTimeMillis() - localityWaitStartTime <= localityWait)) { + logDebug("Skipping host and waiting for locality. host: " + offerHostname) + return false + } + return true + } + override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue val slaveId = status.getSlaveId.getValue @@ -646,6 +688,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // since at coarse grain it depends on the amount of slaves available. logInfo("Capping the total amount of executors to " + requestedTotal) executorLimitOption = Some(requestedTotal) + // Update the locality wait start time to continue trying for locality. + localityWaitStartTime = System.currentTimeMillis() true } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f6bae01c3af59..6c40792112f49 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -604,6 +604,55 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(backend.isReady) } + test("supports data locality with dynamic allocation") { + setBackend(Map( + "spark.dynamicAllocation.enabled" -> "true", + "spark.dynamicAllocation.testing" -> "true", + "spark.locality.wait" -> "1s")) + + assert(backend.getExecutorIds().isEmpty) + + backend.requestTotalExecutors(2, 2, Map("hosts10" -> 1, "hosts11" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(1, false) + offerResourcesAndVerify(2, false) + + // Offer local resource + offerResourcesAndVerify(10, true) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Offer non-local resource, which should be accepted + offerResourcesAndVerify(1, true) + + // Update total executors + backend.requestTotalExecutors(3, 3, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(3, false) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Update total executors + backend.requestTotalExecutors(4, 4, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1, + "hosts13" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(3, false) + + // Offer local resource + offerResourcesAndVerify(13, true) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Offer non-local resource, which should be accepted + offerResourcesAndVerify(2, true) + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { @@ -631,6 +680,19 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend.resourceOffers(driver, mesosOffers.asJava) } + private def offerResourcesAndVerify(id: Int, expectAccept: Boolean): Unit = { + offerResources(List(Resources(backend.executorMemory(sc), 1)), id) + if (expectAccept) { + val numExecutors = backend.getExecutorIds().size + val launchedTasks = verifyTaskLaunched(driver, s"o$id") + assert(s"s$id" == launchedTasks.head.getSlaveId.getValue) + registerMockExecutor(launchedTasks.head.getTaskId.getValue, s"s$id", 1) + assert(backend.getExecutorIds().size == numExecutors + 1) + } else { + verifyTaskNotLaunched(driver, s"o$id") + } + } + private def createTaskStatus(taskId: String, slaveId: String, state: TaskState): TaskStatus = { TaskStatus.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId).build()) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 2a67cbc913ffe..833db0c1ff334 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -84,6 +84,12 @@ object Utils { captor.getValue.asScala.toList } + def verifyTaskNotLaunched(driver: SchedulerDriver, offerId: String): Unit = { + verify(driver, times(0)).launchTasks( + Matchers.eq(Collections.singleton(createOfferId(offerId))), + Matchers.any(classOf[java.util.Collection[TaskInfo]])) + } + def createOfferId(offerId: String): OfferID = { OfferID.newBuilder().setValue(offerId).build() } From fa225da7463e384529da14706e44f4a09772e5c1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 2 Oct 2017 15:25:33 -0700 Subject: [PATCH 656/779] [SPARK-22176][SQL] Fix overflow issue in Dataset.show ## What changes were proposed in this pull request? This pr fixed an overflow issue below in `Dataset.show`: ``` scala> Seq((1, 2), (3, 4)).toDF("a", "b").show(Int.MaxValue) org.apache.spark.sql.AnalysisException: The limit expression must be equal to or greater than 0, but got -2147483648;; GlobalLimit -2147483648 +- LocalLimit -2147483648 +- Project [_1#27218 AS a#27221, _2#27219 AS b#27222] +- LocalRelation [_1#27218, _2#27219] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:41) at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:89) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.org$apache$spark$sql$catalyst$analysis$CheckAnalysis$$checkLimitClause(CheckAnalysis.scala:70) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:234) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:80) at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:127) ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #19401 from maropu/MaxValueInShowString. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f2a76a506eb6f..b70dfc05330f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -237,7 +237,7 @@ class Dataset[T] private[sql]( */ private[sql] def showString( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0) + val numRows = _numRows.max(0).min(Int.MaxValue - 1) val takeResult = toDF().take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 672deeac597f1..dd8f54b690f64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1045,6 +1045,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(0) === expectedAnswer) } + test("showString(Int.MaxValue)") { + val df = Seq((1, 2), (3, 4)).toDF("a", "b") + val expectedAnswer = """+---+---+ + || a| b| + |+---+---+ + || 1| 2| + || 3| 4| + |+---+---+ + |""".stripMargin + assert(df.showString(Int.MaxValue) === expectedAnswer) + } + test("showString(0), vertical = true") { val expectedAnswer = "(0 rows)\n" assert(testData.select($"*").showString(0, vertical = true) === expectedAnswer) From 4c5158eec9101ef105274df6b488e292a56156a2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 3 Oct 2017 12:38:13 -0700 Subject: [PATCH 657/779] [SPARK-21644][SQL] LocalLimit.maxRows is defined incorrectly ## What changes were proposed in this pull request? The definition of `maxRows` in `LocalLimit` operator was simply wrong. This patch introduces a new `maxRowsPerPartition` method and uses that in pruning. The patch also adds more documentation on why we need local limit vs global limit. Note that this previously has never been a bug because the way the code is structured, but future use of the maxRows could lead to bugs. ## How was this patch tested? Should be covered by existing test cases. Closes #18851 Author: gatorsmile Author: Reynold Xin Closes #19393 from gatorsmile/pr-18851. --- .../sql/catalyst/optimizer/Optimizer.scala | 29 ++++++----- .../catalyst/plans/logical/LogicalPlan.scala | 5 ++ .../plans/logical/basicLogicalOperators.scala | 49 ++++++++++++++++++- .../execution/basicPhysicalOperators.scala | 3 ++ 4 files changed, 74 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b9fa39d6dad4c..bc2d4a824cb49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -305,13 +305,20 @@ object LimitPushDown extends Rule[LogicalPlan] { } } - private def maybePushLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { - (limitExp, plan.maxRows) match { - case (IntegerLiteral(maxRow), Some(childMaxRows)) if maxRow < childMaxRows => + private def maybePushLocalLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { + (limitExp, plan.maxRowsPerPartition) match { + case (IntegerLiteral(newLimit), Some(childMaxRows)) if newLimit < childMaxRows => + // If the child has a cap on max rows per partition and the cap is larger than + // the new limit, put a new LocalLimit there. LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case (_, None) => + // If the child has no cap, put the new LocalLimit. LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) - case _ => plan + + case _ => + // Otherwise, don't put a new LocalLimit. + plan } } @@ -323,7 +330,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to // pushdown Limit. case LocalLimit(exp, Union(children)) => - LocalLimit(exp, Union(children.map(maybePushLimit(exp, _)))) + LocalLimit(exp, Union(children.map(maybePushLocalLimit(exp, _)))) // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side // because we need to ensure that rows from the limited side still have an opportunity to match @@ -335,19 +342,19 @@ object LimitPushDown extends Rule[LogicalPlan] { // - If neither side is limited, limit the side that is estimated to be bigger. case LocalLimit(exp, join @ Join(left, right, joinType, _)) => val newJoin = joinType match { - case RightOuter => join.copy(right = maybePushLimit(exp, right)) - case LeftOuter => join.copy(left = maybePushLimit(exp, left)) + case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) + case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) case FullOuter => (left.maxRows, right.maxRows) match { case (None, None) => if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { - join.copy(left = maybePushLimit(exp, left)) + join.copy(left = maybePushLocalLimit(exp, left)) } else { - join.copy(right = maybePushLimit(exp, right)) + join.copy(right = maybePushLocalLimit(exp, right)) } case (Some(_), Some(_)) => join - case (Some(_), None) => join.copy(left = maybePushLimit(exp, left)) - case (None, Some(_)) => join.copy(right = maybePushLimit(exp, right)) + case (Some(_), None) => join.copy(left = maybePushLocalLimit(exp, left)) + case (None, Some(_)) => join.copy(right = maybePushLocalLimit(exp, right)) } case _ => join diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 68aae720e026a..14188829db2af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -97,6 +97,11 @@ abstract class LogicalPlan */ def maxRows: Option[Long] = None + /** + * Returns the maximum number of rows this plan may compute on each partition. + */ + def maxRowsPerPartition: Option[Long] = maxRows + /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f443cd5a69de3..80243d3d356ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -191,6 +191,9 @@ object Union { } } +/** + * Logical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + */ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { override def maxRows: Option[Long] = { if (children.exists(_.maxRows.isEmpty)) { @@ -200,6 +203,17 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { } } + /** + * Note the definition has assumption about how union is implemented physically. + */ + override def maxRowsPerPartition: Option[Long] = { + if (children.exists(_.maxRowsPerPartition.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRowsPerPartition).sum) + } + } + // updating nullability to make all the children consistent override def output: Seq[Attribute] = children.map(_.output).transpose.map(attrs => @@ -669,6 +683,27 @@ case class Pivot( } } +/** + * A constructor for creating a logical limit, which is split into two separate logical nodes: + * a [[LocalLimit]], which is a partition local limit, followed by a [[GlobalLimit]]. + * + * This muds the water for clean logical/physical separation, and is done for better limit pushdown. + * In distributed query processing, a non-terminal global limit is actually an expensive operation + * because it requires coordination (in Spark this is done using a shuffle). + * + * In most cases when we want to push down limit, it is often better to only push some partition + * local limit. Consider the following: + * + * GlobalLimit(Union(A, B)) + * + * It is better to do + * GlobalLimit(Union(LocalLimit(A), LocalLimit(B))) + * + * than + * Union(GlobalLimit(A), GlobalLimit(B)). + * + * So we introduced LocalLimit and GlobalLimit in the logical plan node for limit pushdown. + */ object Limit { def apply(limitExpr: Expression, child: LogicalPlan): UnaryNode = { GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) @@ -682,6 +717,11 @@ object Limit { } } +/** + * A global (coordinated) limit. This operator can emit at most `limitExpr` number in total. + * + * See [[Limit]] for more information. + */ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { @@ -692,9 +732,16 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } } +/** + * A partition-local (non-coordinated) limit. This operator can emit at most `limitExpr` number + * of tuples on each physical partition. + * + * See [[Limit]] for more information. + */ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def maxRows: Option[Long] = { + + override def maxRowsPerPartition: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 8389e2f3d5be9..63cd1691f4cd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -554,6 +554,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) /** * Physical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + * + * If we change how this is implemented physically, we'd need to update + * [[org.apache.spark.sql.catalyst.plans.logical.Union.maxRowsPerPartition]]. */ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { override def output: Seq[Attribute] = From e65b6b7ca1a7cff1b91ad2262bb7941e6bf057cd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 3 Oct 2017 12:40:22 -0700 Subject: [PATCH 658/779] [SPARK-22178][SQL] Refresh Persistent Views by REFRESH TABLE Command ## What changes were proposed in this pull request? The underlying tables of persistent views are not refreshed when users issue the REFRESH TABLE command against the persistent views. ## How was this patch tested? Added a test case Author: gatorsmile Closes #19405 from gatorsmile/refreshView. --- .../apache/spark/sql/internal/CatalogImpl.scala | 15 +++++++++++---- .../spark/sql/hive/HiveMetadataCacheSuite.scala | 14 +++++++++++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 142b005850a49..fdd25330c5e67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -474,13 +474,20 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def refreshTable(tableName: String): Unit = { val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) - // Temp tables: refresh (or invalidate) any metadata/data cached in the plan recursively. - // Non-temp tables: refresh the metadata cache. - sessionCatalog.refreshTable(tableIdent) + val tableMetadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val table = sparkSession.table(tableIdent) + + if (tableMetadata.tableType == CatalogTableType.VIEW) { + // Temp or persistent views: refresh (or invalidate) any metadata/data cached + // in the plan recursively. + table.queryExecution.analyzed.foreach(_.refresh()) + } else { + // Non-temp tables: refresh the metadata cache. + sessionCatalog.refreshTable(tableIdent) + } // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val table = sparkSession.table(tableIdent) if (isCached(table)) { // Uncache the logicalPlan. sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index 0c28a1b609bb8..e71aba72c31fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -31,14 +31,22 @@ import org.apache.spark.sql.test.SQLTestUtils class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-16337 temporary view refresh") { - withTempView("view_refresh") { + checkRefreshView(isTemp = true) + } + + test("view refresh") { + checkRefreshView(isTemp = false) + } + + private def checkRefreshView(isTemp: Boolean) { + withView("view_refresh") { withTable("view_table") { // Create a Parquet directory spark.range(start = 0, end = 100, step = 1, numPartitions = 3) .write.saveAsTable("view_table") - // Read the table in - spark.table("view_table").filter("id > -1").createOrReplaceTempView("view_refresh") + val temp = if (isTemp) "TEMPORARY" else "" + spark.sql(s"CREATE $temp VIEW view_refresh AS SELECT * FROM view_table WHERE id > -1") assert(sql("select count(*) from view_refresh").first().getLong(0) == 100) // Delete a file using the Hadoop file system interface since the path returned by From e36ec38d89472df0dfe12222b6af54cd6eea8e98 Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Tue, 3 Oct 2017 16:53:32 -0700 Subject: [PATCH 659/779] [SPARK-20466][CORE] HadoopRDD#addLocalConfiguration throws NPE ## What changes were proposed in this pull request? Fix for SPARK-20466, full description of the issue in the JIRA. To summarize, `HadoopRDD` uses a metadata cache to cache `JobConf` objects. The cache uses soft-references, which means the JVM can delete entries from the cache whenever there is GC pressure. `HadoopRDD#getJobConf` had a bug where it would check if the cache contained the `JobConf`, if it did it would get the `JobConf` from the cache and return it. This doesn't work when soft-references are used as the JVM can delete the entry between the existence check and the get call. ## How was this patch tested? Haven't thought of a good way to test this yet given the issue only occurs sometimes, and happens during high GC pressure. Was thinking of using mocks to verify `#getJobConf` is doing the right thing. I deleted the method `HadoopRDD#containsCachedMetadata` so that we don't hit this issue again. Author: Sahil Takiar Closes #19413 from sahilTakiar/master. --- .../org/apache/spark/rdd/HadoopRDD.scala | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 76ea8b86c53d2..23b344230e490 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -157,20 +157,25 @@ class HadoopRDD[K, V]( if (conf.isInstanceOf[JobConf]) { logDebug("Re-using user-broadcasted JobConf") conf.asInstanceOf[JobConf] - } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - logDebug("Re-using cached JobConf") - HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] } else { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the - // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456). - HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { - logDebug("Creating new JobConf and caching it for later re-use") - val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - newJobConf + Option(HadoopRDD.getCachedMetadata(jobConfCacheKey)) + .map { conf => + logDebug("Re-using cached JobConf") + conf.asInstanceOf[JobConf] + } + .getOrElse { + // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in + // the local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary + // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, + // HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } } } } @@ -360,8 +365,6 @@ private[spark] object HadoopRDD extends Logging { */ def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key) - private def putCachedMetadata(key: String, value: Any): Unit = SparkEnv.get.hadoopJobMetadata.put(key, value) From 5f694334534e4425fb9e8abf5b7e3e5efdfcef50 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 3 Oct 2017 21:27:58 -0700 Subject: [PATCH 660/779] [SPARK-22171][SQL] Describe Table Extended Failed when Table Owner is Empty ## What changes were proposed in this pull request? Users could hit `java.lang.NullPointerException` when the tables were created by Hive and the table's owner is `null` that are got from Hive metastore. `DESC EXTENDED` failed with the error: > SQLExecutionException: java.lang.NullPointerException at scala.collection.immutable.StringOps$.length$extension(StringOps.scala:47) at scala.collection.immutable.StringOps.length(StringOps.scala:47) at scala.collection.IndexedSeqOptimized$class.isEmpty(IndexedSeqOptimized.scala:27) at scala.collection.immutable.StringOps.isEmpty(StringOps.scala:29) at scala.collection.TraversableOnce$class.nonEmpty(TraversableOnce.scala:111) at scala.collection.immutable.StringOps.nonEmpty(StringOps.scala:29) at org.apache.spark.sql.catalyst.catalog.CatalogTable.toLinkedHashMap(interface.scala:300) at org.apache.spark.sql.execution.command.DescribeTableCommand.describeFormattedTableInfo(tables.scala:565) at org.apache.spark.sql.execution.command.DescribeTableCommand.run(tables.scala:543) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:66) at ## How was this patch tested? Added a unit test case Author: gatorsmile Closes #19395 from gatorsmile/desc. --- .../sql/catalyst/catalog/interface.scala | 2 +- .../sql/catalyst/analysis/CatalogSuite.scala | 37 +++++++++++++++++++ .../sql/hive/client/HiveClientImpl.scala | 2 +- 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 1965144e81197..fe2af910a0ae5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -307,7 +307,7 @@ case class CatalogTable( identifier.database.foreach(map.put("Database", _)) map.put("Table", identifier.table) - if (owner.nonEmpty) map.put("Owner", owner) + if (owner != null && owner.nonEmpty) map.put("Owner", owner) map.put("Created Time", new Date(createTime).toString) map.put("Last Access", new Date(lastAccessTime).toString) map.put("Created By", "Spark " + createVersion) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala new file mode 100644 index 0000000000000..d670053ba1b5d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.types.StructType + + +class CatalogSuite extends AnalysisTest { + + test("desc table when owner is set to null") { + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + owner = null, + schema = new StructType().add("col1", "int").add("col2", "string"), + provider = Some("parquet")) + table.toLinkedHashMap + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c4e48c9360db7..66165c7228bca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -461,7 +461,7 @@ private[hive] class HiveClientImpl( // in table properties. This means, if we have bucket spec in both hive metastore and // table properties, we will trust the one in table properties. bucketSpec = bucketSpec, - owner = h.getOwner, + owner = Option(h.getOwner).getOrElse(""), createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, storage = CatalogStorageFormat( From 3099c574c56cab86c3fcf759864f89151643f837 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 3 Oct 2017 21:42:51 -0700 Subject: [PATCH 661/779] [SPARK-22136][SS] Implement stream-stream outer joins. ## What changes were proposed in this pull request? Allow one-sided outer joins between two streams when a watermark is defined. ## How was this patch tested? new unit tests Author: Jose Torres Closes #19327 from joseph-torres/outerjoin. --- .../analysis/StreamingJoinHelper.scala | 286 +++++++++++++++++ .../UnsupportedOperationChecker.scala | 53 +++- .../analysis/StreamingJoinHelperSuite.scala | 140 ++++++++ .../analysis/UnsupportedOperationsSuite.scala | 108 ++++++- .../StreamingSymmetricHashJoinExec.scala | 152 +++++++-- .../StreamingSymmetricHashJoinHelper.scala | 241 +------------- .../state/SymmetricHashJoinStateManager.scala | 200 +++++++++--- .../SymmetricHashJoinStateManagerSuite.scala | 6 +- .../sql/streaming/StreamingJoinSuite.scala | 298 +++++++++++------- 9 files changed, 1051 insertions(+), 433 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala new file mode 100644 index 0000000000000..072dc954879ca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + + +/** + * Helper object for stream joins. See [[StreamingSymmetricHashJoinExec]] in SQL for more details. + */ +object StreamingJoinHelper extends PredicateHelper with Logging { + + /** + * Check the provided logical plan to see if its join keys contain a watermark attribute. + * + * Will return false if the plan is not an equijoin. + * @param plan the logical plan to check + */ + def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = { + plan match { + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) => + (leftKeys ++ rightKeys).exists { + case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) + case _ => false + } + case _ => false + } + } + + /** + * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) + * given the join condition and the event time watermark. This is how it works. + * - The condition is split into conjunctive predicates, and we find the predicates of the + * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). + * - We canoncalize the predicate and solve it with the event time watermark value to find the + * value of the state watermark. + * This function is supposed to make best-effort attempt to get the state watermark. If there is + * any error, it will return None. + * + * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark + * is to be calculated + * @param attributesWithEventWatermark attributes of the other side which has a watermark column + * @param joinCondition join condition + * @param eventWatermark watermark defined on the input event data + * @return state value watermark in milliseconds, is possible. + */ + def getStateValueWatermark( + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + joinCondition: Option[Expression], + eventWatermark: Option[Long]): Option[Long] = { + + // If condition or event time watermark is not provided, then cannot calculate state watermark + if (joinCondition.isEmpty || eventWatermark.isEmpty) return None + + // If there is not watermark attribute, then cannot define state watermark + if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None + + def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { + try { + getStateWatermarkFromLessThenPredicate( + l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) + } catch { + case NonFatal(e) => + logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) + None + } + } + + val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => + + // The generated the state watermark cleanup expression is inclusive of the state watermark. + // If state watermark is W, all state where timestamp <= W will be cleaned up. + // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean + // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. + val stateWatermark = predicate match { + case LessThan(l, r) => getStateWatermarkSafely(l, r) + case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) + case GreaterThan(l, r) => getStateWatermarkSafely(r, l) + case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) + case _ => None + } + if (stateWatermark.nonEmpty) { + logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") + } + stateWatermark + } + allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) + } + + /** + * Extract the state value watermark (milliseconds) from the condition + * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for + * leftTime using the watermark on the rightTime. Example: + * + * Input: rightTime-with-watermark + c1 < leftTime + c2 + * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 + * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime + * With watermark value: watermark-value + c1 + (-c2) < leftTime + */ + private def getStateWatermarkFromLessThenPredicate( + leftExpr: Expression, + rightExpr: Expression, + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + eventWatermark: Option[Long]): Option[Long] = { + + val attributesInCondition = AttributeSet( + leftExpr.collect { case a: AttributeReference => a } ++ + rightExpr.collect { case a: AttributeReference => a } + ) + if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || + attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { + // If more than attributes present in condition from one side, then it cannot be solved + return None + } + + def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { + e.collectLeaves().collectFirst { + case a @ AttributeReference(_, _, _, _) + if attributesToFindStateWatermarkFor.contains(a) => a + }.nonEmpty + } + + // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 + val allOnLeftExpr = Subtract(leftExpr, rightExpr) + logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") + + // Canonicalization step 2: extract commutative terms + // rightTime-with-watermark, c1, -leftTime, -c2 + val terms = ExpressionSet(collectTerms(allOnLeftExpr)) + logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) + + // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor + val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) + + // Verify there is only one correct constraint term and of the correct type + if (constraintTerms.size > 1) { + logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + if (constraintTerms.isEmpty) { + logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + val constraintTerm = constraintTerms.head + if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { + // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` + // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. + // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark + // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about + // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, + // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. + return None + } + + // Replace watermark attribute with watermark value, and generate the resolved expression + // from the other terms. That is, + // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) + logDebug(s"Constraint term from join condition:\t$constraintTerm") + val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => + term.transform { + case a @ AttributeReference(_, _, _, metadata) + if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => + Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) + } + }.reduceLeft(Add) + + // Calculate the constraint value + logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") + val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] + Some((Double2double(constraintValue) / 1000.0).toLong) + } + + /** + * Collect all the terms present in an expression after converting it into the form + * a + b + c + d where each term be either an attribute or a literal casted to long, + * optionally wrapped in a unary minus. + */ + private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { + var invalid = false + + /** Wrap a term with UnaryMinus if its needs to be negated. */ + def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { + if (minus) UnaryMinus(expr) else expr + } + + /** + * Recursively split the expression into its leaf terms contains attributes or literals. + * Returns terms only of the forms: + * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), + * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) + * Multiply(Literal), UnaryMinus(Multiply(Literal)) + * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) + * + * Note: + * - If term needs to be negated for making it a commutative term, + * then it will be wrapped in UnaryMinus(...) + * - Each terms will be representing timestamp value or time interval in microseconds, + * typed as doubles. + */ + def collect(expr: Expression, negate: Boolean): Seq[Expression] = { + expr match { + case Add(left, right) => + collect(left, negate) ++ collect(right, negate) + case Subtract(left, right) => + collect(left, negate) ++ collect(right, !negate) + case TimeAdd(left, right, _) => + collect(left, negate) ++ collect(right, negate) + case TimeSub(left, right, _) => + collect(left, negate) ++ collect(right, !negate) + case UnaryMinus(child) => + collect(child, !negate) + case CheckOverflow(child, _) => + collect(child, negate) + case Cast(child, dataType, _) => + dataType match { + case _: NumericType | _: TimestampType => collect(child, negate) + case _ => + invalid = true + Seq.empty + } + case a: AttributeReference => + val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a + Seq(negateIfNeeded(castedRef, negate)) + case lit: Literal => + // If literal of type calendar interval, then explicitly convert to millis + // Convert other number like literal to doubles representing millis (by x1000) + val castedLit = lit.dataType match { + case CalendarIntervalType => + val calendarInterval = lit.value.asInstanceOf[CalendarInterval] + if (calendarInterval.months > 0) { + invalid = true + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom " + + s"as imprecise intervals like months and years cannot be used for" + + s"watermark calculation. Use interval in terms of day instead.") + Literal(0.0) + } else { + Literal(calendarInterval.microseconds.toDouble) + } + case DoubleType => + Multiply(lit, Literal(1000000.0)) + case _: NumericType => + Multiply(Cast(lit, DoubleType), Literal(1000000.0)) + case _: TimestampType => + Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) + } + Seq(negateIfNeeded(castedLit, negate)) + case a @ _ => + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") + invalid = true + Seq.empty + } + } + + val terms = collect(exprToCollectFrom, negate = false) + if (!invalid) terms else Seq.empty + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d1d705691b076..dee6fbe9d1514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -217,7 +218,7 @@ object UnsupportedOperationChecker { throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") - case Join(left, right, joinType, _) => + case Join(left, right, joinType, condition) => joinType match { @@ -233,16 +234,52 @@ object UnsupportedOperationChecker { throwError("Full outer joins with streaming DataFrames/Datasets are not supported") } - case LeftOuter | LeftSemi | LeftAnti => + case LeftSemi | LeftAnti => if (right.isStreaming) { - throwError("Left outer/semi/anti joins with a streaming DataFrame/Dataset " + - "on the right is not supported") + throwError("Left semi/anti joins with a streaming DataFrame/Dataset " + + "on the right are not supported") } + // We support streaming left outer joins with static on the right always, and with + // stream on both sides under the appropriate conditions. + case LeftOuter => + if (!left.isStreaming && right.isStreaming) { + throwError("Left outer join with a streaming DataFrame/Dataset " + + "on the right and a static DataFrame/Dataset on the left is not supported") + } else if (left.isStreaming && right.isStreaming) { + val watermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + left.outputSet, right.outputSet, condition, Some(1000000)).isDefined + + if (!watermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } + } + + // We support streaming right outer joins with static on the left always, and with + // stream on both sides under the appropriate conditions. case RightOuter => - if (left.isStreaming) { - throwError("Right outer join with a streaming DataFrame/Dataset on the left is " + - "not supported") + if (left.isStreaming && !right.isStreaming) { + throwError("Right outer join with a streaming DataFrame/Dataset on the left and " + + "a static DataFrame/DataSet on the right not supported") + } else if (left.isStreaming && right.isStreaming) { + val isWatermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + // Check if the nullable side has a watermark, and there's a range condition which + // implies a state value watermark on the first side. + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + right.outputSet, left.outputSet, condition, Some(1000000)).isDefined + + if (!isWatermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } } case NaturalJoin(_) | UsingJoin(_, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala new file mode 100644 index 0000000000000..8cf41a02320d2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter, LeafNode, LocalRelation} +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, TimestampType} + +class StreamingJoinHelperSuite extends AnalysisTest { + + test("extract watermark from time condition") { + val attributesToFindConstraintFor = Seq( + AttributeReference("leftTime", TimestampType)(), + AttributeReference("leftOther", IntegerType)()) + val metadataWithWatermark = new MetadataBuilder() + .putLong(EventTimeWatermark.delayKey, 1000) + .build() + val attributesWithWatermark = Seq( + AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), + AttributeReference("rightOther", IntegerType)()) + + case class DummyLeafNode() extends LeafNode { + override def output: Seq[Attribute] = + attributesToFindConstraintFor ++ attributesWithWatermark + } + + def watermarkFrom( + conditionStr: String, + rightWatermark: Option[Long] = Some(10000)): Option[Long] = { + val conditionExpr = Some(conditionStr).map { str => + val plan = + Filter( + CatalystSqlParser.parseExpression(str), + DummyLeafNode()) + val optimized = SimpleTestOptimizer.execute(SimpleAnalyzer.execute(plan)) + optimized.asInstanceOf[Filter].condition + } + StreamingJoinHelper.getStateValueWatermark( + AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), + conditionExpr, rightWatermark) + } + + // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, + // then cannot define constraint on leftTime. + assert(watermarkFrom("leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) + assert(watermarkFrom("leftTime < rightTime") === None) + assert(watermarkFrom("leftTime <= rightTime") === None) + assert(watermarkFrom("rightTime > leftTime") === None) + assert(watermarkFrom("rightTime >= leftTime") === None) + assert(watermarkFrom("rightTime < leftTime") === Some(10000)) + assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) + + // Test type conversions + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) + + // Test with timestamp type + calendar interval on either side of equation + // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. + assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) + assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) + assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) + assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) + assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") + === Some(12000)) + + // Test with casted long type + constants on either side of equation + // Note: long type and constants commute, so more combinations to test. + // -- Constants on the right + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") + === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") + === Some(10100)) + assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") + === Some(10200)) + // -- Constants on the left + assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) + assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) + assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") + === Some(7000)) + assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") + === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") + === Some(10100)) + // -- Constants on both sides, mixed types + assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") + === Some(13000)) + + // Test multiple conditions, should return minimum watermark + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === + Some(7000)) // first condition wins + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === + Some(6000)) // second condition wins + + // Test invalid comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes + assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed + + // Test static comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) + + // Test non-positive results + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 10") === Some(0)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 100") === Some(-90000)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 11f48a39c1e25..e5057c451d5b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} +import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -417,9 +418,57 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testBinaryOperationInStreamingPlan( "left outer join", _.join(_, joinType = LeftOuter), - streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + streamStreamSupported = false, + expectedMsg = "outer join") + + // Left outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Left outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute > rightTimeWithWatermark + 10)) + }, + OutputMode.Append()) + + // Left outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute < rightTimeWithWatermark + 10)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Left semi joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -427,7 +476,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftSemi), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Left anti joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -435,14 +484,63 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftAnti), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Right outer joins: stream-* not allowed testBinaryOperationInStreamingPlan( "right outer join", _.join(_, joinType = RightOuter), + streamBatchSupported = false, streamStreamSupported = false, - streamBatchSupported = false) + expectedMsg = "outer join") + + // Right outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Right outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 < attribute)) + }, + OutputMode.Append()) + + // Right outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 > attribute)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Cogroup: only batch-batch is allowed testBinaryOperationInStreamingPlan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 44f1fa58599d2..9bd2127a28ff6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -146,7 +146,14 @@ case class StreamingSymmetricHashJoinExec( stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) } - require(joinType == Inner, s"${getClass.getSimpleName} should not take $joinType as the JoinType") + private def throwBadJoinTypeException(): Nothing = { + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $joinType as the JoinType") + } + + require( + joinType == Inner || joinType == LeftOuter || joinType == RightOuter, + s"${getClass.getSimpleName} should not take $joinType as the JoinType") require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType)) private val storeConf = new StateStoreConf(sqlContext.conf) @@ -157,11 +164,18 @@ case class StreamingSymmetricHashJoinExec( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = left.output ++ right.output + override def output: Seq[Attribute] = joinType match { + case _: InnerLike => left.output ++ right.output + case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case _ => throwBadJoinTypeException() + } override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => PartitioningCollection(Seq(left.outputPartitioning)) + case RightOuter => PartitioningCollection(Seq(right.outputPartitioning)) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -207,31 +221,108 @@ case class StreamingSymmetricHashJoinExec( // matching new left input with new right input, since the new left input has become stored // by that point. This tiny asymmetry is necessary to avoid duplication. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(inputRow).withRight(matchedRow) + (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(input).withRight(matched) } val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(matchedRow).withRight(inputRow) + (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(matched).withRight(input) } // Filter the joined rows based on the given condition. - val outputFilterFunction = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output).eval _ - val filteredOutputIter = - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map { row => - numOutputRows += 1 - row - } + val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _ + + // We need to save the time that the inner join output iterator completes, since outer join + // output counts as both update and removal time. + var innerOutputCompletionTimeNs: Long = 0 + def onInnerOutputCompletion = { + innerOutputCompletionTimeNs = System.nanoTime + } + val filteredInnerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( + (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction), onInnerOutputCompletion) + + def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { + rightSideJoiner.get(leftKeyValue.key).exists( + rightValue => { + outputFilterFunction( + joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) + }) + } + + def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { + leftSideJoiner.get(rightKeyValue.key).exists( + leftValue => { + outputFilterFunction( + joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) + }) + } + + val outputIter: Iterator[InternalRow] = joinType match { + case Inner => + filteredInnerOutputIter + case LeftOuter => + // We generate the outer join input by: + // * Getting an iterator over the rows that have aged out on the left side. These rows are + // candidates for being null joined. Note that to avoid doing two passes, this iterator + // removes the rows from the state manager as they're processed. + // * Checking whether the current row matches a key in the right side state, and that key + // has any value which satisfies the filter function when joined. If it doesn't, + // we know we can join with null, since there was never (including this batch) a match + // within the watermark period. If it does, there must have been a match at some point, so + // we know we can't join with null. + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + val removedRowIter = leftSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithRightSideState(pair)) + .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) + + filteredInnerOutputIter ++ outerOutputIter + case RightOuter => + // See comments for left outer case. + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + val removedRowIter = rightSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithLeftSideState(pair)) + .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) + + filteredInnerOutputIter ++ outerOutputIter + case _ => throwBadJoinTypeException() + } + + val outputIterWithMetrics = outputIter.map { row => + numOutputRows += 1 + row + } // Function to remove old state after all the input has been consumed and output generated def onOutputCompletion = { + // All processing time counts as update time. allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0) - // Remove old state if needed + // Processing time between inner output completion and here comes from the outer portion of a + // join, and thus counts as removal time as we remove old state from one side while iterating. + if (innerOutputCompletionTimeNs != 0) { + allRemovalsTimeMs += + math.max(NANOSECONDS.toMillis(System.nanoTime - innerOutputCompletionTimeNs), 0) + } + allRemovalsTimeMs += timeTakenMs { - leftSideJoiner.removeOldState() - rightSideJoiner.removeOldState() + // Remove any remaining state rows which aren't needed because they're below the watermark. + // + // For inner joins, we have to remove unnecessary state rows from both sides if possible. + // For outer joins, we have already removed unnecessary state rows from the outer side + // (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we + // have to remove unnecessary state rows from the other side (e.g., right side for the left + // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal + // needs to be done greedily by immediately consuming the returned iterator. + val cleanupIter = joinType match { + case Inner => + leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case LeftOuter => rightSideJoiner.removeOldState() + case RightOuter => leftSideJoiner.removeOldState() + case _ => throwBadJoinTypeException() + } + while (cleanupIter.hasNext) { + cleanupIter.next() + } } // Commit all state changes and update state store metrics @@ -251,7 +342,8 @@ case class StreamingSymmetricHashJoinExec( } } - CompletionIterator[InternalRow, Iterator[InternalRow]](filteredOutputIter, onOutputCompletion) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIterWithMetrics, onOutputCompletion) } /** @@ -324,14 +416,32 @@ case class StreamingSymmetricHashJoinExec( } } - /** Remove old buffered state rows using watermarks for state keys and values */ - def removeOldState(): Unit = { + /** + * Get an iterator over the values stored in this joiner's state manager for the given key. + * + * Should not be interleaved with mutations. + */ + def get(key: UnsafeRow): Iterator[UnsafeRow] = { + joinStateManager.get(key) + } + + /** + * Builds an iterator over old state key-value pairs, removing them lazily as they're produced. + * + * @note This iterator must be consumed fully before any other operations are made + * against this joiner's join state manager. For efficiency reasons, the intermediate states of + * the iterator leave the state manager in an undefined state. + * + * We do this to avoid requiring either two passes or full materialization when + * processing the rows for outer join. + */ + def removeOldState(): Iterator[UnsafeRowPair] = { stateWatermarkPredicate match { case Some(JoinStateKeyWatermarkPredicate(expr)) => joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc) case Some(JoinStateValueWatermarkPredicate(expr)) => joinStateManager.removeByValueCondition(stateValueWatermarkPredicateFunc) - case _ => + case _ => Iterator.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index e50274a1baba1..64c7189f72ac3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import org.apache.spark.{Partition, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression @@ -34,7 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval /** * Helper object for [[StreamingSymmetricHashJoinExec]]. See that object for more details. */ -object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { +object StreamingSymmetricHashJoinHelper extends Logging { sealed trait JoinSide case object LeftSide extends JoinSide { override def toString(): String = "left" } @@ -111,7 +112,7 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { expr.map(JoinStateKeyWatermarkPredicate.apply _) } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs - val stateValueWatermark = getStateValueWatermark( + val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark( attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, @@ -132,242 +133,6 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { JoinStateWatermarkPredicates(leftStateWatermarkPredicate, rightStateWatermarkPredicate) } - /** - * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) - * given the join condition and the event time watermark. This is how it works. - * - The condition is split into conjunctive predicates, and we find the predicates of the - * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). - * - We canoncalize the predicate and solve it with the event time watermark value to find the - * value of the state watermark. - * This function is supposed to make best-effort attempt to get the state watermark. If there is - * any error, it will return None. - * - * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark - * is to be calculated - * @param attributesWithEventWatermark attributes of the other side which has a watermark column - * @param joinCondition join condition - * @param eventWatermark watermark defined on the input event data - * @return state value watermark in milliseconds, is possible. - */ - def getStateValueWatermark( - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - joinCondition: Option[Expression], - eventWatermark: Option[Long]): Option[Long] = { - - // If condition or event time watermark is not provided, then cannot calculate state watermark - if (joinCondition.isEmpty || eventWatermark.isEmpty) return None - - // If there is not watermark attribute, then cannot define state watermark - if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None - - def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { - try { - getStateWatermarkFromLessThenPredicate( - l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) - } catch { - case NonFatal(e) => - logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) - None - } - } - - val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => - - // The generated the state watermark cleanup expression is inclusive of the state watermark. - // If state watermark is W, all state where timestamp <= W will be cleaned up. - // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean - // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. - val stateWatermark = predicate match { - case LessThan(l, r) => getStateWatermarkSafely(l, r) - case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) - case GreaterThan(l, r) => getStateWatermarkSafely(r, l) - case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) - case _ => None - } - if (stateWatermark.nonEmpty) { - logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") - } - stateWatermark - } - allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) - } - - /** - * Extract the state value watermark (milliseconds) from the condition - * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for - * leftTime using the watermark on the rightTime. Example: - * - * Input: rightTime-with-watermark + c1 < leftTime + c2 - * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 - * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime - * With watermark value: watermark-value + c1 + (-c2) < leftTime - */ - private def getStateWatermarkFromLessThenPredicate( - leftExpr: Expression, - rightExpr: Expression, - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - eventWatermark: Option[Long]): Option[Long] = { - - val attributesInCondition = AttributeSet( - leftExpr.collect { case a: AttributeReference => a } ++ - rightExpr.collect { case a: AttributeReference => a } - ) - if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || - attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { - // If more than attributes present in condition from one side, then it cannot be solved - return None - } - - def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { - e.collectLeaves().collectFirst { - case a @ AttributeReference(_, TimestampType, _, _) - if attributesToFindStateWatermarkFor.contains(a) => a - }.nonEmpty - } - - // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 - val allOnLeftExpr = Subtract(leftExpr, rightExpr) - logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") - - // Canonicalization step 2: extract commutative terms - // rightTime-with-watermark, c1, -leftTime, -c2 - val terms = ExpressionSet(collectTerms(allOnLeftExpr)) - logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) - - - - // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor - val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) - - // Verify there is only one correct constraint term and of the correct type - if (constraintTerms.size > 1) { - logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - if (constraintTerms.isEmpty) { - logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - val constraintTerm = constraintTerms.head - if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { - // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` - // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. - // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark - // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about - // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, - // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. - return None - } - - // Replace watermark attribute with watermark value, and generate the resolved expression - // from the other terms. That is, - // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) - logDebug(s"Constraint term from join condition:\t$constraintTerm") - val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => - term.transform { - case a @ AttributeReference(_, TimestampType, _, metadata) - if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => - Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) - } - }.reduceLeft(Add) - - // Calculate the constraint value - logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") - val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] - Some((Double2double(constraintValue) / 1000.0).toLong) - } - - /** - * Collect all the terms present in an expression after converting it into the form - * a + b + c + d where each term be either an attribute or a literal casted to long, - * optionally wrapped in a unary minus. - */ - private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { - var invalid = false - - /** Wrap a term with UnaryMinus if its needs to be negated. */ - def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { - if (minus) UnaryMinus(expr) else expr - } - - /** - * Recursively split the expression into its leaf terms contains attributes or literals. - * Returns terms only of the forms: - * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), - * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) - * Multiply(Literal), UnaryMinus(Multiply(Literal)) - * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) - * - * Note: - * - If term needs to be negated for making it a commutative term, - * then it will be wrapped in UnaryMinus(...) - * - Each terms will be representing timestamp value or time interval in microseconds, - * typed as doubles. - */ - def collect(expr: Expression, negate: Boolean): Seq[Expression] = { - expr match { - case Add(left, right) => - collect(left, negate) ++ collect(right, negate) - case Subtract(left, right) => - collect(left, negate) ++ collect(right, !negate) - case TimeAdd(left, right, _) => - collect(left, negate) ++ collect(right, negate) - case TimeSub(left, right, _) => - collect(left, negate) ++ collect(right, !negate) - case UnaryMinus(child) => - collect(child, !negate) - case CheckOverflow(child, _) => - collect(child, negate) - case Cast(child, dataType, _) => - dataType match { - case _: NumericType | _: TimestampType => collect(child, negate) - case _ => - invalid = true - Seq.empty - } - case a: AttributeReference => - val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a - Seq(negateIfNeeded(castedRef, negate)) - case lit: Literal => - // If literal of type calendar interval, then explicitly convert to millis - // Convert other number like literal to doubles representing millis (by x1000) - val castedLit = lit.dataType match { - case CalendarIntervalType => - val calendarInterval = lit.value.asInstanceOf[CalendarInterval] - if (calendarInterval.months > 0) { - invalid = true - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom " + - s"as imprecise intervals like months and years cannot be used for" + - s"watermark calculation. Use interval in terms of day instead.") - Literal(0.0) - } else { - Literal(calendarInterval.microseconds.toDouble) - } - case DoubleType => - Multiply(lit, Literal(1000000.0)) - case _: NumericType => - Multiply(Cast(lit, DoubleType), Literal(1000000.0)) - case _: TimestampType => - Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) - } - Seq(negateIfNeeded(castedLit, negate)) - case a @ _ => - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") - invalid = true - Seq.empty - } - } - - val terms = collect(exprToCollectFrom, negate = false) - if (!invalid) terms else Seq.empty - } - /** * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks' * preferred location is based on which executors have the required join state stores already diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 37648710dfc2a..d256fb578d921 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -76,7 +76,7 @@ class SymmetricHashJoinStateManager( /** Get all the values of a key */ def get(key: UnsafeRow): Iterator[UnsafeRow] = { val numValues = keyToNumValues.get(key) - keyWithIndexToValue.getAll(key, numValues) + keyWithIndexToValue.getAll(key, numValues).map(_.value) } /** Append a new value to the key */ @@ -87,70 +87,163 @@ class SymmetricHashJoinStateManager( } /** - * Remove using a predicate on keys. See class docs for more context and implement details. + * Remove using a predicate on keys. + * + * This produces an iterator over the (key, value) pairs satisfying condition(key), where the + * underlying store is updated as a side-effect of producing next. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator - - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - if (condition(keyToNumValue.key)) { - keyToNumValues.remove(keyToNumValue.key) - keyWithIndexToValue.removeAllValues(keyToNumValue.key, keyToNumValue.numValue) + def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + + private val allKeyToNumValues = keyToNumValues.iterator + + private var currentKeyToNumValue: KeyAndNumValues = null + private var currentValues: Iterator[KeyWithIndexAndValue] = null + + private def currentKey = currentKeyToNumValue.key + + private val reusedPair = new UnsafeRowPair() + + private def getAndRemoveValue() = { + val keyWithIndexAndValue = currentValues.next() + keyWithIndexToValue.remove(currentKey, keyWithIndexAndValue.valueIndex) + reusedPair.withRows(currentKey, keyWithIndexAndValue.value) + } + + override def getNext(): UnsafeRowPair = { + // If there are more values for the current key, remove and return the next one. + if (currentValues != null && currentValues.hasNext) { + return getAndRemoveValue() + } + + // If there weren't any values left, try and find the next key that satisfies the removal + // condition and has values. + while (allKeyToNumValues.hasNext) { + currentKeyToNumValue = allKeyToNumValues.next() + if (removalCondition(currentKey)) { + currentValues = keyWithIndexToValue.getAll( + currentKey, currentKeyToNumValue.numValue) + keyToNumValues.remove(currentKey) + + if (currentValues.hasNext) { + return getAndRemoveValue() + } + } + } + + // We only reach here if there were no satisfying keys left, which means we're done. + finished = true + return null } + + override def close: Unit = {} } } /** - * Remove using a predicate on values. See class docs for more context and implementation details. + * Remove using a predicate on values. + * + * At a high level, this produces an iterator over the (key, value) pairs such that value + * satisfies the predicate, where producing an element removes the value from the state store + * and producing all elements with a given key updates it accordingly. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator + def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - val key = keyToNumValue.key + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() - var numValues: Long = keyToNumValue.numValue - var index: Long = 0L - var valueRemoved: Boolean = false - var valueForIndex: UnsafeRow = null + private val allKeyToNumValues = keyToNumValues.iterator - while (index < numValues) { - if (valueForIndex == null) { - valueForIndex = keyWithIndexToValue.get(key, index) + private var currentKey: UnsafeRow = null + private var numValues: Long = 0L + private var index: Long = 0L + private var valueRemoved: Boolean = false + + // Push the data for the current key to the numValues store, and reset the tracking variables + // to their empty state. + private def updateNumValueForCurrentKey(): Unit = { + if (valueRemoved) { + if (numValues >= 1) { + keyToNumValues.put(currentKey, numValues) + } else { + keyToNumValues.remove(currentKey) + } } - if (condition(valueForIndex)) { - if (numValues > 1) { - val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 1) - keyWithIndexToValue.put(key, index, valueAtMaxIndex) - keyWithIndexToValue.remove(key, numValues - 1) - valueForIndex = valueAtMaxIndex + + currentKey = null + numValues = 0 + index = 0 + valueRemoved = false + } + + // Find the next value satisfying the condition, updating `currentKey` and `numValues` if + // needed. Returns null when no value can be found. + private def findNextValueForIndex(): UnsafeRow = { + // Loop across all values for the current key, and then all other keys, until we find a + // value satisfying the removal condition. + def hasMoreValuesForCurrentKey = currentKey != null && index < numValues + def hasMoreKeys = allKeyToNumValues.hasNext + while (hasMoreValuesForCurrentKey || hasMoreKeys) { + if (hasMoreValuesForCurrentKey) { + // First search the values for the current key. + val currentValue = keyWithIndexToValue.get(currentKey, index) + if (removalCondition(currentValue)) { + return currentValue + } else { + index += 1 + } + } else if (hasMoreKeys) { + // If we can't find a value for the current key, cleanup and start looking at the next. + // This will also happen the first time the iterator is called. + updateNumValueForCurrentKey() + + val currentKeyToNumValue = allKeyToNumValues.next() + currentKey = currentKeyToNumValue.key + numValues = currentKeyToNumValue.numValue } else { - keyWithIndexToValue.remove(key, 0) - valueForIndex = null + // Should be unreachable, but in any case means a value couldn't be found. + return null } - numValues -= 1 - valueRemoved = true - } else { - valueForIndex = null - index += 1 } + + // We tried and failed to find the next value. + return null } - if (valueRemoved) { - if (numValues >= 1) { - keyToNumValues.put(key, numValues) + + override def getNext(): UnsafeRowPair = { + val currentValue = findNextValueForIndex() + + // If there's no value, clean up and finish. There aren't any more available. + if (currentValue == null) { + updateNumValueForCurrentKey() + finished = true + return null + } + + // The backing store is arraylike - we as the caller are responsible for filling back in + // any hole. So we swap the last element into the hole and decrement numValues to shorten. + // clean + if (numValues > 1) { + val valueAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) + keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex) + keyWithIndexToValue.remove(currentKey, numValues - 1) } else { - keyToNumValues.remove(key) + keyWithIndexToValue.remove(currentKey, 0) } + numValues -= 1 + valueRemoved = true + + return reusedPair.withRows(currentKey, currentValue) } - } - } - def iterator(): Iterator[UnsafeRowPair] = { - val pair = new UnsafeRowPair() - keyWithIndexToValue.iterator.map { x => - pair.withRows(x.key, x.value) + override def close: Unit = {} } } @@ -309,19 +402,24 @@ class SymmetricHashJoinStateManager( stateStore.get(keyWithIndexRow(key, valueIndex)) } - /** Get all the values for key and all indices. */ - def getAll(key: UnsafeRow, numValues: Long): Iterator[UnsafeRow] = { + /** + * Get all values and indices for the provided key. + * Should not return null. + */ + def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { + val keyWithIndexAndValue = new KeyWithIndexAndValue() var index = 0 - new NextIterator[UnsafeRow] { - override protected def getNext(): UnsafeRow = { + new NextIterator[KeyWithIndexAndValue] { + override protected def getNext(): KeyWithIndexAndValue = { if (index >= numValues) { finished = true null } else { val keyWithIndex = keyWithIndexRow(key, index) val value = stateStore.get(keyWithIndex) + keyWithIndexAndValue.withNew(key, index, value) index += 1 - value + keyWithIndexAndValue } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index ffa4c3c22a194..d44af1d14c27a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -137,14 +137,16 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter BoundReference( 1, inputValueAttribWithWatermark.dataType, inputValueAttribWithWatermark.nullable), Literal(threshold)) - manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + val iter = manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + while (iter.hasNext) iter.next() } /** Remove values where `time <= threshold` */ def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) - manager.removeByValueCondition( + val iter = manager.removeByValueCondition( GeneratePredicate.generate(expr, inputValueAttribs).eval _) + while (iter.hasNext) iter.next() } def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 533e1165fd59c..a6593b71e51de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -24,8 +24,9 @@ import scala.util.Random import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} @@ -35,7 +36,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { +class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { before { SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' @@ -322,111 +323,6 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } - testQuietly("extract watermark from time condition") { - val attributesToFindConstraintFor = Seq( - AttributeReference("leftTime", TimestampType)(), - AttributeReference("leftOther", IntegerType)()) - val metadataWithWatermark = new MetadataBuilder() - .putLong(EventTimeWatermark.delayKey, 1000) - .build() - val attributesWithWatermark = Seq( - AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), - AttributeReference("rightOther", IntegerType)()) - - def watermarkFrom( - conditionStr: String, - rightWatermark: Option[Long] = Some(10000)): Option[Long] = { - val conditionExpr = Some(conditionStr).map { str => - val plan = - Filter( - spark.sessionState.sqlParser.parseExpression(str), - LogicalRDD( - attributesToFindConstraintFor ++ attributesWithWatermark, - spark.sparkContext.emptyRDD)(spark)) - plan.queryExecution.optimizedPlan.asInstanceOf[Filter].condition - } - StreamingSymmetricHashJoinHelper.getStateValueWatermark( - AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), - conditionExpr, rightWatermark) - } - - // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, - // then cannot define constraint on leftTime. - assert(watermarkFrom("leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) - assert(watermarkFrom("leftTime < rightTime") === None) - assert(watermarkFrom("leftTime <= rightTime") === None) - assert(watermarkFrom("rightTime > leftTime") === None) - assert(watermarkFrom("rightTime >= leftTime") === None) - assert(watermarkFrom("rightTime < leftTime") === Some(10000)) - assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) - - // Test type conversions - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) - - // Test with timestamp type + calendar interval on either side of equation - // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. - assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) - assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) - assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) - assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) - assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") - === Some(12000)) - - // Test with casted long type + constants on either side of equation - // Note: long type and constants commute, so more combinations to test. - // -- Constants on the right - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") - === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") - === Some(10100)) - assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") - === Some(10200)) - // -- Constants on the left - assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) - assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) - assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") - === Some(7000)) - assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") - === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") - === Some(10100)) - // -- Constants on both sides, mixed types - assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") - === Some(13000)) - - // Test multiple conditions, should return minimum watermark - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === - Some(7000)) // first condition wins - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === - Some(6000)) // second condition wins - - // Test invalid comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes - assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed - - // Test static comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) - } - test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ @@ -470,3 +366,189 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo } } } + +class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + import testImplicits._ + import org.apache.spark.sql.functions._ + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + private def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = { + val input = MemoryStream[Int] + val df = input.toDF + .select( + 'value as "key", + 'value.cast("timestamp") as s"${prefix}Time", + ('value * multiplier) as s"${prefix}Value") + .withWatermark(s"${prefix}Time", "10 seconds") + + return (input, df) + } + + private def setupWindowedJoin(joinType: String): + (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + val (input1, df1) = setupStream("left", 2) + val (input2, df2) = setupStream("right", 3) + val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) + val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) + .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + + (input1, input2, joined) + } + + test("windowed left outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + test("windowed right outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + Seq( + ("left_outer", Row(3, null, 5, null)), + ("right_outer", Row(null, 2, null, 5)) + ).foreach { case (joinType: String, outerResult) => + test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range condition") { + import org.apache.spark.sql.functions._ + + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] + + val df1 = leftInput.toDF.toDF("leftKey", "time") + .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + + val df2 = rightInput.toDF.toDF("rightKey", "time") + .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = + df1.join( + df2, + expr("leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"), + joinType) + .select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + testStream(joined)( + AddData(leftInput, (1, 5), (3, 5)), + CheckAnswer(), + AddData(rightInput, (1, 10), (2, 5)), + CheckLastBatch((1, 1, 5, 10)), + AddData(rightInput, (1, 11)), + CheckLastBatch(), // no match as left time is too low + assertNumStateRows(total = 5, updated = 1), + + // Increase event time watermark to 20s by adding data with time = 30s on both inputs + AddData(leftInput, (1, 7), (1, 30)), + CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + assertNumStateRows(total = 7, updated = 2), + AddData(rightInput, (0, 30)), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 1), + AddData(rightInput, (0, 30)), + CheckLastBatch(outerResult), + assertNumStateRows(total = 3, updated = 1) + ) + } + } + + // When the join condition isn't true, the outer null rows must be generated, even if the join + // keys themselves have a match. + test("left outer join with non-key condition violated on left") { + val (leftInput, simpleLeftDf) = setupStream("left", 2) + val (rightInput, simpleRightDf) = setupStream("right", 3) + + val left = simpleLeftDf.select('key, window('leftTime, "10 second"), 'leftValue) + val right = simpleRightDf.select('key, window('rightTime, "10 second"), 'rightValue) + + val joined = left.join( + right, + left("key") === right("key") && left("window") === right("window") && + 'leftValue > 10 && ('rightValue < 300 || 'rightValue > 1000), + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + // leftValue <= 10 should generate outer join rows even though it matches right keys + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 1, 2, 3), + CheckLastBatch(), + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 2), + AddData(rightInput, 20), + CheckLastBatch( + Row(20, 30, 40, 60), Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows + AddData(leftInput, 40, 41), + AddData(rightInput, 40, 41), + CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), + AddData(leftInput, 70), + AddData(rightInput, 71), + CheckLastBatch(), + assertNumStateRows(total = 6, updated = 2), + AddData(rightInput, 70), + CheckLastBatch((70, 80, 140, 210)), + assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left + AddData(leftInput, 101, 102, 103), + AddData(rightInput, 101, 102, 103), + CheckLastBatch(), + AddData(leftInput, 1000), + AddData(rightInput, 1001), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 2), + AddData(rightInput, 1000), + CheckLastBatch( + Row(1000, 1010, 2000, 3000), + Row(101, 110, 202, null), + Row(102, 110, 204, null), + Row(103, 110, 206, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } +} + From d54670192a6acd892d13b511dfb62390be6ad39c Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Wed, 4 Oct 2017 07:11:00 +0100 Subject: [PATCH 662/779] [SPARK-22193][SQL] Minor typo fix ## What changes were proposed in this pull request? [SPARK-22193][SQL] Minor typo fix in SortMergeJoinExec. Nothing major, but it bothered me going into.Hence fixing ## How was this patch tested? existing tests Author: Rekha Joshi Author: rjoshi2 Closes #19422 from rekhajoshm/SPARK-22193. --- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 14de2dc23e3c0..4e02803552e82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -402,7 +402,7 @@ case class SortMergeJoinExec( } } - private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { + private def genComparison(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) => s""" |if (comp == 0) { @@ -463,7 +463,7 @@ case class SortMergeJoinExec( | continue; | } | if (!$matches.isEmpty()) { - | ${genComparision(ctx, leftKeyVars, matchedKeyVars)} + | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} | if (comp == 0) { | return true; | } @@ -484,7 +484,7 @@ case class SortMergeJoinExec( | } | ${rightKeyVars.map(_.code).mkString("\n")} | } - | ${genComparision(ctx, leftKeyVars, rightKeyVars)} + | ${genComparison(ctx, leftKeyVars, rightKeyVars)} | if (comp > 0) { | $rightRow = null; | } else if (comp < 0) { From 64df08b64779bab629a8a90a3797d8bd70f61703 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 4 Oct 2017 15:06:44 +0800 Subject: [PATCH 663/779] [SPARK-20783][SQL] Create ColumnVector to abstract existing compressed column (batch method) ## What changes were proposed in this pull request? This PR abstracts data compressed by `CompressibleColumnAccessor` using `ColumnVector` in batch method. When `ColumnAccessor.decompress` is called, `ColumnVector` will have uncompressed data. This batch decompress does not use `InternalRow` to reduce the number of memory accesses. As first step of this implementation, this JIRA supports primitive data types. Another PR will support array and other data types. This implementation decompress data in batch into uncompressed column batch, as rxin suggested at [here](https://github.com/apache/spark/pull/18468#issuecomment-316914076). Another implementation uses adapter approach [as cloud-fan suggested](https://github.com/apache/spark/pull/18468). ## How was this patch tested? Added test suites Author: Kazuaki Ishizaki Closes #18704 from kiszk/SPARK-20783a. --- .../execution/columnar/ColumnDictionary.java | 58 +++ .../vectorized/OffHeapColumnVector.java | 18 + .../vectorized/OnHeapColumnVector.java | 18 + .../vectorized/WritableColumnVector.java | 76 ++-- .../execution/columnar/ColumnAccessor.scala | 16 +- .../sql/execution/columnar/ColumnType.scala | 33 ++ .../CompressibleColumnAccessor.scala | 4 + .../compression/CompressionScheme.scala | 3 + .../compression/compressionSchemes.scala | 340 +++++++++++++++++- .../compression/BooleanBitSetSuite.scala | 52 +++ .../compression/DictionaryEncodingSuite.scala | 72 +++- .../compression/IntegralDeltaSuite.scala | 72 ++++ .../PassThroughEncodingSuite.scala | 189 ++++++++++ .../compression/RunLengthEncodingSuite.scala | 89 ++++- .../TestCompressibleColumnBuilder.scala | 9 +- .../vectorized/ColumnVectorSuite.scala | 183 +++++++++- .../vectorized/ColumnarBatchSuite.scala | 4 +- 17 files changed, 1192 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java b/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java new file mode 100644 index 0000000000000..f1785853a94ae --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar; + +import org.apache.spark.sql.execution.vectorized.Dictionary; + +public final class ColumnDictionary implements Dictionary { + private int[] intDictionary; + private long[] longDictionary; + + public ColumnDictionary(int[] dictionary) { + this.intDictionary = dictionary; + } + + public ColumnDictionary(long[] dictionary) { + this.longDictionary = dictionary; + } + + @Override + public int decodeToInt(int id) { + return intDictionary[id]; + } + + @Override + public long decodeToLong(int id) { + return longDictionary[id]; + } + + @Override + public float decodeToFloat(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support float"); + } + + @Override + public double decodeToDouble(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support double"); + } + + @Override + public byte[] decodeToBinary(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support String"); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 8cbc895506d91..a7522ebf5821a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -228,6 +228,12 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { null, data + 2 * rowId, count * 2); } + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 2, count * 2); + } + @Override public short getShort(int rowId) { if (dictionary == null) { @@ -268,6 +274,12 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { null, data + 4 * rowId, count * 4); } + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } + @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { @@ -334,6 +346,12 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { null, data + 8 * rowId, count * 8); } + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 8, count * 8); + } + @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 2725a29eeabe8..166a39e0fabd9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -233,6 +233,12 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { System.arraycopy(src, srcIndex, shortData, rowId, count); } + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, + Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + } + @Override public short getShort(int rowId) { if (dictionary == null) { @@ -272,6 +278,12 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { System.arraycopy(src, srcIndex, intData, rowId, count); } + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, + Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + } + @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; @@ -332,6 +344,12 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { System.arraycopy(src, srcIndex, longData, rowId, count); } + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, + Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + } + @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 163f2511e5f73..da72954ddc448 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -113,138 +113,156 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { protected abstract void reserveInternal(int capacity); /** - * Sets the value at rowId to null/not null. + * Sets null/not null to the value at rowId. */ public abstract void putNotNull(int rowId); public abstract void putNull(int rowId); /** - * Sets the values from [rowId, rowId + count) to null/not null. + * Sets null/not null to the values at [rowId, rowId + count). */ public abstract void putNulls(int rowId, int count); public abstract void putNotNulls(int rowId, int count); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putBoolean(int rowId, boolean value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putBooleans(int rowId, int count, boolean value); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putByte(int rowId, byte value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putBytes(int rowId, int count, byte value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putShort(int rowId, short value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putShorts(int rowId, int count, short value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets values from [src[srcIndex], src[srcIndex + count * 2]) to [rowId, rowId + count) + * The data in src must be 2-byte platform native endian shorts. + */ + public abstract void putShorts(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets `value` to the value at rowId. */ public abstract void putInt(int rowId, int value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putInts(int rowId, int count, int value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putInts(int rowId, int count, int[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) + * The data in src must be 4-byte platform native endian ints. + */ + public abstract void putInts(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) * The data in src must be 4-byte little endian ints. */ public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putLong(int rowId, long value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putLongs(int rowId, int count, long value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * Sets values from [src[srcIndex], src[srcIndex + count * 8]) to [rowId, rowId + count) + * The data in src must be 8-byte platform native endian longs. + */ + public abstract void putLongs(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets values from [src + srcIndex, src + srcIndex + count * 8) to [rowId, rowId + count) * The data in src must be 8-byte little endian longs. */ public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putFloat(int rowId, float value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putFloats(int rowId, int count, float value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted floats. + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) + * The data in src must be ieee formatted floats in platform native endian. */ public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putDouble(int rowId, double value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putDoubles(int rowId, int count, double value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted doubles. + * Sets values from [src[srcIndex], src[srcIndex + count * 8]) to [rowId, rowId + count) + * The data in src must be ieee formatted doubles in platform native endian. */ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); @@ -254,7 +272,7 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { public abstract void putArray(int rowId, int offset, int length); /** - * Sets the value at rowId to `value`. + * Sets values from [value + offset, value + offset + count) to the values at rowId. */ public abstract int putByteArray(int rowId, byte[] value, int offset, int count); public final int putByteArray(int rowId, byte[] value) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 6241b79d9affc..24c8ac81420cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -24,6 +24,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ /** @@ -62,6 +63,9 @@ private[columnar] abstract class BasicColumnAccessor[JvmType]( } protected def underlyingBuffer = buffer + + def getByteBuffer: ByteBuffer = + buffer.duplicate.order(ByteOrder.nativeOrder()) } private[columnar] class NullColumnAccessor(buffer: ByteBuffer) @@ -122,7 +126,7 @@ private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor -private[columnar] object ColumnAccessor { +private[sql] object ColumnAccessor { @tailrec def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val buf = buffer.order(ByteOrder.nativeOrder) @@ -149,4 +153,14 @@ private[columnar] object ColumnAccessor { throw new Exception(s"not support type: $other") } } + + def decompress(columnAccessor: ColumnAccessor, columnVector: WritableColumnVector, numRows: Int): + Unit = { + if (columnAccessor.isInstanceOf[NativeColumnAccessor[_]]) { + val nativeAccessor = columnAccessor.asInstanceOf[NativeColumnAccessor[_]] + nativeAccessor.decompress(columnVector, numRows) + } else { + throw new RuntimeException("Not support non-primitive type now") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 5cfb003e4f150..e9b150fd86095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -43,6 +43,12 @@ import org.apache.spark.unsafe.types.UTF8String * WARNING: This only works with HeapByteBuffer */ private[columnar] object ByteBufferHelper { + def getShort(buffer: ByteBuffer): Short = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.getShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + def getInt(buffer: ByteBuffer): Int = { val pos = buffer.position() buffer.position(pos + 4) @@ -66,6 +72,33 @@ private[columnar] object ByteBufferHelper { buffer.position(pos + 8) Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) } + + def putShort(buffer: ByteBuffer, value: Short): Unit = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.putShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putInt(buffer: ByteBuffer, value: Int): Unit = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.putInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putLong(buffer: ByteBuffer, value: Long): Unit = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.putLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def copyMemory(src: ByteBuffer, dst: ByteBuffer, len: Int): Unit = { + val srcPos = src.position() + val dstPos = dst.position() + src.position(srcPos + len) + dst.position(dstPos + len) + Platform.copyMemory(src.array(), Platform.BYTE_ARRAY_OFFSET + srcPos, + dst.array(), Platform.BYTE_ARRAY_OFFSET + dstPos, len) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index e1d13ad0e94e5..774011f1e3de8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types.AtomicType private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { @@ -36,4 +37,7 @@ private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends Colu override def extractSingle(row: InternalRow, ordinal: Int): Unit = { decoder.next(row, ordinal) } + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = + decoder.decompress(columnVector, capacity) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 6e4f1c5b80684..f8aeba44257d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types.AtomicType private[columnar] trait Encoder[T <: AtomicType] { @@ -41,6 +42,8 @@ private[columnar] trait Decoder[T <: AtomicType] { def next(row: InternalRow, ordinal: Int): Unit def hasNext: Boolean + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit } private[columnar] trait CompressionScheme { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index ee99c90a751d9..bf00ad997c76e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer +import java.nio.ByteOrder import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ @@ -61,6 +63,101 @@ private[columnar] case object PassThrough extends CompressionScheme { } override def hasNext: Boolean = buffer.hasRemaining + + private def putBooleans( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + for (i <- 0 until len) { + columnVector.putBoolean(pos + i, (buffer.get(bufferPos + i) != 0)) + } + } + + private def putBytes( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putBytes(pos, len, buffer.array, bufferPos) + } + + private def putShorts( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putShorts(pos, len, buffer.array, bufferPos) + } + + private def putInts( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putInts(pos, len, buffer.array, bufferPos) + } + + private def putLongs( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putLongs(pos, len, buffer.array, bufferPos) + } + + private def putFloats( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putFloats(pos, len, buffer.array, bufferPos) + } + + private def putDoubles( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putDoubles(pos, len, buffer.array, bufferPos) + } + + private def decompress0( + columnVector: WritableColumnVector, + capacity: Int, + unitSize: Int, + putFunction: (WritableColumnVector, Int, Int, Int) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else capacity + var pos = 0 + var seenNulls = 0 + var bufferPos = buffer.position + while (pos < capacity) { + if (pos != nextNullIndex) { + val len = nextNullIndex - pos + assert(len * unitSize < Int.MaxValue) + putFunction(columnVector, pos, bufferPos, len) + bufferPos += len * unitSize + pos += len + } else { + seenNulls += 1 + nextNullIndex = if (seenNulls < nullCount) { + ByteBufferHelper.getInt(nullsBuffer) + } else { + capacity + } + columnVector.putNull(pos) + pos += 1 + } + } + } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + columnType.dataType match { + case _: BooleanType => + val unitSize = 1 + decompress0(columnVector, capacity, unitSize, putBooleans) + case _: ByteType => + val unitSize = 1 + decompress0(columnVector, capacity, unitSize, putBytes) + case _: ShortType => + val unitSize = 2 + decompress0(columnVector, capacity, unitSize, putShorts) + case _: IntegerType => + val unitSize = 4 + decompress0(columnVector, capacity, unitSize, putInts) + case _: LongType => + val unitSize = 8 + decompress0(columnVector, capacity, unitSize, putLongs) + case _: FloatType => + val unitSize = 4 + decompress0(columnVector, capacity, unitSize, putFloats) + case _: DoubleType => + val unitSize = 8 + decompress0(columnVector, capacity, unitSize, putDoubles) + } + } } } @@ -169,6 +266,94 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { } override def hasNext: Boolean = valueCount < run || buffer.hasRemaining + + private def putBoolean(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putBoolean(pos, value == 1) + } + + private def getByte(buffer: ByteBuffer): Long = { + buffer.get().toLong + } + + private def putByte(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putByte(pos, value.toByte) + } + + private def getShort(buffer: ByteBuffer): Long = { + buffer.getShort().toLong + } + + private def putShort(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putShort(pos, value.toShort) + } + + private def getInt(buffer: ByteBuffer): Long = { + buffer.getInt().toLong + } + + private def putInt(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putInt(pos, value.toInt) + } + + private def getLong(buffer: ByteBuffer): Long = { + buffer.getLong() + } + + private def putLong(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putLong(pos, value) + } + + private def decompress0( + columnVector: WritableColumnVector, + capacity: Int, + getFunction: (ByteBuffer) => Long, + putFunction: (WritableColumnVector, Int, Long) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + var runLocal = 0 + var valueCountLocal = 0 + var currentValueLocal: Long = 0 + + while (valueCountLocal < runLocal || (pos < capacity)) { + if (pos != nextNullIndex) { + if (valueCountLocal == runLocal) { + currentValueLocal = getFunction(buffer) + runLocal = ByteBufferHelper.getInt(buffer) + valueCountLocal = 1 + } else { + valueCountLocal += 1 + } + putFunction(columnVector, pos, currentValueLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + columnType.dataType match { + case _: BooleanType => + decompress0(columnVector, capacity, getByte, putBoolean) + case _: ByteType => + decompress0(columnVector, capacity, getByte, putByte) + case _: ShortType => + decompress0(columnVector, capacity, getShort, putShort) + case _: IntegerType => + decompress0(columnVector, capacity, getInt, putInt) + case _: LongType => + decompress0(columnVector, capacity, getLong, putLong) + case _ => throw new IllegalStateException("Not supported type in RunLengthEncoding.") + } + } } } @@ -266,11 +451,32 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { } class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - extends compression.Decoder[T] { - - private val dictionary: Array[Any] = { - val elementNum = ByteBufferHelper.getInt(buffer) - Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) + extends compression.Decoder[T] { + val elementNum = ByteBufferHelper.getInt(buffer) + private val dictionary: Array[Any] = new Array[Any](elementNum) + private var intDictionary: Array[Int] = null + private var longDictionary: Array[Long] = null + + columnType.dataType match { + case _: IntegerType => + intDictionary = new Array[Int](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Int] + intDictionary(i) = v + dictionary(i) = v + } + case _: LongType => + longDictionary = new Array[Long](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Long] + longDictionary(i) = v + dictionary(i) = v + } + case _: StringType => + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Any] + dictionary(i) = v + } } override def next(row: InternalRow, ordinal: Int): Unit = { @@ -278,6 +484,46 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { } override def hasNext: Boolean = buffer.hasRemaining + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + columnType.dataType match { + case _: IntegerType => + val dictionaryIds = columnVector.reserveDictionaryIds(capacity) + columnVector.setDictionary(new ColumnDictionary(intDictionary)) + while (pos < capacity) { + if (pos != nextNullIndex) { + dictionaryIds.putInt(pos, buffer.getShort()) + } else { + seenNulls += 1 + if (seenNulls < nullCount) nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + columnVector.putNull(pos) + } + pos += 1 + } + case _: LongType => + val dictionaryIds = columnVector.reserveDictionaryIds(capacity) + columnVector.setDictionary(new ColumnDictionary(longDictionary)) + while (pos < capacity) { + if (pos != nextNullIndex) { + dictionaryIds.putInt(pos, buffer.getShort()) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + case _ => throw new IllegalStateException("Not supported type in DictionaryEncoding.") + } + } } } @@ -368,6 +614,38 @@ private[columnar] case object BooleanBitSet extends CompressionScheme { } override def hasNext: Boolean = visited < count + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + val countLocal = count + var currentWordLocal: Long = 0 + var visitedLocal: Int = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (visitedLocal < countLocal) { + if (pos != nextNullIndex) { + val bit = visitedLocal % BITS_PER_LONG + + visitedLocal += 1 + if (bit == 0) { + currentWordLocal = ByteBufferHelper.getLong(buffer) + } + + columnVector.putBoolean(pos, ((currentWordLocal >> bit) & 1) != 0) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } @@ -448,6 +726,32 @@ private[columnar] case object IntDelta extends CompressionScheme { prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer) row.setInt(ordinal, prev) } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + var prevLocal: Int = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (pos < capacity) { + if (pos != nextNullIndex) { + val delta = buffer.get + prevLocal = if (delta > Byte.MinValue) { prevLocal + delta } else + { ByteBufferHelper.getInt(buffer) } + columnVector.putInt(pos, prevLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } @@ -528,5 +832,31 @@ private[columnar] case object LongDelta extends CompressionScheme { prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer) row.setLong(ordinal, prev) } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + var prevLocal: Long = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (pos < capacity) { + if (pos != nextNullIndex) { + val delta = buffer.get() + prevLocal = if (delta > Byte.MinValue) { prevLocal + delta } else + { ByteBufferHelper.getLong(buffer) } + columnVector.putLong(pos, prevLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index d01bf911e3a77..2d71a42628dfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.BooleanType class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ @@ -85,6 +87,36 @@ class BooleanBitSetSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(count: Int) { + val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) + val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) + val values = rows.map(_.getBoolean(0)) + + rows.foreach(builder.appendFrom(_, 0)) + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(BooleanBitSet.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) + val columnVector = new OnHeapColumnVector(values.length, BooleanType) + decoder.decompress(columnVector, values.length) + + if (values.nonEmpty) { + values.zipWithIndex.foreach { case (b: Boolean, index: Int) => + assertResult(b, s"Wrong ${index}-th decoded boolean value") { + columnVector.getBoolean(index) + } + } + } + } + test(s"$BooleanBitSet: empty") { skeleton(0) } @@ -104,4 +136,24 @@ class BooleanBitSetSuite extends SparkFunSuite { test(s"$BooleanBitSet: multiple words and 1 more bit") { skeleton(BITS_PER_LONG * 2 + 1) } + + test(s"$BooleanBitSet: empty for decompression()") { + skeletonForDecompress(0) + } + + test(s"$BooleanBitSet: less than 1 word for decompression()") { + skeletonForDecompress(BITS_PER_LONG - 1) + } + + test(s"$BooleanBitSet: exactly 1 word for decompression()") { + skeletonForDecompress(BITS_PER_LONG) + } + + test(s"$BooleanBitSet: multiple whole words for decompression()") { + skeletonForDecompress(BITS_PER_LONG * 2) + } + + test(s"$BooleanBitSet: multiple words and 1 more bit for decompression()") { + skeletonForDecompress(BITS_PER_LONG * 2 + 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index 67139b13d7882..28950b74cf1c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -23,16 +23,19 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { + val nullValue = -1 testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) - testDictionaryEncoding(new StringColumnStats, STRING) + testDictionaryEncoding(new StringColumnStats, STRING, false) def testDictionaryEncoding[T <: AtomicType]( columnStats: ColumnStats, - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + testDecompress: Boolean = true) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -113,6 +116,58 @@ class DictionaryEncodingSuite extends SparkFunSuite { } } + def skeletonForDecompress(uniqueValueCount: Int, inputSeq: Seq[Int]) { + if (!testDecompress) return + val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val dictValues = stableDistinct(inputSeq) + + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + inputSeq.foreach { i => + if (i == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + builder.appendFrom(rows(i), 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = DictionaryEncoding.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + decoder.decompress(columnVector, inputSeq.length) + + if (inputSeq.nonEmpty) { + inputSeq.zipWithIndex.foreach { case (i: Any, index: Int) => + if (i == nullValue) { + assertResult(true, s"Wrong null ${index}-th position") { + columnVector.isNullAt(index) + } + } else { + columnType match { + case INT => + assertResult(values(i), s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case LONG => + assertResult(values(i), s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => fail("Unsupported type") + } + } + } + } + } + test(s"$DictionaryEncoding with $typeName: empty") { skeleton(0, Seq.empty) } @@ -124,5 +179,18 @@ class DictionaryEncodingSuite extends SparkFunSuite { test(s"$DictionaryEncoding with $typeName: dictionary overflow") { skeleton(DictionaryEncoding.MAX_DICT_SIZE + 1, 0 to DictionaryEncoding.MAX_DICT_SIZE) } + + test(s"$DictionaryEncoding with $typeName: empty for decompress()") { + skeletonForDecompress(0, Seq.empty) + } + + test(s"$DictionaryEncoding with $typeName: simple case for decompress()") { + skeletonForDecompress(2, Seq(0, nullValue, 0, nullValue)) + } + + test(s"$DictionaryEncoding with $typeName: dictionary overflow for decompress()") { + skeletonForDecompress(DictionaryEncoding.MAX_DICT_SIZE + 2, + Seq(nullValue) ++ (0 to DictionaryEncoding.MAX_DICT_SIZE - 1) ++ Seq(nullValue)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 411d31fa0e29b..0d9f1fb0c02c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -21,9 +21,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { + val nullValue = -1 testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) @@ -109,6 +111,53 @@ class IntegralDeltaSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(input: Seq[I#InternalType]) { + val builder = TestCompressibleColumnBuilder(columnStats, columnType, scheme) + val row = new GenericInternalRow(1) + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + input.map { value => + if (value == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(scheme.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = scheme.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + decoder.decompress(columnVector, input.length) + + if (input.nonEmpty) { + input.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (expected: Int, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case (expected: Long, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => + fail("Unsupported type") + } + } + } + test(s"$scheme: empty column") { skeleton(Seq.empty) } @@ -127,5 +176,28 @@ class IntegralDeltaSuite extends SparkFunSuite { val input = Array.fill[Any](10000)(makeRandomValue(columnType)) skeleton(input.map(_.asInstanceOf[I#InternalType])) } + + + test(s"$scheme: empty column for decompress()") { + skeletonForDecompress(Seq.empty) + } + + test(s"$scheme: simple case for decompress()") { + val input = columnType match { + case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) + } + + skeletonForDecompress(input.map(_.asInstanceOf[I#InternalType])) + } + + test(s"$scheme: simple case with null for decompress()") { + val input = columnType match { + case INT => Seq(2: Int, 1: Int, 2: Int, nullValue: Int, 5: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, nullValue: Long, 5: Long) + } + + skeletonForDecompress(input.map(_.asInstanceOf[I#InternalType])) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala new file mode 100644 index 0000000000000..b6f0b5e6277b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar.compression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.AtomicType + +class PassThroughSuite extends SparkFunSuite { + val nullValue = -1 + testPassThrough(new ByteColumnStats, BYTE) + testPassThrough(new ShortColumnStats, SHORT) + testPassThrough(new IntColumnStats, INT) + testPassThrough(new LongColumnStats, LONG) + testPassThrough(new FloatColumnStats, FLOAT) + testPassThrough(new DoubleColumnStats, DOUBLE) + + def testPassThrough[T <: AtomicType]( + columnStats: ColumnStats, + columnType: NativeColumnType[T]) { + + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + + def skeleton(input: Seq[T#InternalType]) { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, PassThrough) + + input.map { value => + val row = new GenericInternalRow(1) + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + + val buffer = builder.build() + // Column type ID + null count + null positions + val headerSize = CompressionScheme.columnHeaderSize(buffer) + + // Compression scheme ID + compressed contents + val compressedSize = 4 + input.size * columnType.defaultSize + + // 4 extra bytes for compression scheme type ID + assertResult(headerSize + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + buffer.position(headerSize) + assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + if (input.nonEmpty) { + input.foreach { value => + assertResult(value, "Wrong value")(columnType.extract(buffer)) + } + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = PassThrough.decoder(buffer, columnType) + val mutableRow = new GenericInternalRow(1) + + if (input.nonEmpty) { + input.foreach{ + assert(decoder.hasNext) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } + } + } + assert(!decoder.hasNext) + } + + def skeletonForDecompress(input: Seq[T#InternalType]) { + val builder = TestCompressibleColumnBuilder(columnStats, columnType, PassThrough) + val row = new GenericInternalRow(1) + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + input.map { value => + if (value == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = PassThrough.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + decoder.decompress(columnVector, input.length) + + if (input.nonEmpty) { + input.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (expected: Byte, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded byte value") { + columnVector.getByte(index) + } + case (expected: Short, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded short value") { + columnVector.getShort(index) + } + case (expected: Int, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case (expected: Long, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case (expected: Float, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded float value") { + columnVector.getFloat(index) + } + case (expected: Double, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded double value") { + columnVector.getDouble(index) + } + case _ => fail("Unsupported type") + } + } + } + + test(s"$PassThrough with $typeName: empty column") { + skeleton(Seq.empty) + } + + test(s"$PassThrough with $typeName: long random series") { + val input = Array.fill[Any](10000)(makeRandomValue(columnType)) + skeleton(input.map(_.asInstanceOf[T#InternalType])) + } + + test(s"$PassThrough with $typeName: empty column for decompress()") { + skeletonForDecompress(Seq.empty) + } + + test(s"$PassThrough with $typeName: long random series for decompress()") { + val input = Array.fill[Any](10000)(makeRandomValue(columnType)) + skeletonForDecompress(input.map(_.asInstanceOf[T#InternalType])) + } + + test(s"$PassThrough with $typeName: simple case with null for decompress()") { + val input = columnType match { + case BYTE => Seq(2: Byte, 1: Byte, 2: Byte, nullValue.toByte: Byte, 5: Byte) + case SHORT => Seq(2: Short, 1: Short, 2: Short, nullValue.toShort: Short, 5: Short) + case INT => Seq(2: Int, 1: Int, 2: Int, nullValue: Int, 5: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, nullValue: Long, 5: Long) + case FLOAT => Seq(2: Float, 1: Float, 2: Float, nullValue: Float, 5: Float) + case DOUBLE => Seq(2: Double, 1: Double, 2: Double, nullValue: Double, 5: Double) + } + + skeletonForDecompress(input.map(_.asInstanceOf[T#InternalType])) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index dffa9b364ebfe..eb1cdd9bbceff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -21,19 +21,22 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { + val nullValue = -1 testRunLengthEncoding(new NoopColumnStats, BOOLEAN) testRunLengthEncoding(new ByteColumnStats, BYTE) testRunLengthEncoding(new ShortColumnStats, SHORT) testRunLengthEncoding(new IntColumnStats, INT) testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING) + testRunLengthEncoding(new StringColumnStats, STRING, false) def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + testDecompress: Boolean = true) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -95,6 +98,72 @@ class RunLengthEncodingSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(uniqueValueCount: Int, inputRuns: Seq[(Int, Int)]) { + if (!testDecompress) return + val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val inputSeq = inputRuns.flatMap { case (index, run) => + Seq.fill(run)(index) + } + + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + inputSeq.foreach { i => + if (i == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + builder.appendFrom(rows(i), 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = RunLengthEncoding.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + decoder.decompress(columnVector, inputSeq.length) + + if (inputSeq.nonEmpty) { + inputSeq.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (i: Int, index: Int) => + columnType match { + case BOOLEAN => + assertResult(values(i), s"Wrong ${index}-th decoded boolean value") { + columnVector.getBoolean(index) + } + case BYTE => + assertResult(values(i), s"Wrong ${index}-th decoded byte value") { + columnVector.getByte(index) + } + case SHORT => + assertResult(values(i), s"Wrong ${index}-th decoded short value") { + columnVector.getShort(index) + } + case INT => + assertResult(values(i), s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case LONG => + assertResult(values(i), s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => fail("Unsupported type") + } + case _ => fail("Unsupported type") + } + } + } + test(s"$RunLengthEncoding with $typeName: empty column") { skeleton(0, Seq.empty) } @@ -110,5 +179,21 @@ class RunLengthEncodingSuite extends SparkFunSuite { test(s"$RunLengthEncoding with $typeName: single long run") { skeleton(1, Seq(0 -> 1000)) } + + test(s"$RunLengthEncoding with $typeName: empty column for decompress()") { + skeletonForDecompress(0, Seq.empty) + } + + test(s"$RunLengthEncoding with $typeName: simple case for decompress()") { + skeletonForDecompress(2, Seq(0 -> 2, 1 -> 2)) + } + + test(s"$RunLengthEncoding with $typeName: single long run for decompress()") { + skeletonForDecompress(1, Seq(0 -> 1000)) + } + + test(s"$RunLengthEncoding with $typeName: single case with null for decompress()") { + skeletonForDecompress(2, Seq(0 -> 2, nullValue -> 2)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5e078f251375a..310cb0be5f5a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.execution.columnar._ -import org.apache.spark.sql.types.AtomicType +import org.apache.spark.sql.types.{AtomicType, DataType} class TestCompressibleColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, @@ -42,3 +42,10 @@ object TestCompressibleColumnBuilder { builder } } + +object ColumnBuilderHelper { + def apply( + dataType: DataType, batchSize: Int, name: String, useCompression: Boolean): ColumnBuilder = { + ColumnBuilder(dataType, batchSize, name, useCompression) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 85da8270d4cba..c5c8ae3a17c6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -20,7 +20,10 @@ package org.apache.spark.sql.execution.vectorized import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution.columnar.ColumnAccessor +import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -31,14 +34,21 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { try block(vector) finally vector.close() } + private def withVectors( + size: Int, + dt: DataType)( + block: WritableColumnVector => Unit): Unit = { + withVector(new OnHeapColumnVector(size, dt))(block) + withVector(new OffHeapColumnVector(size, dt))(block) + } + private def testVectors( name: String, size: Int, dt: DataType)( block: WritableColumnVector => Unit): Unit = { test(name) { - withVector(new OnHeapColumnVector(size, dt))(block) - withVector(new OffHeapColumnVector(size, dt))(block) + withVectors(size, dt)(block) } } @@ -218,4 +228,173 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) } } + + test("CachedBatch boolean Apis") { + val dataType = BooleanType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setBoolean(0, i % 2 == 0) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getBoolean(i) == (i % 2 == 0)) + } + } + } + + test("CachedBatch byte Apis") { + val dataType = ByteType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setByte(0, i.toByte) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getByte(i) == i) + } + } + } + + test("CachedBatch short Apis") { + val dataType = ShortType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setShort(0, i.toShort) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getShort(i) == i) + } + } + } + + test("CachedBatch int Apis") { + val dataType = IntegerType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setInt(0, i) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getInt(i) == i) + } + } + } + + test("CachedBatch long Apis") { + val dataType = LongType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setLong(0, i.toLong) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getLong(i) == i.toLong) + } + } + } + + test("CachedBatch float Apis") { + val dataType = FloatType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setFloat(0, i.toFloat) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getFloat(i) == i.toFloat) + } + } + } + + test("CachedBatch double Apis") { + val dataType = DoubleType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setDouble(0, i.toDouble) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getDouble(i) == i.toDouble) + } + } + } } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 983eb103682c1..0b179aa97c479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -413,7 +413,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1 == column.getLong(v._2), "idx=" + v._2 + - " Seed = " + seed + " MemMode=" + memMode) + " Seed = " + seed + " MemMode=" + memMode) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) @@ -1120,7 +1120,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } batch.close() } - }} + }} /** * This test generates a random schema data, serializes it to column batches and verifies the From 4a779bdac3e75c17b7d36c5a009ba6c948fa9fb6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Oct 2017 10:08:24 -0700 Subject: [PATCH 664/779] [SPARK-21871][SQL] Check actual bytecode size when compiling generated code ## What changes were proposed in this pull request? This pr added code to check actual bytecode size when compiling generated code. In #18810, we added code to give up code compilation and use interpreter execution in `SparkPlan` if the line number of generated functions goes over `maxLinesPerFunction`. But, we already have code to collect metrics for compiled bytecode size in `CodeGenerator` object. So,we could easily reuse the code for this purpose. ## How was this patch tested? Added tests in `WholeStageCodegenSuite`. Author: Takeshi Yamamuro Closes #19083 from maropu/SPARK-21871. --- .../expressions/codegen/CodeFormatter.scala | 8 --- .../expressions/codegen/CodeGenerator.scala | 59 +++++++++---------- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateOrdering.scala | 3 +- .../codegen/GeneratePredicate.scala | 3 +- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../codegen/GenerateUnsafeRowJoiner.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 15 ++--- .../codegen/CodeFormatterSuite.scala | 32 ---------- .../sql/execution/WholeStageCodegenExec.scala | 25 ++++---- .../columnar/GenerateColumnAccessor.scala | 3 +- .../execution/WholeStageCodegenSuite.scala | 43 ++++---------- .../benchmark/AggregateBenchmark.scala | 36 +++++------ 14 files changed, 94 insertions(+), 149 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 7b398f424cead..60e600d8dbd8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -89,14 +89,6 @@ object CodeFormatter { } new CodeAndComment(code.result().trim(), map) } - - def stripExtraNewLinesAndComments(input: String): String = { - val commentReg = - ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ - """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment - val codeWithoutComment = commentReg.replaceAllIn(input, "") - codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines - } } private class CodeFormatter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f3b45799c5688..f9c5ef8439085 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -373,20 +373,6 @@ class CodegenContext { */ private val placeHolderToComments = new mutable.HashMap[String, String] - /** - * It will count the lines of every Java function generated by whole-stage codegen, - * if there is a function of length greater than spark.sql.codegen.maxLinesPerFunction, - * it will return true. - */ - def isTooLongGeneratedFunction: Boolean = { - classFunctions.values.exists { _.values.exists { - code => - val codeWithoutComments = CodeFormatter.stripExtraNewLinesAndComments(code) - codeWithoutComments.count(_ == '\n') > SQLConf.get.maxLinesPerFunction - } - } - } - /** * Returns a term name that is unique within this instance of a `CodegenContext`. */ @@ -1020,10 +1006,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } object CodeGenerator extends Logging { + + // This is the value of HugeMethodLimit in the OpenJDK JVM settings + val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + /** * Compile the Java source code into a Java class, using Janino. + * + * @return a pair of a generated class and the max bytecode size of generated functions. */ - def compile(code: CodeAndComment): GeneratedClass = try { + def compile(code: CodeAndComment): (GeneratedClass, Int) = try { cache.get(code) } catch { // Cache.get() may wrap the original exception. See the following URL @@ -1036,7 +1028,7 @@ object CodeGenerator extends Logging { /** * Compile the Java source code into a Java class, using Janino. */ - private[this] def doCompile(code: CodeAndComment): GeneratedClass = { + private[this] def doCompile(code: CodeAndComment): (GeneratedClass, Int) = { val evaluator = new ClassBodyEvaluator() // A special classloader used to wrap the actual parent classloader of @@ -1075,9 +1067,9 @@ object CodeGenerator extends Logging { s"\n${CodeFormatter.format(code)}" }) - try { + val maxCodeSize = try { evaluator.cook("generated.java", code.body) - recordCompilationStats(evaluator) + updateAndGetCompilationStats(evaluator) } catch { case e: JaninoRuntimeException => val msg = s"failed to compile: $e" @@ -1092,13 +1084,15 @@ object CodeGenerator extends Logging { logInfo(s"\n${CodeFormatter.format(code, maxLines)}") throw new CompileException(msg, e.getLocation) } - evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] + + (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) } /** - * Records the generated class and method bytecode sizes by inspecting janino private fields. + * Returns the max bytecode size of the generated functions by inspecting janino private fields. + * Also, this method updates the metrics information. */ - private def recordCompilationStats(evaluator: ClassBodyEvaluator): Unit = { + private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): Int = { // First retrieve the generated classes. val classes = { val resultField = classOf[SimpleCompiler].getDeclaredField("result") @@ -1113,23 +1107,26 @@ object CodeGenerator extends Logging { val codeAttr = Utils.classForName("org.codehaus.janino.util.ClassFile$CodeAttribute") val codeAttrField = codeAttr.getDeclaredField("code") codeAttrField.setAccessible(true) - classes.foreach { case (_, classBytes) => + val codeSizes = classes.flatMap { case (_, classBytes) => CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) try { val cf = new ClassFile(new ByteArrayInputStream(classBytes)) - cf.methodInfos.asScala.foreach { method => - method.getAttributes().foreach { a => - if (a.getClass.getName == codeAttr.getName) { - CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( - codeAttrField.get(a).asInstanceOf[Array[Byte]].length) - } + val stats = cf.methodInfos.asScala.flatMap { method => + method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a => + val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length + CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) + byteCodeSize } } + Some(stats) } catch { case NonFatal(e) => logWarning("Error calculating stats of compiled class.", e) + None } - } + }.flatten + + codeSizes.max } /** @@ -1144,8 +1141,8 @@ object CodeGenerator extends Logging { private val cache = CacheBuilder.newBuilder() .maximumSize(100) .build( - new CacheLoader[CodeAndComment, GeneratedClass]() { - override def load(code: CodeAndComment): GeneratedClass = { + new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { + override def load(code: CodeAndComment): (GeneratedClass, Int) = { val startTime = System.nanoTime() val result = doCompile(code) val endTime = System.nanoTime() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 3768dcde00a4e..b5429fade53cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -142,7 +142,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4e47895985209..1639d1b9dda1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -185,7 +185,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated Ordering by ${ordering.mkString(",")}:\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index e35b9dda6c017..e0fabad6d089a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -78,6 +78,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[Predicate] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 192701a829686..1e4ac3f2afd52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -189,8 +189,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) + val (clazz, _) = CodeGenerator.compile(code) val resultRow = new SpecificInternalRow(expressions.map(_.dataType)) - c.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] + clazz.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index f2a66efc98e71..4bd50aee05514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -409,7 +409,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 4aa5ec82471ec..6bc72a0d75c6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val code = CodeFormatter.stripOverlappingComments(new CodeAndComment(codeBody, Map.empty)) logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1a73d168b9b6e..58323740b80cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -575,15 +576,15 @@ object SQLConf { "disable logging or -1 to apply no limit.") .createWithDefault(1000) - val WHOLESTAGE_MAX_LINES_PER_FUNCTION = buildConf("spark.sql.codegen.maxLinesPerFunction") + val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") .internal() - .doc("The maximum lines of a single Java function generated by whole-stage codegen. " + - "When the generated function exceeds this threshold, " + + .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + + "codegen. When the compiled function exceeds this threshold, " + "the whole-stage codegen is deactivated for this subtree of the current query plan. " + - "The default value 4000 is the max length of byte code JIT supported " + - "for a single function(8000) divided by 2.") + s"The default value is ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} and " + + "this is a limit in the OpenJDK JVM implementation.") .intConf - .createWithDefault(4000) + .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") @@ -1058,7 +1059,7 @@ class SQLConf extends Serializable with Logging { def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) - def maxLinesPerFunction: Int = getConf(WHOLESTAGE_MAX_LINES_PER_FUNCTION) + def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index a0f1a64b0ab08..9d0a41661beaa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -53,38 +53,6 @@ class CodeFormatterSuite extends SparkFunSuite { assert(reducedCode.body === "/*project_c4*/") } - test("removing extra new lines and comments") { - val code = - """ - |/* - | * multi - | * line - | * comments - | */ - | - |public function() { - |/*comment*/ - | /*comment_with_space*/ - |code_body - |//comment - |code_body - | //comment_with_space - | - |code_body - |} - """.stripMargin - - val reducedCode = CodeFormatter.stripExtraNewLinesAndComments(code) - assert(reducedCode === - """ - |public function() { - |code_body - |code_body - |code_body - |} - """.stripMargin) - } - testCase("basic example") { """ |class A { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 268ccfa4edfa0..9073d599ac43d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -380,16 +380,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def doExecute(): RDD[InternalRow] = { val (ctx, cleanedSource) = doCodeGen() - if (ctx.isTooLongGeneratedFunction) { - logWarning("Found too long generated codes and JIT optimization might not work, " + - "Whole-stage codegen disabled for this plan, " + - "You can change the config spark.sql.codegen.MaxFunctionLength " + - "to adjust the function length limit:\n " - + s"$treeString") - return child.execute() - } // try to compile and fallback if it failed - try { + val (_, maxCodeSize) = try { CodeGenerator.compile(cleanedSource) } catch { case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => @@ -397,6 +389,17 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString") return child.execute() } + + // Check if compiled code has a too large function + if (maxCodeSize > sqlContext.conf.hugeMethodLimit) { + logWarning(s"Found too long generated codes and JIT optimization might not work: " + + s"the bytecode size was $maxCodeSize, this value went over the limit " + + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + + s"for this plan. To avoid this, you can raise the limit " + + s"${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}:\n$treeString") + return child.execute() + } + val references = ctx.references.toArray val durationMs = longMetric("pipelineTime") @@ -405,7 +408,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co assert(rdds.size <= 2, "Up to two input RDDs can be supported") if (rdds.length == 1) { rdds.head.mapPartitionsWithIndex { (index, iter) => - val clazz = CodeGenerator.compile(cleanedSource) + val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(iter)) new Iterator[InternalRow] { @@ -424,7 +427,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co // a small hack to obtain the correct partition index }.mapPartitionsWithIndex { (index, zippedIter) => val (leftIter, rightIter) = zippedIter.next() - val clazz = CodeGenerator.compile(cleanedSource) + val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index da34643281911..ae600c1ffae8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -227,6 +227,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated ColumnarIterator:\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(Array.empty).asInstanceOf[ColumnarIterator] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index beeee6a97c8dd..aaa77b3ee6201 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{Column, Dataset, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -151,7 +149,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } } - def genGroupByCodeGenContext(caseNum: Int): CodegenContext = { + def genGroupByCode(caseNum: Int): CodeAndComment = { val caseExp = (1 to caseNum).map { i => s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i" }.toList @@ -176,34 +174,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { }) assert(wholeStageCodeGenExec.isDefined) - wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._1 + wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } - test("SPARK-21603 check there is a too long generated function") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { - val ctx = genGroupByCodeGenContext(30) - assert(ctx.isTooLongGeneratedFunction === true) - } - } - - test("SPARK-21603 check there is not a too long generated function") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { - val ctx = genGroupByCodeGenContext(1) - assert(ctx.isTooLongGeneratedFunction === false) - } - } - - test("SPARK-21603 check there is not a too long generated function when threshold is Int.Max") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> Int.MaxValue.toString) { - val ctx = genGroupByCodeGenContext(30) - assert(ctx.isTooLongGeneratedFunction === false) - } - } - - test("SPARK-21603 check there is a too long generated function when threshold is 0") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "0") { - val ctx = genGroupByCodeGenContext(1) - assert(ctx.isTooLongGeneratedFunction === true) - } + test("SPARK-21871 check if we can get large code size when compiling too long functions") { + val codeWithShortFunctions = genGroupByCode(3) + val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) + assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) + val codeWithLongFunctions = genGroupByCode(20) + val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) + assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 691fa9ac5e1e7..aca1be01fa3da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -24,6 +24,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.execution.vectorized.AggregateHashMap +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 @@ -301,10 +302,10 @@ class AggregateBenchmark extends BenchmarkBase { */ } - ignore("max function length of wholestagecodegen") { + ignore("max function bytecode size of wholestagecodegen") { val N = 20 << 15 - val benchmark = new Benchmark("max function length of wholestagecodegen", N) + val benchmark = new Benchmark("max function bytecode size", N) def f(): Unit = sparkSession.range(N) .selectExpr( "id", @@ -333,33 +334,34 @@ class AggregateBenchmark extends BenchmarkBase { .sum() .collect() - benchmark.addCase(s"codegen = F") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + benchmark.addCase("codegen = F") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") f() } - benchmark.addCase(s"codegen = T maxLinesPerFunction = 10000") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "10000") + benchmark.addCase("codegen = T hugeMethodLimit = 10000") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "10000") f() } - benchmark.addCase(s"codegen = T maxLinesPerFunction = 1500") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "1500") + benchmark.addCase("codegen = T hugeMethodLimit = 1500") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "1500") f() } benchmark.run() /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1 - Intel64 Family 6 Model 58 Stepping 9, GenuineIntel - max function length of wholestagecodegen: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ---------------------------------------------------------------------------------------------- - codegen = F 462 / 533 1.4 704.4 1.0X - codegen = T maxLinesPerFunction = 10000 3444 / 3447 0.2 5255.3 0.1X - codegen = T maxLinesPerFunction = 1500 447 / 478 1.5 682.1 1.0X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + + max function bytecode size: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + codegen = F 709 / 803 0.9 1082.1 1.0X + codegen = T hugeMethodLimit = 10000 3485 / 3548 0.2 5317.7 0.2X + codegen = T hugeMethodLimit = 1500 636 / 701 1.0 969.9 1.1X */ } From bb035f1ee5cdf88e476b7ed83d59140d669fbe12 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 4 Oct 2017 13:13:51 -0700 Subject: [PATCH 665/779] [SPARK-22169][SQL] support byte length literal as identifier ## What changes were proposed in this pull request? By definition the table name in Spark can be something like `123x`, `25a`, etc., with exceptions for literals like `12L`, `23BD`, etc. However, Spark SQL has a special byte length literal, which stops users to use digits followed by `b`, `k`, `m`, `g` as identifiers. byte length literal is not a standard sql literal and is only used in the `tableSample` parser rule. This PR move the parsing of byte length literal from lexer to parser, so that users can use it as identifiers. ## How was this patch tested? regression test Author: Wenchen Fan Closes #19392 from cloud-fan/parser-bug. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 25 +++++++----------- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 26 +++++++++++++------ .../sql/catalyst/parser/PlanParserSuite.scala | 1 + .../execution/command/DDLParserSuite.scala | 19 ++++++++++++++ 5 files changed, 49 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d0a54288780ea..17c8404f8a79c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -25,7 +25,7 @@ grammar SqlBase; * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. - * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is folllowed + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' * which is not a digit or letter or underscore. */ @@ -40,10 +40,6 @@ grammar SqlBase; } } -tokens { - DELIMITER -} - singleStatement : statement EOF ; @@ -447,12 +443,15 @@ joinCriteria ; sample - : TABLESAMPLE '(' - ( (negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) - | (expression sampleType=ROWS) - | sampleType=BYTELENGTH_LITERAL - | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON (identifier | qualifiedName '(' ')'))?)) - ')' + : TABLESAMPLE '(' sampleMethod? ')' + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket + | bytes=expression #sampleByBytes ; identifierList @@ -1004,10 +1003,6 @@ TINYINT_LITERAL : DIGIT+ 'Y' ; -BYTELENGTH_LITERAL - : DIGIT+ ('B' | 'K' | 'M' | 'G') - ; - INTEGER_VALUE : DIGIT+ ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6ba9ee5446a01..95bc3d674b4f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -99,7 +99,7 @@ class SessionCatalog( protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) /** - * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), + * Checks if the given name conforms the Hive standard ("[a-zA-Z_0-9]+"), * i.e. if this name only contains characters, numbers, and _. * * This method is intended to have the same behavior of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 85b492e83446e..ce367145bc637 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -699,20 +699,30 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) } - ctx.sampleType.getType match { - case SqlBaseParser.ROWS => + if (ctx.sampleMethod() == null) { + throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx) + } + + ctx.sampleMethod() match { + case ctx: SampleByRowsContext => Limit(expression(ctx.expression), query) - case SqlBaseParser.PERCENTLIT => + case ctx: SampleByPercentileContext => val fraction = ctx.percentage.getText.toDouble val sign = if (ctx.negativeSign == null) 1 else -1 sample(sign * fraction / 100.0d) - case SqlBaseParser.BYTELENGTH_LITERAL => - throw new ParseException( - "TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + case ctx: SampleByBytesContext => + val bytesStr = ctx.bytes.getText + if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { + throw new ParseException("TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + } else { + throw new ParseException( + bytesStr + " is not a valid byte length literal, " + + "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx) + } - case SqlBaseParser.BUCKET if ctx.ON != null => + case ctx: SampleByBucketContext if ctx.ON() != null => if (ctx.identifier != null) { throw new ParseException( "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx) @@ -721,7 +731,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx) } - case SqlBaseParser.BUCKET => + case ctx: SampleByBucketContext => sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 306e6f2cfbd37..d34a83c42c67e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -110,6 +110,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) assertEqual("select from tbl", OneRowRelation().select('from.as("tbl"))) + assertEqual("select a from 1k.2m", table("1k", "2m").select('a)) } test("reverse select query") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index fa5172ca8a3e7..eb7c33590b602 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -525,6 +525,25 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(e.message.contains("you can only specify one of them.")) } + test("create table - byte length literal table name") { + val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("2g", Some("1m")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType), + provider = Some("parquet")) + + parser.parsePlan(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + test("insert overwrite directory") { val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a" parser.parsePlan(v1) match { From 969ffd631746125eb2b83722baf6f6e7ddd2092c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 4 Oct 2017 19:25:22 -0700 Subject: [PATCH 666/779] [SPARK-22187][SS] Update unsaferow format for saved state such that we can set timeouts when state is null ## What changes were proposed in this pull request? Currently, the group state of user-defined-type is encoded as top-level columns in the UnsafeRows stores in the state store. The timeout timestamp is also saved as (when needed) as the last top-level column. Since the group state is serialized to top-level columns, you cannot save "null" as a value of state (setting null in all the top-level columns is not equivalent). So we don't let the user set the timeout without initializing the state for a key. Based on user experience, this leads to confusion. This PR is to change the row format such that the state is saved as nested columns. This would allow the state to be set to null, and avoid these confusing corner cases. ## How was this patch tested? Refactored tests. Author: Tathagata Das Closes #19416 from tdas/SPARK-22187. --- .../FlatMapGroupsWithStateExec.scala | 133 +++------------ .../FlatMapGroupsWithState_StateManager.scala | 153 ++++++++++++++++++ .../FlatMapGroupsWithStateSuite.scala | 130 ++++++++------- 3 files changed, 246 insertions(+), 170 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index ab690fd5fbbca..aab06d611a5ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -62,26 +60,7 @@ case class FlatMapGroupsWithStateExec( import GroupStateImpl._ private val isTimeoutEnabled = timeoutConf != NoTimeout - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs - } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - + val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -109,11 +88,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateAttributes.toStructType, + stateManager.stateSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val updater = new StateStoreUpdater(store) + val processor = new InputProcessor(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -128,7 +107,7 @@ case class FlatMapGroupsWithStateExec( // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + processor.processNewData(filteredIter) ++ processor.processTimedOutState() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -143,7 +122,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class StateStoreUpdater(store: StateStore) { + class InputProcessor(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -152,14 +131,6 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converters for translating state between rows and Java objects - private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Index of the additional metadata fields in the state row - private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) - // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -168,20 +139,19 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - keyUnsafeRow, + stateManager.getState(store, keyUnsafeRow), valueRowIter, - store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -190,12 +160,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => - val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) - timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + val timingOutKeys = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => - callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) + timingOutKeys.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) } } else Iterator.empty } @@ -205,72 +174,43 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param keyRow Row representing the key, cannot be null + * @param stateData All the data related to the state to be updated * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - keyRow: UnsafeRow, + stateData: FlatMapGroupsWithState_StateData, valueRowIter: Iterator[InternalRow], - prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(keyRow) // convert key to objects + val keyObj = getKeyObj(stateData.keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObj = getStateObj(prevStateRow) - val keyedState = GroupStateImpl.createForStreaming( - Option(stateObj), + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, hasTimedOut) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - - if (keyedState.hasRemoved) { - store.remove(keyRow) + if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 - } else { - val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) - val stateRowToWrite = if (keyedState.hasUpdated) { - getStateRow(keyedState.get) - } else { - prevStateRow - } - - val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp - val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged if (shouldWriteState) { - if (stateRowToWrite == null) { - // This should never happen because checks in GroupStateImpl should avoid cases - // where empty state would need to be written - throw new IllegalStateException("Attempting to write empty state") - } - setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow, stateRowToWrite) + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } } @@ -279,28 +219,5 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } - - /** Returns the state as Java object if defined */ - def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled && stateRow != null) { - stateRow.getLong(timeoutTimestampIndex) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala new file mode 100644 index 0000000000000..d077836da847c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GetStructField, IsNull, Literal, UnsafeRow} +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types.{IntegerType, LongType, StructType} + + +/** + * Class to serialize/write/read/deserialize state for + * [[org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec]]. + */ +class FlatMapGroupsWithState_StateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends Serializable { + + /** Schema of the state rows saved in the state store */ + val stateSchema = { + val schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) else schema + } + + /** Get deserialized state and corresponding timeout timestamp for a key */ + def getState(store: StateStore, keyRow: UnsafeRow): FlatMapGroupsWithState_StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew( + keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(keyRow, stateRow) + } + + /** Removed all information related to a key */ + def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + /** Get all the keys and corresponding state rows in the state store */ + def getAllState(store: StateStore): Iterator[FlatMapGroupsWithState_StateData] = { + val stateDataForGetAllState = FlatMapGroupsWithState_StateData() + store.getRange(None, None).map { pair => + stateDataForGetAllState.withNew( + pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) + } + } + + // Ordinals of the information stored in the state row + private lazy val nestedStateOrdinal = 0 + private lazy val timeoutTimestampOrdinal = 1 + + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val nestedStateExpr = CreateNamedStruct( + stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + if (shouldStoreTimestamp) { + Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nestedStateExpr) + } + } + + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = { + val boundRefToNestedState = BoundReference(nestedStateOrdinal, stateEncoder.schema, true) + val deser = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen() + } + + // Converters for translating state between rows and Java objects + private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateDeserializer, stateSchema.toAttributes) + private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // Reusable instance for returning state information + private lazy val stateDataForGets = FlatMapGroupsWithState_StateData() + + /** Returns the state as Java object if defined */ + private def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow == null) null + else getStateObjFromRow(stateRow) + } + + /** Returns the row for an updated state */ + private def getStateRow(obj: Any): UnsafeRow = { + val row = getStateRowFromObj(obj) + if (obj == null) { + row.setNullAt(nestedStateOrdinal) + } + row + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinal) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps) + } +} + +/** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ +case class FlatMapGroupsWithState_StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 9d74a5c701ef1..d2e8beb2f5290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -289,13 +289,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - // Values used for testing StateStoreUpdater + // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + // Tests for InputProcessor.processNewData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " @@ -322,7 +322,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -365,6 +365,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = None) // state should be removed } + // Tests with ProcessingTimeTimeout + if (priorState == None) { + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + } + testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -375,10 +387,36 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated after state removed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + + // Tests with EventTimeTimeout + + if (priorState == None) { + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { + state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } + testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = - (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + (state: GroupState[Int]) => { + state.update(5); state.setTimeoutTimestamp(5000) + }, timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -397,50 +435,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - } - } - - // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), - // Try to remove these cases in the future - for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - val testName = - if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { + state.remove(); state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } } - // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + // Tests for InputProcessor.processTimedOutState() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -924,7 +935,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"StateStoreUpdater - updates with data - $testName") { + test(s"InputProcessor - process new data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -946,7 +957,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"StateStoreUpdater - updates for timeout - $testName") { + test(s"InputProcessor - process timed out state - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -973,21 +984,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty) { - val row = updater.getStateRow(priorState.get) - updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) - store.put(key.copy(), row.copy()) + if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { + stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() + inputProcessor.processTimedOutState() } else { - updater.updateStateForKeysWithData(Iterator(key)) + inputProcessor.processNewData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -998,15 +1008,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } else { // Call function to update and verify updated state in store callFunction() - val updatedStateRow = store.get(key) - assert( - Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + val updatedState = stateManager.getState(store, key) + assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow != null) { - assert( - updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") - } + assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") } } From c8affec21c91d638009524955515fc143ad86f20 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 4 Oct 2017 20:58:48 -0700 Subject: [PATCH 667/779] [SPARK-22203][SQL] Add job description for file listing Spark jobs ## What changes were proposed in this pull request? The user may be confused about some 10000-tasks jobs. We can add a job description for these jobs so that the user can figure it out. ## How was this patch tested? The new unit test. Before: screen shot 2017-10-04 at 3 22 09 pm After: screen shot 2017-10-04 at 3 13 51 pm Author: Shixiong Zhu Closes #19432 from zsxwing/SPARK-22203. --- .../datasources/InMemoryFileIndex.scala | 85 +++++++++++-------- .../sql/test/DataFrameReaderWriterSuite.scala | 31 +++++++ 2 files changed, 81 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 203d449717512..318ada0ceefc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession @@ -187,42 +188,56 @@ object InMemoryFileIndex extends Logging { // in case of large #defaultParallelism. val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) - val statusMap = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { pathStrings => - val hadoopConf = serializableConfiguration.value - pathStrings.map(new Path(_)).toSeq.map { path => - (path, listLeafFiles(path, hadoopConf, filter, None)) - }.iterator - }.map { case (path, statuses) => - val serializableStatuses = statuses.map { status => - // Turn FileStatus into SerializableFileStatus so we can send it back to the driver - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - SerializableBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[SerializableBlockLocation] - } - - SerializableFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) + val previousJobDescription = sparkContext.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + val statusMap = try { + val description = paths.size match { + case 0 => + s"Listing leaf files and directories 0 paths" + case 1 => + s"Listing leaf files and directories for 1 path:
${paths(0)}" + case s => + s"Listing leaf files and directories for $s paths:
${paths(0)}, ..." } - (path.toString, serializableStatuses) - }.collect() + sparkContext.setJobDescription(description) + sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { pathStrings => + val hadoopConf = serializableConfiguration.value + pathStrings.map(new Path(_)).toSeq.map { path => + (path, listLeafFiles(path, hadoopConf, filter, None)) + }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) + }.collect() + } finally { + sparkContext.setJobDescription(previousJobDescription) + } // turn SerializableFileStatus back to Status statusMap.map { case (path, serializableStatuses) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 569bac156b531..a5d7e6257a6df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -21,10 +21,14 @@ import java.io.File import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.internal.SQLConf @@ -775,4 +779,31 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } } + + test("use Spark jobs to list files") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { + withTempDir { dir => + val jobDescriptions = new ConcurrentLinkedQueue[String]() + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobDescriptions.add(jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) + } + } + sparkContext.addSparkListener(jobListener) + try { + spark.range(0, 3).map(i => (i, i)) + .write.partitionBy("_1").mode("overwrite").parquet(dir.getCanonicalPath) + // normal file paths + checkDatasetUnorderly( + spark.read.parquet(dir.getCanonicalPath).as[(Long, Long)], + 0L -> 0L, 1L -> 1L, 2L -> 2L) + sparkContext.listenerBus.waitUntilEmpty(10000) + assert(jobDescriptions.asScala.toList.exists( + _.contains("Listing leaf files and directories for 3 paths"))) + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + } + } } From ae61f187aa0471242c046fdeac6ed55b9b98a3f6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Oct 2017 23:36:18 +0900 Subject: [PATCH 668/779] [SPARK-22206][SQL][SPARKR] gapply in R can't work on empty grouping columns ## What changes were proposed in this pull request? Looks like `FlatMapGroupsInRExec.requiredChildDistribution` didn't consider empty grouping attributes. It should be a problem when running `EnsureRequirements` and `gapply` in R can't work on empty grouping columns. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19436 from viirya/fix-flatmapinr-distribution. --- R/pkg/tests/fulltests/test_sparkSQL.R | 5 +++++ .../main/scala/org/apache/spark/sql/execution/objects.scala | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7f781f2f66a7f..bbea25bc4da5c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3075,6 +3075,11 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { df1Collect <- gapplyCollect(df, list("a"), function(key, x) { x }) expect_identical(df1Collect, expected) + # gapply on empty grouping columns. + df1 <- gapply(df, c(), function(key, x) { x }, schema(df)) + actual <- collect(df1) + expect_identical(actual, expected) + # Computes the sum of second column by grouping on the first and third columns # and checks if the sum is larger than 2 schemas <- list(structType(structField("a", "integer"), structField("e", "boolean")), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 5a3fcad38888e..c68975bea490f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -394,7 +394,11 @@ case class FlatMapGroupsInRExec( override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(groupingAttributes.map(SortOrder(_, Ascending))) From 83488cc3180ca18f829516f550766efb3095881e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 5 Oct 2017 23:33:49 -0700 Subject: [PATCH 669/779] [SPARK-21871][SQL] Fix infinite loop when bytecode size is larger than spark.sql.codegen.hugeMethodLimit ## What changes were proposed in this pull request? When exceeding `spark.sql.codegen.hugeMethodLimit`, the runtime fallbacks to the Volcano iterator solution. This could cause an infinite loop when `FileSourceScanExec` can use the columnar batch to read the data. This PR is to fix the issue. ## How was this patch tested? Added a test Author: gatorsmile Closes #19440 from gatorsmile/testt. --- .../sql/execution/WholeStageCodegenExec.scala | 12 ++++++---- .../execution/WholeStageCodegenSuite.scala | 23 +++++++++++++++++-- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9073d599ac43d..1aaaf896692d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -392,12 +392,16 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co // Check if compiled code has a too large function if (maxCodeSize > sqlContext.conf.hugeMethodLimit) { - logWarning(s"Found too long generated codes and JIT optimization might not work: " + - s"the bytecode size was $maxCodeSize, this value went over the limit " + + logInfo(s"Found too long generated codes and JIT optimization might not work: " + + s"the bytecode size ($maxCodeSize) is above the limit " + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + s"for this plan. To avoid this, you can raise the limit " + - s"${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}:\n$treeString") - return child.execute() + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") + child match { + // The fallback solution of batch file source scan still uses WholeStageCodegenExec + case f: FileSourceScanExec if f.supportsBatch => // do nothing + case _ => return child.execute() + } } val references = ctx.references.toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index aaa77b3ee6201..098e4cfeb15b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { +class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") @@ -185,4 +185,23 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } + + test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { + import testImplicits._ + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(10).select(Seq.tabulate(201) {i => ('id + i).as(s"c$i")} : _*) + df.write.mode(SaveMode.Overwrite).parquet(path) + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "202", + SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "2000") { + // wide table batch scan causes the byte code of codegen exceeds the limit of + // WHOLESTAGE_HUGE_METHOD_LIMIT + val df2 = spark.read.parquet(path) + val fileScan2 = df2.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + assert(fileScan2.asInstanceOf[FileSourceScanExec].supportsBatch) + checkAnswer(df2, df) + } + } + } } From 0c03297bf0e87944f9fe0535fdae5518228e3e29 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 6 Oct 2017 15:08:28 +0100 Subject: [PATCH 670/779] [SPARK-22142][BUILD][STREAMING] Move Flume support behind a profile, take 2 ## What changes were proposed in this pull request? Move flume behind a profile, take 2. See https://github.com/apache/spark/pull/19365 for most of the back-story. This change should fix the problem by removing the examples module dependency and moving Flume examples to the module itself. It also adds deprecation messages, per a discussion on dev about deprecating for 2.3.0. ## How was this patch tested? Existing tests, which still enable flume integration. Author: Sean Owen Closes #19412 from srowen/SPARK-22142.2. --- dev/create-release/release-build.sh | 4 ++-- dev/mima | 2 +- dev/scalastyle | 1 + dev/sparktestsupport/modules.py | 20 ++++++++++++++++++- dev/test-dependencies.sh | 2 +- docs/building-spark.md | 7 +++++++ docs/streaming-flume-integration.md | 13 +++++------- examples/pom.xml | 7 ------- .../spark/examples}/JavaFlumeEventCount.java | 2 -- .../spark/examples}/FlumeEventCount.scala | 2 -- .../examples}/FlumePollingEventCount.scala | 2 -- .../spark/streaming/flume/FlumeUtils.scala | 1 + pom.xml | 13 +++++++++--- project/SparkBuild.scala | 17 ++++++++-------- python/pyspark/streaming/flume.py | 4 ++++ python/pyspark/streaming/tests.py | 16 ++++++++++++--- 16 files changed, 73 insertions(+), 40 deletions(-) rename {examples/src/main/java/org/apache/spark/examples/streaming => external/flume/src/main/java/org/apache/spark/examples}/JavaFlumeEventCount.java (98%) rename {examples/src/main/scala/org/apache/spark/examples/streaming => external/flume/src/main/scala/org/apache/spark/examples}/FlumeEventCount.scala (98%) rename {examples/src/main/scala/org/apache/spark/examples/streaming => external/flume/src/main/scala/org/apache/spark/examples}/FlumePollingEventCount.scala (98%) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 5390f5916fc0d..7e8d5c7075195 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -84,9 +84,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/mima b/dev/mima index fdb21f5007cf2..1e3ca9700bc07 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index e5aa589869535..89ecc8abd6f8c 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -25,6 +25,7 @@ ERRORS=$(echo -e "q\n" \ -Pmesos \ -Pkafka-0-8 \ -Pyarn \ + -Pflume \ -Phive \ -Phive-thriftserver \ scalastyle test:scalastyle \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 50e14b60545af..91d5667ed1f07 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -279,6 +279,12 @@ def __hash__(self): source_file_regexes=[ "external/flume-sink", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume-sink/test", ] @@ -291,6 +297,12 @@ def __hash__(self): source_file_regexes=[ "external/flume", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume/test", ] @@ -302,7 +314,13 @@ def __hash__(self): dependencies=[streaming_flume, streaming_flume_sink], source_file_regexes=[ "external/flume-assembly", - ] + ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + } ) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index c7714578bd005..58b295d4f6e00 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 diff --git a/docs/building-spark.md b/docs/building-spark.md index 57baa503259c1..98f7df155456f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -100,6 +100,13 @@ Note: Kafka 0.8 support is deprecated as of Spark 2.3.0. Kafka 0.10 support is still automatically built. +## Building with Flume support + +Apache Flume support must be explicitly enabled with the `flume` profile. +Note: Flume support is deprecated as of Spark 2.3.0. + + ./build/mvn -Pflume -DskipTests clean package + ## Building submodules individually It's possible to build Spark sub-modules using the `mvn -pl` option. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index a5d36da5b6de9..257a4f7d4f3ca 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,6 +5,8 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. +**Note: Flume support is deprecated as of Spark 2.3.0.** + ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. @@ -44,8 +46,7 @@ configuring Flume agents. val flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) - See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala). + See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$).
import org.apache.spark.streaming.flume.*; @@ -53,8 +54,7 @@ configuring Flume agents. JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]); - See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). + See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html).
from pyspark.streaming.flume import FlumeUtils @@ -62,8 +62,7 @@ configuring Flume agents. flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. - See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/flume_wordcount.py). + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils).
@@ -162,8 +161,6 @@ configuring Flume agents. - See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). - Note that each input DStream can be configured to receive data from multiple sinks. 3. **Deploying:** This is same as the first approach. diff --git a/examples/pom.xml b/examples/pom.xml index 52a6764ae26a5..1791dbaad775e 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,7 +34,6 @@ examples none package - provided provided provided provided @@ -78,12 +77,6 @@ ${project.version} provided
- - org.apache.spark - spark-streaming-flume_${scala.binary.version} - ${project.version} - provided - org.apache.spark spark-streaming-kafka-0-10_${scala.binary.version} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java similarity index 98% rename from examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java rename to external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java index 0c651049d0ffa..4e3420d9c3b06 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java @@ -48,8 +48,6 @@ public static void main(String[] args) throws Exception { System.exit(1); } - StreamingExamples.setStreamingLogLevels(); - String host = args[0]; int port = Integer.parseInt(args[1]); diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala similarity index 98% rename from examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala rename to external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala index 91e52e4eff5a7..f877f79391b37 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala @@ -47,8 +47,6 @@ object FlumeEventCount { System.exit(1) } - StreamingExamples.setStreamingLogLevels() - val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala similarity index 98% rename from examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala rename to external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala index dd725d72c23ef..79a4027ca5bde 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala @@ -44,8 +44,6 @@ object FlumePollingEventCount { System.exit(1) } - StreamingExamples.setStreamingLogLevels() - val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 3e3ed712f0dbf..707193a957700 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -30,6 +30,7 @@ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream +@deprecated("Deprecated without replacement", "2.3.0") object FlumeUtils { private val DEFAULT_POLLING_PARALLELISM = 5 private val DEFAULT_POLLING_BATCH_SIZE = 1000 diff --git a/pom.xml b/pom.xml index 87a468c3a6f55..9fac8b1e53788 100644 --- a/pom.xml +++ b/pom.xml @@ -98,15 +98,13 @@ sql/core sql/hive assembly - external/flume - external/flume-sink - external/flume-assembly examples repl launcher external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + @@ -2583,6 +2581,15 @@ + + flume + + external/flume + external/flume-sink + external/flume-assembly + + + spark-ganglia-lgpl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a568d264cb2db..9501eed1e906b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,11 +43,8 @@ object BuildCommons { "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) - val streamingProjects@Seq( - streaming, streamingFlumeSink, streamingFlume, streamingKafka010 - ) = Seq( - "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-10" - ).map(ProjectRef(buildLocation, _)) + val streamingProjects@Seq(streaming, streamingKafka010) = + Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* @@ -56,9 +53,13 @@ object BuildCommons { "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects - val optionallyEnabledProjects@Seq(mesos, yarn, streamingKafka, sparkGangliaLgpl, - streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = - Seq("mesos", "yarn", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", + val optionallyEnabledProjects@Seq(mesos, yarn, + streamingFlumeSink, streamingFlume, + streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, + dockerIntegrationTests, hadoopCloud) = + Seq("mesos", "yarn", + "streaming-flume-sink", "streaming-flume", + "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index cd30483fc636a..2fed5940b31ea 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -53,6 +53,8 @@ def createStream(ssc, hostname, port, :param enableDecompression: Should netty server decompress input stream :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object + + .. note:: Deprecated in 2.3.0 """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) helper = FlumeUtils._get_helper(ssc._sc) @@ -79,6 +81,8 @@ def createPollingStream(ssc, addresses, will result in this stream using more threads :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object + + .. note:: Deprecated in 2.3.0 """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) hosts = [] diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 229cf53e47359..5b86c1cb2c390 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1478,7 +1478,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1495,7 +1495,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1516,6 +1516,9 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in modules.py +flume_test_environ_var = "ENABLE_FLUME_TESTS" +are_flume_tests_enabled = os.environ.get(flume_test_environ_var) == '1' # Must be same as the variable and condition defined in modules.py kafka_test_environ_var = "ENABLE_KAFKA_0_8_TESTS" are_kafka_tests_enabled = os.environ.get(kafka_test_environ_var) == '1' @@ -1538,9 +1541,16 @@ def search_kinesis_asl_assembly_jar(): os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - FlumeStreamTests, FlumePollingStreamTests, StreamingListenerTests] + if are_flume_tests_enabled: + testcases.append(FlumeStreamTests) + testcases.append(FlumePollingStreamTests) + else: + sys.stderr.write( + "Skipped test_flume_stream (enable by setting environment variable %s=1" + % flume_test_environ_var) + if are_kafka_tests_enabled: testcases.append(KafkaStreamTests) else: From c7b46d4d8aa8da24131d79d2bfa36e8db19662e4 Mon Sep 17 00:00:00 2001 From: minixalpha Date: Fri, 6 Oct 2017 23:38:47 +0900 Subject: [PATCH 671/779] [SPARK-21877][DEPLOY, WINDOWS] Handle quotes in Windows command scripts ## What changes were proposed in this pull request? All the windows command scripts can not handle quotes in parameter. Run a windows command shell with parameter which has quotes can reproduce the bug: ``` C:\Users\meng\software\spark-2.2.0-bin-hadoop2.7> bin\spark-shell --driver-java-options " -Dfile.encoding=utf-8 " 'C:\Users\meng\software\spark-2.2.0-bin-hadoop2.7\bin\spark-shell2.cmd" --driver-java-options "' is not recognized as an internal or external command, operable program or batch file. ``` Windows recognize "--driver-java-options" as part of the command. All the Windows command script has the following code have the bug. ``` cmd /V /E /C "" %* ``` We should quote command and parameters like ``` cmd /V /E /C """ %*" ``` ## How was this patch tested? Test manually on Windows 10 and Windows 7 We can verify it by the following demo: ``` C:\Users\meng\program\demo>cat a.cmd echo off cmd /V /E /C "b.cmd" %* C:\Users\meng\program\demo>cat b.cmd echo off echo %* C:\Users\meng\program\demo>cat c.cmd echo off cmd /V /E /C ""b.cmd" %*" C:\Users\meng\program\demo>a.cmd "123" 'b.cmd" "123' is not recognized as an internal or external command, operable program or batch file. C:\Users\meng\program\demo>c.cmd "123" "123" ``` With the spark-shell.cmd example, change it to the following code will make the command execute succeed. ``` cmd /V /E /C ""%~dp0spark-shell2.cmd" %*" ``` ``` C:\Users\meng\software\spark-2.2.0-bin-hadoop2.7> bin\spark-shell --driver-java-options " -Dfile.encoding=utf-8 " Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). ... ``` Author: minixalpha Closes #19090 from minixalpha/master. --- bin/beeline.cmd | 4 +++- bin/pyspark.cmd | 4 +++- bin/run-example.cmd | 5 ++++- bin/spark-class.cmd | 4 +++- bin/spark-shell.cmd | 4 +++- bin/spark-submit.cmd | 4 +++- bin/sparkR.cmd | 4 +++- 7 files changed, 22 insertions(+), 7 deletions(-) diff --git a/bin/beeline.cmd b/bin/beeline.cmd index 02464bd088792..288059a28cd74 100644 --- a/bin/beeline.cmd +++ b/bin/beeline.cmd @@ -17,4 +17,6 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -cmd /V /E /C "%~dp0spark-class.cmd" org.apache.hive.beeline.BeeLine %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-class.cmd" org.apache.hive.beeline.BeeLine %*" diff --git a/bin/pyspark.cmd b/bin/pyspark.cmd index 72d046a4ba2cf..3dcf1d45a8189 100644 --- a/bin/pyspark.cmd +++ b/bin/pyspark.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running PySpark. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0pyspark2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0pyspark2.cmd" %*" diff --git a/bin/run-example.cmd b/bin/run-example.cmd index f9b786e92b823..efa5f81d08f7f 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -19,4 +19,7 @@ rem set SPARK_HOME=%~dp0.. set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args] -cmd /V /E /C "%~dp0spark-submit.cmd" run-example %* + +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-submit.cmd" run-example %*" diff --git a/bin/spark-class.cmd b/bin/spark-class.cmd index 3bf3d20cb57b5..b22536ab6f458 100644 --- a/bin/spark-class.cmd +++ b/bin/spark-class.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running a Spark class. To avoid polluting rem the environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-class2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-class2.cmd" %*" diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 991423da6ab99..e734f13097d61 100644 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running Spark shell. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-shell2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-shell2.cmd" %*" diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index f301606933a95..da62a8777524d 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-submit2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-submit2.cmd" %*" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd index 1e5ea6a623219..fcd172b083e1e 100644 --- a/bin/sparkR.cmd +++ b/bin/sparkR.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running SparkR. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0sparkR2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0sparkR2.cmd" %*" From 08b204fd2c731e87d3bc2cc0bccb6339ef7e3a6e Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 6 Oct 2017 12:53:35 -0700 Subject: [PATCH 672/779] [SPARK-22214][SQL] Refactor the list hive partitions code ## What changes were proposed in this pull request? In this PR we make a few changes to the list hive partitions code, to make the code more extensible. The following changes are made: 1. In `HiveClientImpl.getPartitions()`, call `client.getPartitions` instead of `shim.getAllPartitions` when `spec` is empty; 2. In `HiveTableScanExec`, previously we always call `listPartitionsByFilter` if the config `metastorePartitionPruning` is enabled, but actually, we'd better call `listPartitions` if `partitionPruningPred` is empty; 3. We should use sessionCatalog instead of SharedState.externalCatalog in `HiveTableScanExec`. ## How was this patch tested? Tested by existing test cases since this is code refactor, no regression or behavior change is expected. Author: Xingbo Jiang Closes #19444 from jiangxb1987/hivePartitions. --- .../sql/catalyst/catalog/interface.scala | 5 ++++ .../sql/hive/client/HiveClientImpl.scala | 7 +++-- .../hive/execution/HiveTableScanExec.scala | 28 +++++++++---------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index fe2af910a0ae5..975b084aa6188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -405,6 +405,11 @@ object CatalogTypes { * Specifications of a table partition. Mapping column name to column value. */ type TablePartitionSpec = Map[String, String] + + /** + * Initialize an empty spec. + */ + lazy val emptyTablePartitionSpec: TablePartitionSpec = Map.empty[String, String] } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 66165c7228bca..a01c312d5e497 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -638,12 +638,13 @@ private[hive] class HiveClientImpl( table: CatalogTable, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { val hiveTable = toHiveTable(table, Some(userName)) - val parts = spec match { - case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) + val partSpec = spec match { + case None => CatalogTypes.emptyTablePartitionSpec case Some(s) => assert(s.values.forall(_.nonEmpty), s"partition spec '$s' is invalid") - client.getPartitions(hiveTable, s.asJava).asScala.map(fromHivePartition) + s } + val parts = client.getPartitions(hiveTable, partSpec.asJava).asScala.map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) parts } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 48d0b4a63e54a..4f8dab9cd6172 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -162,21 +162,19 @@ case class HiveTableScanExec( // exposed for tests @transient lazy val rawPartitions = { - val prunedPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { - // Retrieve the original attributes based on expression ID so that capitalization matches. - val normalizedFilters = partitionPruningPred.map(_.transform { - case a: AttributeReference => originalAttributes(a) - }) - sparkSession.sharedState.externalCatalog.listPartitionsByFilter( - relation.tableMeta.database, - relation.tableMeta.identifier.table, - normalizedFilters, - sparkSession.sessionState.conf.sessionLocalTimeZone) - } else { - sparkSession.sharedState.externalCatalog.listPartitions( - relation.tableMeta.database, - relation.tableMeta.identifier.table) - } + val prunedPartitions = + if (sparkSession.sessionState.conf.metastorePartitionPruning && + partitionPruningPred.size > 0) { + // Retrieve the original attributes based on expression ID so that capitalization matches. + val normalizedFilters = partitionPruningPred.map(_.transform { + case a: AttributeReference => originalAttributes(a) + }) + sparkSession.sessionState.catalog.listPartitionsByFilter( + relation.tableMeta.identifier, + normalizedFilters) + } else { + sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier) + } prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) } From debcbec7491d3a23b19ef149e50d2887590b6de0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 6 Oct 2017 13:10:04 -0700 Subject: [PATCH 673/779] [SPARK-21947][SS] Check and report error when monotonically_increasing_id is used in streaming query ## What changes were proposed in this pull request? `monotonically_increasing_id` doesn't work in Structured Streaming. We should throw an exception if a streaming query uses it. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19336 from viirya/SPARK-21947. --- .../analysis/UnsupportedOperationChecker.scala | 15 ++++++++++++++- .../analysis/UnsupportedOperationsSuite.scala | 10 +++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index dee6fbe9d1514..04502d04d9509 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, MonotonicallyIncreasingID} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ @@ -129,6 +129,16 @@ object UnsupportedOperationChecker { !subplan.isStreaming || (aggs.nonEmpty && outputMode == InternalOutputModes.Complete) } + def checkUnsupportedExpressions(implicit operator: LogicalPlan): Unit = { + val unsupportedExprs = operator.expressions.flatMap(_.collect { + case m: MonotonicallyIncreasingID => m + }).distinct + if (unsupportedExprs.nonEmpty) { + throwError("Expression(s): " + unsupportedExprs.map(_.sql).mkString(", ") + + " is not supported with streaming DataFrames/Datasets") + } + } + plan.foreachUp { implicit subPlan => // Operations that cannot exists anywhere in a streaming plan @@ -323,6 +333,9 @@ object UnsupportedOperationChecker { case _ => } + + // Check if there are unsupported expressions in streaming query plan. + checkUnsupportedExpressions(subPlan) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index e5057c451d5b8..60d1351fda264 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, MonotonicallyIncreasingID, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} @@ -614,6 +614,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testOutputMode(Update, shouldSupportAggregation = true, shouldSupportNonAggregation = true) testOutputMode(Complete, shouldSupportAggregation = true, shouldSupportNonAggregation = false) + // Unsupported expressions in streaming plan + assertNotSupportedInStreamingPlan( + "MonotonicallyIncreasingID", + streamRelation.select(MonotonicallyIncreasingID()), + outputMode = Append, + expectedMsgs = Seq("monotonically_increasing_id")) + + /* ======================================================================================= TESTING FUNCTIONS From 2030f19511f656e9534f3fd692e622e45f9a074e Mon Sep 17 00:00:00 2001 From: Sergey Zhemzhitsky Date: Fri, 6 Oct 2017 20:43:53 -0700 Subject: [PATCH 674/779] [SPARK-21549][CORE] Respect OutputFormats with no output directory provided ## What changes were proposed in this pull request? Fix for https://issues.apache.org/jira/browse/SPARK-21549 JIRA issue. Since version 2.2 Spark does not respect OutputFormat with no output paths provided. The examples of such formats are [Cassandra OutputFormat](https://github.com/finn-no/cassandra-hadoop/blob/08dfa3a7ac727bb87269f27a1c82ece54e3f67e6/src/main/java/org/apache/cassandra/hadoop2/AbstractColumnFamilyOutputFormat.java), [Aerospike OutputFormat](https://github.com/aerospike/aerospike-hadoop/blob/master/mapreduce/src/main/java/com/aerospike/hadoop/mapreduce/AerospikeOutputFormat.java), etc. which do not have an ability to rollback the results written to an external systems on job failure. Provided output directory is required by Spark to allows files to be committed to an absolute output location, that is not the case for output formats which write data to external systems. This pull request prevents accessing `absPathStagingDir` method that causes the error described in SPARK-21549 unless there are files to rename in `addedAbsPathFiles`. ## How was this patch tested? Unit tests Author: Sergey Zhemzhitsky Closes #19294 from szhem/SPARK-21549-abs-output-commits. --- .../io/HadoopMapReduceCommitProtocol.scala | 28 ++++++++++++---- .../spark/rdd/PairRDDFunctionsSuite.scala | 33 ++++++++++++++++++- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index b1d07ab2c9199..a7e6859ef6b64 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -35,6 +35,9 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * (from the newer mapreduce API, not the old mapred API). * * Unlike Hadoop's OutputCommitter, this implementation is serializable. + * + * @param jobId the job's or stage's id + * @param path the job's output path, or null if committer acts as a noop */ class HadoopMapReduceCommitProtocol(jobId: String, path: String) extends FileCommitProtocol with Serializable with Logging { @@ -57,6 +60,15 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) */ private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + /** + * Checks whether there are files to be committed to an absolute output location. + * + * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, + * it is necessary to check whether the output path is specified. Output path may not be required + * for committers not writing to distributed file systems. + */ + private def hasAbsPathFiles: Boolean = path != null + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() // If OutputFormat is Configurable, we should set conf to it. @@ -130,17 +142,21 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) .foldLeft(Map[String, String]())(_ ++ _) logDebug(s"Committing files staged for absolute locations $filesToMove") - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - for ((src, dst) <- filesToMove) { - fs.rename(new Path(src), new Path(dst)) + if (hasAbsPathFiles) { + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) } - fs.delete(absPathStagingDir, true) } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + if (hasAbsPathFiles) { + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) + } } override def setupTask(taskContext: TaskAttemptContext): Unit = { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 44dd955ce8690..07579c5098014 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -26,7 +26,7 @@ import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistr import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ -import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, +import org.apache.hadoop.mapreduce.{Job => NewJob, JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.hadoop.util.Progressable @@ -568,6 +568,37 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") } + test("saveAsNewAPIHadoopDataset should respect empty output directory when " + + "there are no files to be committed to an absolute output location") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) + job.setOutputKeyClass(classOf[Integer]) + job.setOutputValueClass(classOf[Integer]) + job.setOutputFormatClass(classOf[NewFakeFormat]) + val jobConfiguration = job.getConfiguration + + // just test that the job does not fail with + // java.lang.IllegalArgumentException: Can not create a Path from a null string + pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + } + + test("saveAsHadoopDataset should respect empty output directory when " + + "there are no files to be committed to an absolute output location") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + val conf = new JobConf() + conf.setOutputKeyClass(classOf[Integer]) + conf.setOutputValueClass(classOf[Integer]) + conf.setOutputFormat(classOf[FakeOutputFormat]) + conf.setOutputCommitter(classOf[FakeOutputCommitter]) + + FakeOutputCommitter.ran = false + pairs.saveAsHadoopDataset(conf) + + assert(FakeOutputCommitter.ran, "OutputCommitter was never called") + } + test("lookup") { val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) From 5eacc3bfa9b9c1435ce04222ac7f943b5f930cf4 Mon Sep 17 00:00:00 2001 From: Kento NOZAWA Date: Sat, 7 Oct 2017 08:30:48 +0100 Subject: [PATCH 675/779] [SPARK-22156][MLLIB] Fix update equation of learning rate in Word2Vec.scala ## What changes were proposed in this pull request? Current equation of learning rate is incorrect when `numIterations` > `1`. This PR is based on [original C code](https://github.com/tmikolov/word2vec/blob/master/word2vec.c#L393). cc: mengxr ## How was this patch tested? manual tests I modified [this example code](https://spark.apache.org/docs/2.1.1/mllib-feature-extraction.html#example). ### `numIteration=1` #### Code ```scala import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq) val word2vec = new Word2Vec() val model = word2vec.fit(input) val synonyms = model.findSynonyms("1", 5) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } ``` #### Result ``` 2 0.175856813788414 0 0.10971353203058243 4 0.09818313270807266 3 0.012947646901011467 9 -0.09881238639354706 ``` ### `numIteration=5` #### Code ```scala import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq) val word2vec = new Word2Vec() word2vec.setNumIterations(5) val model = word2vec.fit(input) val synonyms = model.findSynonyms("1", 5) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } ``` #### Result ``` 0 0.9898583889007568 2 0.9808019399642944 4 0.9794934391975403 3 0.9506527781486511 9 -0.9065656661987305 ``` Author: Kento NOZAWA Closes #19372 from nzw0301/master. --- .../org/apache/spark/mllib/feature/Word2Vec.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 6f96813497b62..b8c306d86bace 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -353,11 +353,14 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) + val totalWordsCounts = numIterations * trainWordsCount + 1 var alpha = learningRate for (k <- 1 to numIterations) { val bcSyn0Global = sc.broadcast(syn0Global) val bcSyn1Global = sc.broadcast(syn1Global) + val numWordsProcessedInPreviousIterations = (k - 1) * trainWordsCount + val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) @@ -368,11 +371,12 @@ class Word2Vec extends Serializable with Logging { var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - // TODO: discount by iteration? - alpha = - learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + alpha = learningRate * + (1 - (numPartitions * wordCount.toDouble + numWordsProcessedInPreviousIterations) / + totalWordsCounts) if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 - logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + logInfo(s"wordCount = ${wordCount + numWordsProcessedInPreviousIterations}, " + + s"alpha = $alpha") } wc += sentence.length var pos = 0 From c998a2ae0ea019dfb9b39cef6ddfac07c496e083 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Sun, 8 Oct 2017 12:58:39 +0100 Subject: [PATCH 676/779] [SPARK-22147][CORE] Removed redundant allocations from BlockId ## What changes were proposed in this pull request? Prior to this commit BlockId.hashCode and BlockId.equals were defined in terms of BlockId.name. This allowed the subclasses to be concise and enforced BlockId.name as a single unique identifier for a block. All subclasses override BlockId.name with an expression involving an allocation of StringBuilder and ultimatelly String. This is suboptimal since it induced unnecessary GC pressure on the dirver, see BlockManagerMasterEndpoint. The commit removes the definition of hashCode and equals from the base class. No other change is necessary since all subclasses are in fact case classes and therefore have auto-generated hashCode and equals. No change of behaviour is expected. Sidenote: you might be wondering, why did the subclasses use the base implementation and the auto-generated one? Apparently, this behaviour is documented in the spec. See this SO answer for details https://stackoverflow.com/a/44990210/262432. ## How was this patch tested? BlockIdSuite Author: Sergei Lebedev Closes #19369 from superbobry/blockid-equals-hashcode. --- .../netty/NettyBlockTransferService.scala | 2 +- .../org/apache/spark/storage/BlockId.scala | 5 -- .../org/apache/spark/storage/DiskStore.scala | 8 +-- .../BlockManagerReplicationSuite.scala | 49 ------------------- 4 files changed, 5 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index ac4d85004bad1..6a29e18bf3cbb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -151,7 +151,7 @@ private[spark] class NettyBlockTransferService( // Convert or copy nio buffer into array in order to serialize it. val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer, + client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, new RpcResponseCallback { override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 524f6970992a5..a441baed2800e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -41,11 +41,6 @@ sealed abstract class BlockId { def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name - override def hashCode: Int = name.hashCode - override def equals(other: Any): Boolean = other match { - case o: BlockId => getClass == o.getClass && name.equals(o.name) - case _ => false - } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 3579acf8d83d9..97abd92d4b70f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -47,9 +47,9 @@ private[spark] class DiskStore( private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") private val maxMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString) - private val blockSizes = new ConcurrentHashMap[String, Long]() + private val blockSizes = new ConcurrentHashMap[BlockId, Long]() - def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) + def getSize(blockId: BlockId): Long = blockSizes.get(blockId) /** * Invokes the provided callback function to write the specific block. @@ -67,7 +67,7 @@ private[spark] class DiskStore( var threwException: Boolean = true try { writeFunc(out) - blockSizes.put(blockId.name, out.getCount) + blockSizes.put(blockId, out.getCount) threwException = false } finally { try { @@ -113,7 +113,7 @@ private[spark] class DiskStore( } def remove(blockId: BlockId): Boolean = { - blockSizes.remove(blockId.name) + blockSizes.remove(blockId) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index dd61dcd11bcda..c2101ba828553 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -198,55 +198,6 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite } } - test("block replication - deterministic node selection") { - val blockSize = 1000 - val storeSize = 10000 - val stores = (1 to 5).map { - i => makeBlockManager(storeSize, s"store$i") - } - val storageLevel2x = StorageLevel.MEMORY_AND_DISK_2 - val storageLevel3x = StorageLevel(true, true, false, true, 3) - val storageLevel4x = StorageLevel(true, true, false, true, 4) - - def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = { - stores.head.putSingle(blockId, new Array[Byte](blockSize), level) - val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet - stores.foreach { _.removeBlock(blockId) } - master.removeBlock(blockId) - locations - } - - // Test if two attempts to 2x replication returns same set of locations - val a1Locs = putBlockAndGetLocations("a1", storageLevel2x) - assert(putBlockAndGetLocations("a1", storageLevel2x) === a1Locs, - "Inserting a 2x replicated block second time gave different locations from the first") - - // Test if two attempts to 3x replication returns same set of locations - val a2Locs3x = putBlockAndGetLocations("a2", storageLevel3x) - assert(putBlockAndGetLocations("a2", storageLevel3x) === a2Locs3x, - "Inserting a 3x replicated block second time gave different locations from the first") - - // Test if 2x replication of a2 returns a strict subset of the locations of 3x replication - val a2Locs2x = putBlockAndGetLocations("a2", storageLevel2x) - assert( - a2Locs2x.subsetOf(a2Locs3x), - "Inserting a with 2x replication gave locations that are not a subset of locations" + - s" with 3x replication [3x: ${a2Locs3x.mkString(",")}; 2x: ${a2Locs2x.mkString(",")}" - ) - - // Test if 4x replication of a2 returns a strict superset of the locations of 3x replication - val a2Locs4x = putBlockAndGetLocations("a2", storageLevel4x) - assert( - a2Locs3x.subsetOf(a2Locs4x), - "Inserting a with 4x replication gave locations that are not a superset of locations " + - s"with 3x replication [3x: ${a2Locs3x.mkString(",")}; 4x: ${a2Locs4x.mkString(",")}" - ) - - // Test if 3x replication of two different blocks gives two different sets of locations - val a3Locs3x = putBlockAndGetLocations("a3", storageLevel3x) - assert(a3Locs3x !== a2Locs3x, "Two blocks gave same locations with 3x replication") - } - test("block replication - replication failures") { /* Create a system of three block managers / stores. One of them (say, failableStore) From fe7b219ae3e8a045655a836cbb77219036ec5740 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 9 Oct 2017 14:16:25 +0800 Subject: [PATCH 677/779] [SPARK-22074][CORE] Task killed by other attempt task should not be resubmitted ## What changes were proposed in this pull request? As the detail scenario described in [SPARK-22074](https://issues.apache.org/jira/browse/SPARK-22074), unnecessary resubmitted may cause stage hanging in currently release versions. This patch add a new var in TaskInfo to mark this task killed by other attempt or not. ## How was this patch tested? Add a new UT `[SPARK-22074] Task killed by other attempt task should not be resubmitted` in TaskSetManagerSuite, this UT recreate the scenario in JIRA description, it failed without the changes in this PR and passed conversely. Author: Yuanjian Li Closes #19287 from xuanyuanking/SPARK-22074. --- .../spark/scheduler/TaskSetManager.scala | 8 +- .../org/apache/spark/scheduler/FakeTask.scala | 20 +++- .../spark/scheduler/TaskSetManagerSuite.scala | 107 ++++++++++++++++++ 3 files changed, 132 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3bdede6743d1b..de4711f461df2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -83,6 +83,11 @@ private[spark] class TaskSetManager( val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) + // Set the coresponding index of Boolean var when the task killed by other attempt tasks, + // this happened while we set the `spark.speculation` to true. The task killed by others + // should not resubmit while executor lost. + private val killedByOtherAttempt: Array[Boolean] = new Array[Boolean](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) private[scheduler] var tasksSuccessful = 0 @@ -729,6 +734,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") + killedByOtherAttempt(index) = true sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, @@ -915,7 +921,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index)) { + if (successful(index) && !killedByOtherAttempt(index)) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index fe6de2bd98850..109d4a0a870b8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,8 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import org.apache.spark.SparkEnv -import org.apache.spark.TaskContext +import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.executor.TaskMetrics class FakeTask( @@ -58,4 +57,21 @@ object FakeTask { } new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } + + def createShuffleMapTaskSet( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new ShuffleMapTask(stageId, stageAttemptId, null, new Partition { + override def index: Int = i + }, prefLocs(i), new Properties, + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) + } + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 5c712bd6a545b..2ce81ae27daf6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -744,6 +744,113 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(resubmittedTasks === 0) } + + test("[SPARK-22074] Task killed by other attempt task should not be resubmitted") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.5") + sc.conf.set("spark.speculation", "true") + + var killTaskCalled = false + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec2", "host2"), ("exec3", "host3")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Check the only one killTask event in this case, which triggered by + // task 2.1 completed. + assert(taskId === 2) + assert(executorId === "exec3") + assert(interruptThread) + assert(reason === "another attempt succeeded") + killTaskCalled = true + } + }) + + // Keep track of the number of tasks that are resubmitted, + // so that the test can check that no tasks were resubmitted. + var resubmittedTasks = 0 + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += 1 + case _ => + } + } + } + sched.setDAGScheduler(dagScheduler) + + val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host3", "exec3")), + Seq(TaskLocation("host2", "exec2"))) + + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((exec, host) <- Seq( + "exec1" -> "host1", + "exec1" -> "host1", + "exec3" -> "host3", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(exec, host, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === exec) + // Add an extra assert to make sure task 2.0 is running on exec3 + if (task.index == 2) { + assert(task.attemptNumber === 0) + assert(task.executorId === "exec3") + } + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 2 tasks and leave 2 task in running + for (id <- Set(0, 1)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(2, 3)) + + // Offer resource to start the speculative attempt for the running task 2.0 + val taskOption = manager.resourceOffer("exec2", "host2", ANY) + assert(taskOption.isDefined) + val task4 = taskOption.get + assert(task4.index === 2) + assert(task4.taskId === 4) + assert(task4.executorId === "exec2") + assert(task4.attemptNumber === 1) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) + // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called + assert(killTaskCalled) + // Host 3 Losts, there's only task 2.0 on it, which killed by task 2.1 + manager.executorLost("exec3", "host3", SlaveLost()) + // Check the resubmittedTasks + assert(resubmittedTasks === 0) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler( From 98057583dd2787c0e396c2658c7dd76412f86936 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Mon, 9 Oct 2017 10:42:33 +0200 Subject: [PATCH 678/779] [SPARK-20679][ML] Support recommending for a subset of users/items in ALSModel This PR adds methods `recommendForUserSubset` and `recommendForItemSubset` to `ALSModel`. These allow recommending for a specified set of user / item ids rather than for every user / item (as in the `recommendForAllX` methods). The subset methods take a `DataFrame` as input, containing ids in the column specified by the param `userCol` or `itemCol`. The model will generate recommendations for each _unique_ id in this input dataframe. ## How was this patch tested? New unit tests in `ALSSuite` and Python doctests in `ALS`. Ran updated examples locally. Author: Nick Pentreath Closes #18748 from MLnick/als-recommend-df. --- .../spark/examples/ml/JavaALSExample.java | 9 ++ examples/src/main/python/ml/als_example.py | 9 ++ .../apache/spark/examples/ml/ALSExample.scala | 9 ++ .../apache/spark/ml/recommendation/ALS.scala | 48 +++++++++ .../spark/ml/recommendation/ALSSuite.scala | 100 ++++++++++++++++-- python/pyspark/ml/recommendation.py | 38 +++++++ 6 files changed, 205 insertions(+), 8 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index fe4d6bc83f04a..27052be87b82e 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -118,9 +118,18 @@ public static void main(String[] args) { Dataset userRecs = model.recommendForAllUsers(10); // Generate top 10 user recommendations for each movie Dataset movieRecs = model.recommendForAllItems(10); + + // Generate top 10 movie recommendations for a specified set of users + Dataset users = ratings.select(als.getUserCol()).distinct().limit(3); + Dataset userSubsetRecs = model.recommendForUserSubset(users, 10); + // Generate top 10 user recommendations for a specified set of movies + Dataset movies = ratings.select(als.getItemCol()).distinct().limit(3); + Dataset movieSubSetRecs = model.recommendForItemSubset(movies, 10); // $example off$ userRecs.show(); movieRecs.show(); + userSubsetRecs.show(); + movieSubSetRecs.show(); spark.stop(); } diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 1672d552eb1d5..8b7ec9c439f9f 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -60,8 +60,17 @@ userRecs = model.recommendForAllUsers(10) # Generate top 10 user recommendations for each movie movieRecs = model.recommendForAllItems(10) + + # Generate top 10 movie recommendations for a specified set of users + users = ratings.select(als.getUserCol()).distinct().limit(3) + userSubsetRecs = model.recommendForUserSubset(users, 10) + # Generate top 10 user recommendations for a specified set of movies + movies = ratings.select(als.getItemCol()).distinct().limit(3) + movieSubSetRecs = model.recommendForItemSubset(movies, 10) # $example off$ userRecs.show() movieRecs.show() + userSubsetRecs.show() + movieSubSetRecs.show() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index 07b15dfa178f7..8091838a2301e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -80,9 +80,18 @@ object ALSExample { val userRecs = model.recommendForAllUsers(10) // Generate top 10 user recommendations for each movie val movieRecs = model.recommendForAllItems(10) + + // Generate top 10 movie recommendations for a specified set of users + val users = ratings.select(als.getUserCol).distinct().limit(3) + val userSubsetRecs = model.recommendForUserSubset(users, 10) + // Generate top 10 user recommendations for a specified set of movies + val movies = ratings.select(als.getItemCol).distinct().limit(3) + val movieSubSetRecs = model.recommendForItemSubset(movies, 10) // $example off$ userRecs.show() movieRecs.show() + userSubsetRecs.show() + movieSubSetRecs.show() spark.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 3d5fd1794de23..a8843661c873b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -344,6 +344,21 @@ class ALSModel private[ml] ( recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) } + /** + * Returns top `numItems` items recommended for each user id in the input data set. Note that if + * there are duplicate ids in the input dataset, only one set of recommendations per unique id + * will be returned. + * @param dataset a Dataset containing a column of user ids. The column name must match `userCol`. + * @param numItems max number of recommendations for each user. + * @return a DataFrame of (userCol: Int, recommendations), where recommendations are + * stored as an array of (itemCol: Int, rating: Float) Rows. + */ + @Since("2.3.0") + def recommendForUserSubset(dataset: Dataset[_], numItems: Int): DataFrame = { + val srcFactorSubset = getSourceFactorSubset(dataset, userFactors, $(userCol)) + recommendForAll(srcFactorSubset, itemFactors, $(userCol), $(itemCol), numItems) + } + /** * Returns top `numUsers` users recommended for each item, for all items. * @param numUsers max number of recommendations for each item @@ -355,6 +370,39 @@ class ALSModel private[ml] ( recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) } + /** + * Returns top `numUsers` users recommended for each item id in the input data set. Note that if + * there are duplicate ids in the input dataset, only one set of recommendations per unique id + * will be returned. + * @param dataset a Dataset containing a column of item ids. The column name must match `itemCol`. + * @param numUsers max number of recommendations for each item. + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userCol: Int, rating: Float) Rows. + */ + @Since("2.3.0") + def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): DataFrame = { + val srcFactorSubset = getSourceFactorSubset(dataset, itemFactors, $(itemCol)) + recommendForAll(srcFactorSubset, userFactors, $(itemCol), $(userCol), numUsers) + } + + /** + * Returns a subset of a factor DataFrame limited to only those unique ids contained + * in the input dataset. + * @param dataset input Dataset containing id column to user to filter factors. + * @param factors factor DataFrame to filter. + * @param column column name containing the ids in the input dataset. + * @return DataFrame containing factors only for those ids present in both the input dataset and + * the factor DataFrame. + */ + private def getSourceFactorSubset( + dataset: Dataset[_], + factors: DataFrame, + column: String): DataFrame = { + factors + .join(dataset.select(column), factors("id") === dataset(column), joinType = "left_semi") + .select(factors("id"), factors("features")) + } + /** * Makes recommendations for all users (or items). * diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index ac7319110159b..addcd21d50aac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -723,9 +723,9 @@ class ALSSuite val numUsers = model.userFactors.count val numItems = model.itemFactors.count val expected = Map( - 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), - 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), - 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + 0 -> Seq((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Seq((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Seq((3, 51f), (5, 45f), (4, 30f), (6, 18f)) ) Seq(2, 4, 6).foreach { k => @@ -743,10 +743,10 @@ class ALSSuite val numUsers = model.userFactors.count val numItems = model.itemFactors.count val expected = Map( - 3 -> Array((0, 54f), (2, 51f), (1, 39f)), - 4 -> Array((0, 44f), (2, 30f), (1, 26f)), - 5 -> Array((2, 45f), (0, 42f), (1, 33f)), - 6 -> Array((0, 28f), (2, 18f), (1, 16f)) + 3 -> Seq((0, 54f), (2, 51f), (1, 39f)), + 4 -> Seq((0, 44f), (2, 30f), (1, 26f)), + 5 -> Seq((2, 45f), (0, 42f), (1, 33f)), + 6 -> Seq((0, 28f), (2, 18f), (1, 16f)) ) Seq(2, 3, 4).foreach { k => @@ -759,9 +759,93 @@ class ALSSuite } } + test("recommendForUserSubset with k <, = and > num_items") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val numItems = model.itemFactors.count + val expected = Map( + 0 -> Seq((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 2 -> Seq((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + ) + val userSubset = expected.keys.toSeq.toDF("user") + val numUsersSubset = userSubset.count + + Seq(2, 4, 6).foreach { k => + val n = math.min(k, numItems).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topItems = model.recommendForUserSubset(userSubset, k) + assert(topItems.count() == numUsersSubset) + assert(topItems.columns.contains("user")) + checkRecommendations(topItems, expectedUpToN, "item") + } + } + + test("recommendForItemSubset with k <, = and > num_users") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val numUsers = model.userFactors.count + val expected = Map( + 3 -> Seq((0, 54f), (2, 51f), (1, 39f)), + 6 -> Seq((0, 28f), (2, 18f), (1, 16f)) + ) + val itemSubset = expected.keys.toSeq.toDF("item") + val numItemsSubset = itemSubset.count + + Seq(2, 3, 4).foreach { k => + val n = math.min(k, numUsers).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topUsers = model.recommendForItemSubset(itemSubset, k) + assert(topUsers.count() == numItemsSubset) + assert(topUsers.columns.contains("item")) + checkRecommendations(topUsers, expectedUpToN, "user") + } + } + + test("subset recommendations eliminate duplicate ids, returns same results as unique ids") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val k = 2 + + val users = Seq(0, 1).toDF("user") + val dupUsers = Seq(0, 1, 0, 1).toDF("user") + val singleUserRecs = model.recommendForUserSubset(users, k) + val dupUserRecs = model.recommendForUserSubset(dupUsers, k) + .as[(Int, Seq[(Int, Float)])].collect().toMap + assert(singleUserRecs.count == dupUserRecs.size) + checkRecommendations(singleUserRecs, dupUserRecs, "item") + + val items = Seq(3, 4, 5).toDF("item") + val dupItems = Seq(3, 4, 5, 4, 5).toDF("item") + val singleItemRecs = model.recommendForItemSubset(items, k) + val dupItemRecs = model.recommendForItemSubset(dupItems, k) + .as[(Int, Seq[(Int, Float)])].collect().toMap + assert(singleItemRecs.count == dupItemRecs.size) + checkRecommendations(singleItemRecs, dupItemRecs, "user") + } + + test("subset recommendations on full input dataset equivalent to recommendForAll") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val k = 2 + + val userSubset = model.userFactors.withColumnRenamed("id", "user").drop("features") + val userSubsetRecs = model.recommendForUserSubset(userSubset, k) + val allUserRecs = model.recommendForAllUsers(k).as[(Int, Seq[(Int, Float)])].collect().toMap + checkRecommendations(userSubsetRecs, allUserRecs, "item") + + val itemSubset = model.itemFactors.withColumnRenamed("id", "item").drop("features") + val itemSubsetRecs = model.recommendForItemSubset(itemSubset, k) + val allItemRecs = model.recommendForAllItems(k).as[(Int, Seq[(Int, Float)])].collect().toMap + checkRecommendations(itemSubsetRecs, allItemRecs, "user") + } + private def checkRecommendations( topK: DataFrame, - expected: Map[Int, Array[(Int, Float)]], + expected: Map[Int, Seq[(Int, Float)]], dstColName: String): Unit = { val spark = this.spark import spark.implicits._ diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index bcfb36880eb02..e8bcbe4cd34cb 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -90,6 +90,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> item_recs.where(item_recs.item == 2)\ .select("recommendations.user", "recommendations.rating").collect() [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])] + >>> user_subset = df.where(df.user == 2) + >>> user_subset_recs = model.recommendForUserSubset(user_subset, 3) + >>> user_subset_recs.select("recommendations.item", "recommendations.rating").first() + Row(item=[2, 1, 0], rating=[4.901..., 1.056..., -1.501...]) + >>> item_subset = df.where(df.item == 0) + >>> item_subset_recs = model.recommendForItemSubset(item_subset, 3) + >>> item_subset_recs.select("recommendations.user", "recommendations.rating").first() + Row(user=[0, 1, 2], rating=[3.910..., 2.625..., -1.501...]) >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) @@ -414,6 +422,36 @@ def recommendForAllItems(self, numUsers): """ return self._call_java("recommendForAllItems", numUsers) + @since("2.3.0") + def recommendForUserSubset(self, dataset, numItems): + """ + Returns top `numItems` items recommended for each user id in the input data set. Note that + if there are duplicate ids in the input dataset, only one set of recommendations per unique + id will be returned. + + :param dataset: a Dataset containing a column of user ids. The column name must match + `userCol`. + :param numItems: max number of recommendations for each user + :return: a DataFrame of (userCol, recommendations), where recommendations are + stored as an array of (itemCol, rating) Rows. + """ + return self._call_java("recommendForUserSubset", dataset, numItems) + + @since("2.3.0") + def recommendForItemSubset(self, dataset, numUsers): + """ + Returns top `numUsers` users recommended for each item id in the input data set. Note that + if there are duplicate ids in the input dataset, only one set of recommendations per unique + id will be returned. + + :param dataset: a Dataset containing a column of item ids. The column name must match + `itemCol`. + :param numUsers: max number of recommendations for each item + :return: a DataFrame of (itemCol, recommendations), where recommendations are + stored as an array of (userCol, rating) Rows. + """ + return self._call_java("recommendForItemSubset", dataset, numUsers) + if __name__ == "__main__": import doctest From f31e11404d6d5ee28b574c242ecbee94f35e9370 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 9 Oct 2017 12:53:10 -0700 Subject: [PATCH 679/779] [SPARK-21568][CORE] ConsoleProgressBar should only be enabled in shells ## What changes were proposed in this pull request? This PR disables console progress bar feature in non-shell environment by overriding the configuration. ## How was this patch tested? Manual. Run the following examples with and without `spark.ui.showConsoleProgress` in order to see progress bar on master branch and this PR. **Scala Shell** ```scala spark.range(1000000000).map(_ + 1).count ``` **PySpark** ```python spark.range(10000000).rdd.map(lambda x: len(x)).count() ``` **Spark Submit** ```python from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession.builder.getOrCreate() spark.range(2000000).rdd.map(lambda row: len(row)).count() spark.stop() ``` Author: Dongjoon Hyun Closes #19061 from dongjoon-hyun/SPARK-21568. --- .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 5 +++++ .../org/apache/spark/internal/config/package.scala | 5 +++++ .../org/apache/spark/deploy/SparkSubmitSuite.scala | 12 ++++++++++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cec61d85ccf38..b3cd03c0cfbe1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -434,7 +434,7 @@ class SparkContext(config: SparkConf) extends Logging { _statusTracker = new SparkStatusTracker(this) _progressBar = - if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + if (_conf.get(UI_SHOW_CONSOLE_PROGRESS) && !log.isInfoEnabled) { Some(new ConsoleProgressBar(this)) } else { None diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 286a4379d2040..135bbe93bf28e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -598,6 +598,11 @@ object SparkSubmit extends CommandLineUtils with Logging { } } + // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. + if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) { + sysProps(UI_SHOW_CONSOLE_PROGRESS.key) = "true" + } + // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python and R files, the primary resource is already distributed as a regular file diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d85b6a0200b8d..5278e5e0fb270 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -203,6 +203,11 @@ package object config { private[spark] val HISTORY_UI_MAX_APPS = ConfigBuilder("spark.history.ui.maxApplications").intConf.createWithDefault(Integer.MAX_VALUE) + private[spark] val UI_SHOW_CONSOLE_PROGRESS = ConfigBuilder("spark.ui.showConsoleProgress") + .doc("When true, show the progress bar in the console.") + .booleanConf + .createWithDefault(false) + private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled") .booleanConf .createWithDefault(false) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index ad801bf8519a6..b06f2e26a4a7a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -399,6 +399,18 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.yarn.Client") } + test("SPARK-21568 ConsoleProgressBar should be enabled only in shells") { + val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") + val appArgs1 = new SparkSubmitArguments(clArgs1) + val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + sysProps1(UI_SHOW_CONSOLE_PROGRESS.key) should be ("true") + + val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") + val appArgs2 = new SparkSubmitArguments(clArgs2) + val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + sysProps2.keys should not contain UI_SHOW_CONSOLE_PROGRESS.key + } + test("launch simple application with spark-submit") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( From a74ec6d7bbfe185ba995dcb02d69e90a089c293e Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Mon, 9 Oct 2017 12:56:37 -0700 Subject: [PATCH 680/779] [SPARK-22218] spark shuffle services fails to update secret on app re-attempts This patch fixes application re-attempts when running spark on yarn using the external shuffle service with security on. Currently executors will fail to launch on any application re-attempt when launched on a nodemanager that had an executor from the first attempt. The reason for this is because we aren't updating the secret key after the first application attempt. The fix here is to just remove the containskey check to see if it already exists. In this way, we always add it and make sure its the most recent secret. Similarly remove the check for containsKey on the remove since its just adding extra check that isn't really needed. Note this worked before spark 2.2 because the check used to be contains (which was looking for the value) rather then containsKey, so that never matched and it was just always adding the new secret. Patch was tested on a 10 node cluster as well as added the unit test. The test ran was a wordcount where the output directory already existed. With the bug present the application attempt failed with max number of executor Failures which were all saslExceptions. With the fix present the application re-attempts fail with directory already exists or when you remove the directory between attempts the re-attemps succeed. Author: Thomas Graves Closes #19450 from tgravescs/SPARK-22218. --- .../network/sasl/ShuffleSecretManager.java | 19 +++---- .../sasl/ShuffleSecretManagerSuite.java | 55 +++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index d2d008f8a3d35..7253101f41df6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -47,12 +47,11 @@ public ShuffleSecretManager() { * fetching shuffle files written by other executors in this application. */ public void registerApp(String appId, String shuffleSecret) { - if (!shuffleSecretMap.containsKey(appId)) { - shuffleSecretMap.put(appId, shuffleSecret); - logger.info("Registered shuffle secret for application {}", appId); - } else { - logger.debug("Application {} already registered", appId); - } + // Always put the new secret information to make sure it's the most up to date. + // Otherwise we have to specifically look at the application attempt in addition + // to the applicationId since the secrets change between application attempts on yarn. + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); } /** @@ -67,12 +66,8 @@ public void registerApp(String appId, ByteBuffer shuffleSecret) { * This is called when the application terminates. */ public void unregisterApp(String appId) { - if (shuffleSecretMap.containsKey(appId)) { - shuffleSecretMap.remove(appId); - logger.info("Unregistered shuffle secret for application {}", appId); - } else { - logger.warn("Attempted to unregister application {} when it is not registered", appId); - } + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); } /** diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java new file mode 100644 index 0000000000000..46c4c33865eea --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.nio.ByteBuffer; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class ShuffleSecretManagerSuite { + static String app1 = "app1"; + static String app2 = "app2"; + static String pw1 = "password1"; + static String pw2 = "password2"; + static String pw1update = "password1update"; + static String pw2update = "password2update"; + + @Test + public void testMultipleRegisters() { + ShuffleSecretManager secretManager = new ShuffleSecretManager(); + secretManager.registerApp(app1, pw1); + assertEquals(pw1, secretManager.getSecretKey(app1)); + secretManager.registerApp(app2, ByteBuffer.wrap(pw2.getBytes())); + assertEquals(pw2, secretManager.getSecretKey(app2)); + + // now update the password for the apps and make sure it takes affect + secretManager.registerApp(app1, pw1update); + assertEquals(pw1update, secretManager.getSecretKey(app1)); + secretManager.registerApp(app2, ByteBuffer.wrap(pw2update.getBytes())); + assertEquals(pw2update, secretManager.getSecretKey(app2)); + + secretManager.unregisterApp(app1); + assertNull(secretManager.getSecretKey(app1)); + assertEquals(pw2update, secretManager.getSecretKey(app2)); + + secretManager.unregisterApp(app2); + assertNull(secretManager.getSecretKey(app2)); + assertNull(secretManager.getSecretKey(app1)); + } +} From b650ee0265477ada68220cbf286fa79906608ef5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 9 Oct 2017 13:55:55 -0700 Subject: [PATCH 681/779] [INFRA] Close stale PRs. Closes #19423 Closes #19455 From dadd13f365aad0d9228cd8b8e6d57ad32175b155 Mon Sep 17 00:00:00 2001 From: Pavel Sakun Date: Mon, 9 Oct 2017 23:00:04 +0100 Subject: [PATCH 682/779] [SPARK] Misleading error message for missing --proxy-user value Fix misleading error message when argument is expected. ## What changes were proposed in this pull request? Change message to be accurate. ## How was this patch tested? Messaging change, was tested manually. Author: Pavel Sakun Closes #19457 from pavel-sakun/patch-1. --- .../src/main/java/org/apache/spark/launcher/SparkLauncher.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index 718a368a8e731..75b8ef5ca5ef4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -625,7 +625,7 @@ private static class ArgumentValidator extends SparkSubmitOptionParser { @Override protected boolean handle(String opt, String value) { if (value == null && hasValue) { - throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + throw new IllegalArgumentException(String.format("'%s' expects a value.", opt)); } return true; } From 155ab6347ec7be06c937372a51e8013fdd371d93 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 9 Oct 2017 15:22:41 -0700 Subject: [PATCH 683/779] [SPARK-22170][SQL] Reduce memory consumption in broadcast joins. ## What changes were proposed in this pull request? This updates the broadcast join code path to lazily decompress pages and iterate through UnsafeRows to prevent all rows from being held in memory while the broadcast table is being built. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #19394 from rdblue/broadcast-driver-memory. --- .../plans/physical/broadcastMode.scala | 6 ++++ .../spark/sql/execution/SparkPlan.scala | 19 ++++++++---- .../exchange/BroadcastExchangeExec.scala | 29 ++++++++++++++----- .../sql/execution/joins/HashedRelation.scala | 13 ++++++++- .../spark/sql/ConfigBehaviorSuite.scala | 2 +- .../execution/metric/SQLMetricsSuite.scala | 3 +- 6 files changed, 54 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 2ab46dc8330aa..9fac95aed8f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow trait BroadcastMode { def transform(rows: Array[InternalRow]): Any + def transform(rows: Iterator[InternalRow], sizeHint: Option[Long]): Any + def canonicalized: BroadcastMode } @@ -36,5 +38,9 @@ case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows + override def transform( + rows: Iterator[InternalRow], + sizeHint: Option[Long]): Array[InternalRow] = rows.toArray + override def canonicalized: BroadcastMode = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b263f100e6068..2ffd948f984bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -223,7 +223,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also * compressed. */ - private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + private def getByteArrayRdd(n: Int = -1): RDD[(Long, Array[Byte])] = { execute().mapPartitionsInternal { iter => var count = 0 val buffer = new Array[Byte](4 << 10) // 4K @@ -239,7 +239,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ out.writeInt(-1) out.flush() out.close() - Iterator(bos.toByteArray) + Iterator((count, bos.toByteArray)) } } @@ -274,19 +274,26 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val byteArrayRdd = getByteArrayRdd() val results = ArrayBuffer[InternalRow]() - byteArrayRdd.collect().foreach { bytes => - decodeUnsafeRows(bytes).foreach(results.+=) + byteArrayRdd.collect().foreach { countAndBytes => + decodeUnsafeRows(countAndBytes._2).foreach(results.+=) } results.toArray } + private[spark] def executeCollectIterator(): (Long, Iterator[InternalRow]) = { + val countsAndBytes = getByteArrayRdd().collect() + val total = countsAndBytes.map(_._1).sum + val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2)) + (total, rows) + } + /** * Runs this query returning the result as an iterator of InternalRow. * * @note Triggers multiple jobs (one for each partition). */ def executeToIterator(): Iterator[InternalRow] = { - getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows) + getByteArrayRdd().map(_._2).toLocalIterator.flatMap(decodeUnsafeRows) } /** @@ -307,7 +314,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = getByteArrayRdd(n) + val childRDD = getByteArrayRdd(n).map(_._2) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9c859e41f8762..880e18c6808b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils @@ -72,26 +72,39 @@ case class BroadcastExchangeExec( SQLExecution.withExecutionId(sparkContext, executionId) { try { val beforeCollect = System.nanoTime() - // Note that we use .executeCollect() because we don't want to convert data to Scala types - val input: Array[InternalRow] = child.executeCollect() - if (input.length >= 512000000) { + // Use executeCollect/executeCollectIterator to avoid conversion to Scala types + val (numRows, input) = child.executeCollectIterator() + if (numRows >= 512000000) { throw new SparkException( - s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows") + s"Cannot broadcast the table with more than 512 millions rows: $numRows rows") } + val beforeBuild = System.nanoTime() longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + + // Construct the relation. + val relation = mode.transform(input, Some(numRows)) + + val dataSize = relation match { + case map: HashedRelation => + map.estimatedSize + case arr: Array[InternalRow] => + arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + case _ => + throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " + + relation.getClass.getName) + } + longMetric("dataSize") += dataSize if (dataSize >= (8L << 30)) { throw new SparkException( s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") } - // Construct and broadcast the relation. - val relation = mode.transform(input) val beforeBroadcast = System.nanoTime() longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 + // Broadcast the relation val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index f8058b2f7813b..b2dcbe5aa9877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -866,7 +866,18 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalized.key, rows.length) + transform(rows.iterator, Some(rows.length)) + } + + override def transform( + rows: Iterator[InternalRow], + sizeHint: Option[Long]): HashedRelation = { + sizeHint match { + case Some(numRows) => + HashedRelation(rows, canonicalized.key, numRows.toInt) + case None => + HashedRelation(rows, canonicalized.key) + } } override lazy val canonicalized: HashedRelationBroadcastMode = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 2c1e5db5fd9bb..cee85ec8af04d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -58,7 +58,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") { // If we only sample one point, the range boundaries will be pretty bad and the // chi-sq value would be very high. - assert(computeChiSquareTest() > 1000) + assert(computeChiSquareTest() > 300) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 0dc612ef735fa..58a194b8af62b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -227,8 +227,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df = df1.join(broadcast(df2), "key") testSparkPlanMetrics(df, 2, Map( 1L -> (("BroadcastHashJoin", Map( - "number of output rows" -> 2L, - "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))) + "number of output rows" -> 2L)))) ) } From 71c2b81aa0e0db70013821f5512df1fbd8e59445 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 9 Oct 2017 16:34:39 -0700 Subject: [PATCH 684/779] [SPARK-22230] Swap per-row order in state store restore. ## What changes were proposed in this pull request? In state store restore, for each row, put the saved state before the row in the iterator instead of after. This fixes an issue where agg(last('attr)) will forever return the last value of 'attr from the first microbatch. ## How was this patch tested? new unit test Author: Jose Torres Closes #19461 from joseph-torres/SPARK-22230. --- .../execution/streaming/statefulOperators.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index fb960fbdde8b3..0d85542928ee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -225,7 +225,7 @@ case class StateStoreRestoreExec( val key = getKey(row) val savedState = store.get(key) numOutputRows += 1 - row +: Option(savedState).toSeq + Option(savedState).toSeq :+ row } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 995cea3b37d4f..fe7efa69f7e31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -520,6 +520,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } + test("SPARK-22230: last should change with new batches") { + val input = MemoryStream[Int] + + val aggregated = input.toDF().agg(last('value)) + testStream(aggregated, OutputMode.Complete())( + AddData(input, 1, 2, 3), + CheckLastBatch(3), + AddData(input, 4, 5, 6), + CheckLastBatch(6), + AddData(input), + CheckLastBatch(6), + AddData(input, 0), + CheckLastBatch(0) + ) + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { From bebd2e1ce10a460555f75cda75df33f39a783469 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 9 Oct 2017 21:34:37 -0700 Subject: [PATCH 685/779] [SPARK-22222][CORE] Fix the ARRAY_MAX in BufferHolder and add a test ## What changes were proposed in this pull request? We should not break the assumption that the length of the allocated byte array is word rounded: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java#L170 So we want to use `Integer.MAX_VALUE - 15` instead of `Integer.MAX_VALUE - 8` as the upper bound of an allocated byte array. cc: srowen gatorsmile ## How was this patch tested? Since the Spark unit test JVM has less than 1GB heap, here we run the test code as a submit job, so it can run on a JVM has 4GB memory. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #19460 from liufengdb/fix_array_max. --- .../spark/unsafe/array/ByteArrayMethods.java | 7 ++ .../unsafe/map/HashMapGrowthStrategy.java | 6 +- .../collection/PartitionedPairBuffer.scala | 6 +- .../spark/deploy/SparkSubmitSuite.scala | 52 +++++++------ .../expressions/codegen/BufferHolder.java | 7 +- .../BufferHolderSparkSubmitSutie.scala | 78 +++++++++++++++++++ .../vectorized/WritableColumnVector.java | 3 +- 7 files changed, 124 insertions(+), 35 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 9c551ab19e9aa..f121b1cd745b8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -40,6 +40,13 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { } } + // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat smaller. + // Be conservative and lower the cap a little. + // Refer to "http://hg.openjdk.java.net/jdk8/jdk8/jdk/file/tip/src/share/classes/java/util/ArrayList.java#l229" + // This value is word rounded. Use this value if the allocated byte arrays are used to store other + // types rather than bytes. + public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; + private static final boolean unaligned = Platform.unaligned(); /** * Optimized byte array equality check for byte arrays. diff --git a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java index b8c2294c7b7ab..ee6d9f75ac5aa 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -17,6 +17,8 @@ package org.apache.spark.unsafe.map; +import org.apache.spark.unsafe.array.ByteArrayMethods; + /** * Interface that defines how we can grow the size of a hash map when it is over a threshold. */ @@ -31,9 +33,7 @@ public interface HashMapGrowthStrategy { class Doubling implements HashMapGrowthStrategy { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - private static final int ARRAY_MAX = Integer.MAX_VALUE - 8; + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; @Override public int nextCapacity(int currentCapacity) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index b755e5da51684..e17a9de97e335 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,6 +19,8 @@ package org.apache.spark.util.collection import java.util.Comparator +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** @@ -96,7 +98,5 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) } private object PartitionedPairBuffer { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - val MAXIMUM_CAPACITY: Int = (Int.MaxValue - 8) / 2 + val MAXIMUM_CAPACITY: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 2 } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b06f2e26a4a7a..b52da4c0c8bc3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -100,6 +100,8 @@ class SparkSubmitSuite with TimeLimits with TestPrematureExit { + import SparkSubmitSuite._ + override def beforeEach() { super.beforeEach() System.setProperty("spark.testing", "true") @@ -974,30 +976,6 @@ class SparkSubmitSuite } } - // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - private def runSparkSubmit(args: Seq[String]): Unit = { - val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val sparkSubmitFile = if (Utils.isWindows) { - new File("..\\bin\\spark-submit.cmd") - } else { - new File("../bin/spark-submit") - } - val process = Utils.executeCommand( - Seq(sparkSubmitFile.getCanonicalPath) ++ args, - new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - - try { - val exitCode = failAfter(60 seconds) { process.waitFor() } - if (exitCode != 0) { - fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") - } - } finally { - // Ensure we still kill the process in case it timed out - process.destroy() - } - } - private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { val tmpDir = Utils.createTempDir() @@ -1020,6 +998,32 @@ class SparkSubmitSuite } } +object SparkSubmitSuite extends SparkFunSuite with TimeLimits { + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + def runSparkSubmit(args: Seq[String], root: String = ".."): Unit = { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val sparkSubmitFile = if (Utils.isWindows) { + new File(s"$root\\bin\\spark-submit.cmd") + } else { + new File(s"$root/bin/spark-submit") + } + val process = Utils.executeCommand( + Seq(sparkSubmitFile.getCanonicalPath) ++ args, + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + + try { + val exitCode = failAfter(60 seconds) { process.waitFor() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} + object JarCreationTest extends Logging { def main(args: Array[String]) { Utils.configTestLog4j("INFO") diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 971d19973f067..259976118c12f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -19,6 +19,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; /** * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and @@ -36,9 +37,7 @@ */ public class BufferHolder { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - private static final int ARRAY_MAX = Integer.MAX_VALUE - 8; + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; @@ -51,7 +50,7 @@ public BufferHolder(UnsafeRow row) { public BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); - if (row.numFields() > (Integer.MAX_VALUE - initialSize - bitsetWidthInBytes) / 8) { + if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( "Cannot create BufferHolder for input UnsafeRow because there are " + "too many fields (number of fields: " + row.numFields() + ")"); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala new file mode 100644 index 0000000000000..1167d2f3f3891 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.concurrent.Timeouts + +import org.apache.spark.{SparkFunSuite, TestUtils} +import org.apache.spark.deploy.SparkSubmitSuite +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.util.ResetSystemProperties + +// A test for growing the buffer holder to nearly 2GB. Due to the heap size limitation of the Spark +// unit tests JVM, the actually test code is running as a submit job. +class BufferHolderSparkSubmitSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach + with ResetSystemProperties + with Timeouts { + + test("SPARK-22222: Buffer holder should be able to allocate memory larger than 1GB") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + + val argsForSparkSubmit = Seq( + "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"), + "--name", "SPARK-22222", + "--master", "local-cluster[2,1,1024]", + "--driver-memory", "4g", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.driver.extraJavaOptions=-ea", + unusedJar.toString) + SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..") + } +} + +object BufferHolderSparkSubmitSuite { + + def main(args: Array[String]): Unit = { + + val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val holder = new BufferHolder(new UnsafeRow(1000)) + + holder.reset() + holder.grow(roundToWord(ARRAY_MAX / 2)) + + holder.reset() + holder.grow(roundToWord(ARRAY_MAX / 2 + 8)) + + holder.reset() + holder.grow(roundToWord(Integer.MAX_VALUE / 2)) + + holder.reset() + holder.grow(roundToWord(Integer.MAX_VALUE)) + } + + private def roundToWord(len: Int): Int = { + ByteArrayMethods.roundNumberOfBytesToNearestWord(len) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index da72954ddc448..d3a14b9d8bd74 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; /** @@ -595,7 +596,7 @@ public final int appendStruct(boolean isNull) { * Upper limit for the maximum capacity for this column. */ @VisibleForTesting - protected int MAX_CAPACITY = Integer.MAX_VALUE - 8; + protected int MAX_CAPACITY = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. From af8a34c787dc3d68f5148a7d9975b52650bb7729 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 9 Oct 2017 22:35:34 -0700 Subject: [PATCH 686/779] [SPARK-22159][SQL][FOLLOW-UP] Make config names consistently end with "enabled". ## What changes were proposed in this pull request? This is a follow-up of #19384. In the previous pr, only definitions of the config names were modified, but we also need to modify the names in runtime or tests specified as string literal. ## How was this patch tested? Existing tests but modified the config names. Author: Takuya UESHIN Closes #19462 from ueshin/issues/SPARK-22159/fup1. --- python/pyspark/sql/dataframe.py | 4 ++-- python/pyspark/sql/tests.py | 6 +++--- .../aggregate/HashAggregateExec.scala | 2 +- .../spark/sql/AggregateHashMapSuite.scala | 12 +++++------ .../benchmark/AggregateBenchmark.scala | 20 +++++++++---------- .../execution/AggregationQuerySuite.scala | 2 +- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b7ce9a83a616d..fe69e588fe098 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1878,7 +1878,7 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": + if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: import pyarrow tables = self._collectAsArrow() @@ -1889,7 +1889,7 @@ def toPandas(self): return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enable=true" + "if using spark.sql.execution.arrow.enabled=true" raise ImportError("%s\n%s" % (e.message, msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1b3af42c47ad2..a59378b5e848a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3088,7 +3088,7 @@ class ArrowTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.spark = SparkSession(cls.sc) - cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") + cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3120,9 +3120,9 @@ def test_null_conversion(self): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f424096b330e3..8b573fdcf25e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -539,7 +539,7 @@ case class HashAggregateExec( private def enableTwoLevelHashMap(ctx: CodegenContext) = { if (!checkIfFastHashMapSupported(ctx)) { if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { - logInfo("spark.sql.codegen.aggregate.map.twolevel.enable is set to true, but" + logInfo("spark.sql.codegen.aggregate.map.twolevel.enabled is set to true, but" + " current version of codegened fast hashmap does not support this aggregate.") } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 7e61a68025158..938d76c9f0837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -24,14 +24,14 @@ import org.apache.spark.SparkConf class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "false", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "false", "configuration parameter changed in test body") } } @@ -39,14 +39,14 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", "configuration parameter changed in test body") } } @@ -57,7 +57,7 @@ class TwoLevelAggregateHashMapWithVectorizedMapSuite override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed @@ -65,7 +65,7 @@ class TwoLevelAggregateHashMapWithVectorizedMapSuite after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", "configuration parameter changed in test body") assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true", "configuration parameter changed in test body") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index aca1be01fa3da..a834b7cd2c69f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -107,14 +107,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -149,14 +149,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -189,14 +189,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -228,14 +228,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -277,14 +277,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index f245a79f805a2..ae675149df5e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1015,7 +1015,7 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { Seq("true", "false").foreach { enableTwoLevelMaps => - withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enable" -> + withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enabled" -> enableTwoLevelMaps) { (1 to 3).foreach { fallbackStartsAt => withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> From 3b5c2a84bfa311a94c1c0a57f2cb3e421fb05650 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 10 Oct 2017 08:27:45 +0100 Subject: [PATCH 687/779] [SPARK-21770][ML] ProbabilisticClassificationModel fix corner case: normalization of all-zero raw predictions ## What changes were proposed in this pull request? Fix probabilisticClassificationModel corner case: normalization of all-zero raw predictions, throw IllegalArgumentException with description. ## How was this patch tested? Test case added. Author: WeichenXu Closes #19106 from WeichenXu123/SPARK-21770. --- .../ProbabilisticClassifier.scala | 20 ++++++++++--------- .../ProbabilisticClassifierSuite.scala | 18 +++++++++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index ef08134809915..730fcab333e11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -230,21 +230,23 @@ private[ml] object ProbabilisticClassificationModel { * Normalize a vector of raw predictions to be a multinomial probability vector, in place. * * The input raw predictions should be nonnegative. - * The output vector sums to 1, unless the input vector is all-0 (in which case the output is - * all-0 too). + * The output vector sums to 1. * * NOTE: This is NOT applicable to all models, only ones which effectively use class * instance counts for raw predictions. + * + * @throws IllegalArgumentException if the input vector is all-0 or including negative values */ def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = { + v.values.foreach(value => require(value >= 0, + "The input raw predictions should be nonnegative.")) val sum = v.values.sum - if (sum != 0) { - var i = 0 - val size = v.size - while (i < size) { - v.values(i) /= sum - i += 1 - } + require(sum > 0, "Can't normalize the 0-vector.") + var i = 0 + val size = v.size + while (i < size) { + v.values(i) /= sum + i += 1 } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index 4ecd5a05365eb..d649ceac949c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -80,6 +80,24 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1)) } } + + test("normalizeToProbabilitiesInPlace") { + val vec1 = Vectors.dense(1.0, 2.0, 3.0).toDense + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec1) + assert(vec1 ~== Vectors.dense(1.0 / 6, 2.0 / 6, 3.0 / 6) relTol 1e-3) + + // all-0 input test + val vec2 = Vectors.dense(0.0, 0.0, 0.0).toDense + intercept[IllegalArgumentException] { + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec2) + } + + // negative input test + val vec3 = Vectors.dense(1.0, -1.0, 2.0).toDense + intercept[IllegalArgumentException] { + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec3) + } + } } object ProbabilisticClassifierSuite { From b8a08f25cc64ed3034f3c90790931c30e5b0f236 Mon Sep 17 00:00:00 2001 From: liuxian Date: Tue, 10 Oct 2017 20:44:33 +0800 Subject: [PATCH 688/779] [SPARK-21506][DOC] The description of "spark.executor.cores" may be not correct ## What changes were proposed in this pull request? The number of cores assigned to each executor is configurable. When this is not explicitly set, multiple executors from the same application may be launched on the same worker too. ## How was this patch tested? N/A Author: liuxian Closes #18711 from 10110346/executorcores. --- .../spark/deploy/client/StandaloneAppClient.scala | 2 +- .../scala/org/apache/spark/deploy/master/Master.scala | 8 +++++++- .../cluster/StandaloneSchedulerBackend.scala | 2 +- docs/configuration.md | 11 ++++------- docs/spark-standalone.md | 8 ++++++++ 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 757c930b84eb2..34ade4ce6f39b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -170,7 +170,7 @@ private[spark] class StandaloneAppClient( case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, + logInfo("Executor added: %s on %s (%s) with %d core(s)".format(fullId, workerId, hostPort, cores)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e030cac60a8e4..2c78c15773af2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -581,7 +581,13 @@ private[deploy] class Master( * The number of cores assigned to each executor is configurable. When this is explicitly set, * multiple executors from the same application may be launched on the same worker if the worker * has enough cores and memory. Otherwise, each executor grabs all the cores available on the - * worker by default, in which case only one executor may be launched on each worker. + * worker by default, in which case only one executor per application may be launched on each + * worker during one single schedule iteration. + * Note that when `spark.executor.cores` is not set, we may still launch multiple executors from + * the same application on the same worker. Consider appA and appB both have one executor running + * on worker1, and appA.coresLeft > 0, then appB is finished and release all its cores on worker1, + * thus for the next schedule iteration, appA launches a new executor that grabs all the free + * cores on worker1, therefore we get multiple executors from appA running on worker1. * * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core * at a time). Consider the following example: cluster has 4 workers with 16 cores each. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index a4e2a74341283..505c342a889ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -153,7 +153,7 @@ private[spark] class StandaloneSchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( + logInfo("Granted executor ID %s on hostPort %s with %d core(s), %s RAM".format( fullId, hostPort, cores, Utils.megabytesToString(memory))) } diff --git a/docs/configuration.md b/docs/configuration.md index 6e9fe591b70a3..7a777d3c6fa3d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1015,7 +1015,7 @@ Apart from these, the following properties are also available, and may be useful 0.5 Amount of storage memory immune to eviction, expressed as a fraction of the size of the - region set aside by s​park.memory.fraction. The higher this is, the less + region set aside by spark.memory.fraction. The higher this is, the less working memory may be available to execution and tasks may spill to disk more often. Leaving this at the default value is recommended. For more detail, see this description. @@ -1041,7 +1041,7 @@ Apart from these, the following properties are also available, and may be useful spark.memory.useLegacyMode false - ​Whether to enable the legacy memory management mode used in Spark 1.5 and before. + Whether to enable the legacy memory management mode used in Spark 1.5 and before. The legacy mode rigidly partitions the heap space into fixed-size regions, potentially leading to excessive spilling if the application was not tuned. The following deprecated memory fraction configurations are not read unless this is enabled: @@ -1115,11 +1115,8 @@ Apart from these, the following properties are also available, and may be useful The number of cores to use on each executor. - In standalone and Mesos coarse-grained modes, setting this - parameter allows an application to run multiple executors on the - same worker, provided that there are enough cores on that - worker. Otherwise, only one executor per application will run on - each worker. + In standalone and Mesos coarse-grained modes, for more detail, see + this description. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1095386c31ab8..f51c5cc38f4de 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -328,6 +328,14 @@ export SPARK_MASTER_OPTS="-Dspark.deploy.defaultCores=" This is useful on shared clusters where users might not have configured a maximum number of cores individually. +# Executors Scheduling + +The number of cores assigned to each executor is configurable. When `spark.executor.cores` is +explicitly set, multiple executors from the same application may be launched on the same worker +if the worker has enough cores and memory. Otherwise, each executor grabs all the cores available +on the worker by default, in which case only one executor per application may be launched on each +worker during one single schedule iteration. + # Monitoring and Logging Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. From 23af2d79ad9a3c83936485ee57513b39193a446b Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 10 Oct 2017 20:48:42 +0800 Subject: [PATCH 689/779] [SPARK-20025][CORE] Ignore SPARK_LOCAL* env, while deploying via cluster mode. ## What changes were proposed in this pull request? In a bare metal system with No DNS setup, spark may be configured with SPARK_LOCAL* for IP and host properties. During a driver failover, in cluster deployment mode. SPARK_LOCAL* should be ignored while restarting on another node and should be picked up from target system's local environment. ## How was this patch tested? Distributed deployment against a spark standalone cluster of 6 Workers. Tested by killing JVM's running driver and verified the restarted JVMs have right configurations on them. Author: Prashant Sharma Author: Prashant Sharma Closes #17357 from ScrapCodes/driver-failover-fix. --- core/src/main/scala/org/apache/spark/deploy/Client.scala | 6 +++--- .../apache/spark/deploy/rest/StandaloneRestServer.scala | 4 +++- .../org/apache/spark/deploy/worker/DriverWrapper.scala | 9 ++++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index bf6093236d92b..7acb5c55bb252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -93,19 +93,19 @@ private class ClientEndpoint( driverArgs.cores, driverArgs.supervise, command) - ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + asyncSendToMasterAndForwardReply[SubmitDriverResponse]( RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + asyncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) } } /** * Send the message to master and forward the reply to self asynchronously. */ - private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + private def asyncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { for (masterEndpoint <- masterEndpoints) { masterEndpoint.ask[T](message).onComplete { case Success(v) => self.send(v) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 0164084ab129e..22b65abce611a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -139,7 +139,9 @@ private[rest] class StandaloneSubmitRequestServlet( val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") val superviseDriver = sparkProperties.get("spark.driver.supervise") val appArgs = request.appArgs - val environmentVariables = request.environmentVariables + // Filter SPARK_LOCAL_(IP|HOSTNAME) environment variables from being set on the remote system. + val environmentVariables = + request.environmentVariables.filterNot(x => x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) // Construct driver description val conf = new SparkConf(false) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index c1671192e0c64..b19c9904d5982 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -23,6 +23,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} +import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -30,7 +31,7 @@ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, U * Utility object for launching driver programs such that they share fate with the Worker process. * This is used in standalone cluster mode only. */ -object DriverWrapper { +object DriverWrapper extends Logging { def main(args: Array[String]) { args.toList match { /* @@ -41,8 +42,10 @@ object DriverWrapper { */ case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val rpcEnv = RpcEnv.create("Driver", - Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val host: String = Utils.localHostName() + val port: Int = sys.props.getOrElse("spark.driver.port", "0").toInt + val rpcEnv = RpcEnv.create("Driver", host, port, conf, new SecurityManager(conf)) + logInfo(s"Driver address: ${rpcEnv.address}") rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) val currentLoader = Thread.currentThread.getContextClassLoader From 633ffd816d285480bab1f346471135b10ec092bb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 10 Oct 2017 11:01:02 -0700 Subject: [PATCH 690/779] rename the file. --- ...rSparkSubmitSutie.scala => BufferHolderSparkSubmitSuite.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/{BufferHolderSparkSubmitSutie.scala => BufferHolderSparkSubmitSuite.scala} (100%) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala similarity index 100% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala From 2028e5a82bc3e9a79f9b84f376bdf606b8c9bb0f Mon Sep 17 00:00:00 2001 From: Eyal Farago Date: Tue, 10 Oct 2017 22:49:47 +0200 Subject: [PATCH 691/779] [SPARK-21907][CORE] oom during spill ## What changes were proposed in this pull request? 1. a test reproducing [SPARK-21907](https://issues.apache.org/jira/browse/SPARK-21907) 2. a fix for the root cause of the issue. `org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill` calls `org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset` which may trigger another spill, when this happens the `array` member is already de-allocated but still referenced by the code, this causes the nested spill to fail with an NPE in `org.apache.spark.memory.TaskMemoryManager.getPage`. This patch introduces a reproduction in a test case and a fix, the fix simply sets the in-mem sorter's array member to an empty array before actually performing the allocation. This prevents the spilling code from 'touching' the de-allocated array. ## How was this patch tested? introduced a new test case: `org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorterSuite#testOOMDuringSpill`. Author: Eyal Farago Closes #19181 from eyalfa/SPARK-21907__oom_during_spill. --- .../unsafe/sort/UnsafeExternalSorter.java | 4 ++ .../unsafe/sort/UnsafeInMemorySorter.java | 12 ++++- .../sort/UnsafeExternalSorterSuite.java | 33 +++++++++++++ .../sort/UnsafeInMemorySorterSuite.java | 46 +++++++++++++++++++ .../spark/memory/TestMemoryManager.scala | 12 +++-- 5 files changed, 102 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 39eda00dd7efb..e749f7ba87c6e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -480,6 +480,10 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { } } + @VisibleForTesting boolean hasSpaceForAnotherRecord() { + return inMemSorter.hasSpaceForAnotherRecord(); + } + private static void spillIterator(UnsafeSorterIterator inMemIterator, UnsafeSorterSpillWriter spillWriter) throws IOException { while (inMemIterator.hasNext()) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c14c12664f5ab..869ec908be1fb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -162,7 +162,9 @@ private int getUsableCapacity() { */ public void free() { if (consumer != null) { - consumer.freeArray(array); + if (array != null) { + consumer.freeArray(array); + } array = null; } } @@ -170,6 +172,14 @@ public void free() { public void reset() { if (consumer != null) { consumer.freeArray(array); + // the call to consumer.allocateArray may trigger a spill + // which in turn access this instance and eventually re-enter this method and try to free the array again. + // by setting the array to null and its length to 0 we effectively make the spill code-path a no-op. + // setting the array to null also indicates that it has already been de-allocated which prevents a double de-allocation in free(). + array = null; + usableCapacity = 0; + pos = 0; + nullBoundaryPos = 0; array = consumer.allocateArray(initialSize); usableCapacity = getUsableCapacity(); } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 5330a688e63e3..6c5451d0fd2a5 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -23,6 +23,7 @@ import java.util.LinkedList; import java.util.UUID; +import org.hamcrest.Matchers; import scala.Tuple2$; import org.junit.After; @@ -503,6 +504,38 @@ public void testGetIterator() throws Exception { verifyIntIterator(sorter.getIterator(279), 279, 300); } + @Test + public void testOOMDuringSpill() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + // we assume that given default configuration, + // the size of the data we insert to the sorter (ints) + // and assuming we shouldn't spill before pointers array is exhausted + // (memory manager is not configured to throw at this point) + // - so this loop runs a reasonable number of iterations (<2000). + // test indeed completed within <30ms (on a quad i7 laptop). + for (int i = 0; sorter.hasSpaceForAnotherRecord(); ++i) { + insertNumber(sorter, i); + } + // we expect the next insert to attempt growing the pointerssArray + // first allocation is expected to fail, then a spill is triggered which attempts another allocation + // which also fails and we expect to see this OOM here. + // the original code messed with a released array within the spill code + // and ended up with a failed assertion. + // we also expect the location of the OOM to be org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset + memoryManager.markconsequentOOM(2); + try { + insertNumber(sorter, 1024); + fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); + } + // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) + catch (OutOfMemoryError oom){ + String oomStackTrace = Utils.exceptionString(oom); + assertThat("expected OutOfMemoryError in org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", + oomStackTrace, + Matchers.containsString("org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); + } + } + private void verifyIntIterator(UnsafeSorterIterator iter, int start, int end) throws IOException { for (int i = start; i < end; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index bd89085aa9a14..1a3e11efe9787 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -35,6 +35,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.isIn; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; public class UnsafeInMemorySorterSuite { @@ -139,4 +140,49 @@ public int compare( } assertEquals(dataToSort.length, iterLength); } + + @Test + public void freeAfterOOM() { + final SparkConf sparkConf = new SparkConf(); + sparkConf.set("spark.memory.offHeap.enabled", "false"); + + final TestMemoryManager testMemoryManager = + new TestMemoryManager(sparkConf); + final TaskMemoryManager memoryManager = new TaskMemoryManager( + testMemoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final MemoryBlock dataPage = memoryManager.allocatePage(2048, consumer); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = PrefixComparators.LONG; + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, + recordComparator, prefixComparator, 100, shouldUseRadixSort()); + + testMemoryManager.markExecutionAsOutOfMemoryOnce(); + try { + sorter.reset(); + fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); + } catch (OutOfMemoryError oom) { + // as expected + } + // [SPARK-21907] this failed on NPE at org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) + sorter.free(); + // simulate a 'back to back' free. + sorter.free(); + } + } diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 5f699df8211de..c26945fa5fa31 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -27,8 +27,8 @@ class TestMemoryManager(conf: SparkConf) numBytes: Long, taskAttemptId: Long, memoryMode: MemoryMode): Long = { - if (oomOnce) { - oomOnce = false + if (consequentOOM > 0) { + consequentOOM -= 1 0 } else if (available >= numBytes) { available -= numBytes @@ -58,11 +58,15 @@ class TestMemoryManager(conf: SparkConf) override def maxOffHeapStorageMemory: Long = 0L - private var oomOnce = false + private var consequentOOM = 0 private var available = Long.MaxValue def markExecutionAsOutOfMemoryOnce(): Unit = { - oomOnce = true + markconsequentOOM(1) + } + + def markconsequentOOM(n : Int) : Unit = { + consequentOOM += n } def limit(avail: Long): Unit = { From bfc7e1fe1ad5f9777126f2941e29bbe51ea5da7c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 11 Oct 2017 07:32:01 +0900 Subject: [PATCH 692/779] [SPARK-20396][SQL][PYSPARK] groupby().apply() with pandas udf ## What changes were proposed in this pull request? This PR adds an apply() function on df.groupby(). apply() takes a pandas udf that is a transformation on `pandas.DataFrame` -> `pandas.DataFrame`. Static schema ------------------- ``` schema = df.schema pandas_udf(schema) def normalize(df): df = df.assign(v1 = (df.v1 - df.v1.mean()) / df.v1.std() return df df.groupBy('id').apply(normalize) ``` Dynamic schema ----------------------- **This use case is removed from the PR and we will discuss this as a follow up. See discussion https://github.com/apache/spark/pull/18732#pullrequestreview-66583248** Another example to use pd.DataFrame dtypes as output schema of the udf: ``` sample_df = df.filter(df.id == 1).toPandas() def foo(df): ret = # Some transformation on the input pd.DataFrame return ret foo_udf = pandas_udf(foo, foo(sample_df).dtypes) df.groupBy('id').apply(foo_udf) ``` In interactive use case, user usually have a sample pd.DataFrame to test function `foo` in their notebook. Having been able to use `foo(sample_df).dtypes` frees user from specifying the output schema of `foo`. Design doc: https://github.com/icexelloss/spark/blob/pandas-udf-doc/docs/pyspark-pandas-udf.md ## How was this patch tested? * Added GroupbyApplyTest Author: Li Jin Author: Takuya UESHIN Author: Bryan Cutler Closes #18732 from icexelloss/groupby-apply-SPARK-20396. --- python/pyspark/sql/dataframe.py | 6 +- python/pyspark/sql/functions.py | 98 ++++++++--- python/pyspark/sql/group.py | 88 +++++++++- python/pyspark/sql/tests.py | 157 +++++++++++++++++- python/pyspark/sql/types.py | 2 +- python/pyspark/worker.py | 35 ++-- .../sql/catalyst/optimizer/Optimizer.scala | 2 + .../logical/pythonLogicalOperators.scala | 39 +++++ .../spark/sql/RelationalGroupedDataset.scala | 36 +++- .../spark/sql/execution/SparkStrategies.scala | 2 + .../python/ArrowEvalPythonExec.scala | 39 ++++- .../execution/python/ArrowPythonRunner.scala | 15 +- .../execution/python/ExtractPythonUDFs.scala | 8 +- .../python/FlatMapGroupsInPandasExec.scala | 103 ++++++++++++ 14 files changed, 561 insertions(+), 69 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fe69e588fe098..2d596229ced7e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1227,7 +1227,7 @@ def groupBy(self, *cols): """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def rollup(self, *cols): @@ -1248,7 +1248,7 @@ def rollup(self, *cols): """ jgd = self._jdf.rollup(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def cube(self, *cols): @@ -1271,7 +1271,7 @@ def cube(self, *cols): """ jgd = self._jdf.cube(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.3) def agg(self, *exprs): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b45a59db93679..9bc12c3b7a162 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2058,7 +2058,7 @@ def __init__(self, func, returnType, name=None, vectorized=False): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self._vectorized = vectorized + self.vectorized = vectorized @property def returnType(self): @@ -2090,7 +2090,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self._vectorized) + self._name, wrapped_func, jdt, self.vectorized) return judf def __call__(self, *cols): @@ -2118,8 +2118,10 @@ def wrapper(*args): wrapper.__name__ = self._name wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') else self.func.__class__.__module__) + wrapper.func = self.func wrapper.returnType = self.returnType + wrapper.vectorized = self.vectorized return wrapper @@ -2129,8 +2131,12 @@ def _create_udf(f, returnType, vectorized): def _udf(f, returnType=StringType(), vectorized=vectorized): if vectorized: import inspect - if len(inspect.getargspec(f).args) == 0: - raise NotImplementedError("0-parameter pandas_udfs are not currently supported") + argspec = inspect.getargspec(f) + if len(argspec.args) == 0 and argspec.varargs is None: + raise ValueError( + "0-arg pandas_udfs are not supported. " + "Instead, create a 1-arg pandas_udf and ignore the arg in your function." + ) udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) return udf_obj._wrapped() @@ -2146,7 +2152,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): @since(1.3) def udf(f=None, returnType=StringType()): - """Creates a :class:`Column` expression representing a user defined function (UDF). + """Creates a user defined function (UDF). .. note:: The user-defined functions must be deterministic. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than @@ -2181,30 +2187,70 @@ def udf(f=None, returnType=StringType()): @since(2.3) def pandas_udf(f=None, returnType=StringType()): """ - Creates a :class:`Column` expression representing a user defined function (UDF) that accepts - `Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length. + Creates a vectorized user defined function (UDF). - :param f: python function if used as a standalone function + :param f: user-defined function. A python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object - >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(returnType=StringType()) - ... def to_upper(s): - ... return s.str.upper() - ... - >>> @pandas_udf(returnType="integer") - ... def add_one(x): - ... return x + 1 - ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) - >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP - +----------+--------------+------------+ - |slen(name)|to_upper(name)|add_one(age)| - +----------+--------------+------------+ - | 8| JOHN DOE| 22| - +----------+--------------+------------+ + The user-defined function can define one of the following transformations: + + 1. One or more `pandas.Series` -> A `pandas.Series` + + This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + The returnType should be a primitive data type, e.g., `DoubleType()`. + The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + + >>> from pyspark.sql.types import IntegerType, StringType + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) + >>> @pandas_udf(returnType=StringType()) + ... def to_upper(s): + ... return s.str.upper() + ... + >>> @pandas_udf(returnType="integer") + ... def add_one(x): + ... return x + 1 + ... + >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ + ... .show() # doctest: +SKIP + +----------+--------------+------------+ + |slen(name)|to_upper(name)|add_one(age)| + +----------+--------------+------------+ + | 8| JOHN DOE| 22| + +----------+--------------+------------+ + + 2. A `pandas.DataFrame` -> A `pandas.DataFrame` + + This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. + The returnType should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. + + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` + because it defines a `DataFrame` transformation rather than a `Column` + transformation. + + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + + .. note:: The user-defined function must be deterministic. """ return _create_udf(f, returnType=returnType, vectorized=True) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f2092f9c63054..817d0bc83bb77 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -54,9 +54,10 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jgd, sql_ctx): + def __init__(self, jgd, df): self._jgd = jgd - self.sql_ctx = sql_ctx + self._df = df + self.sql_ctx = df.sql_ctx @ignore_unicode_prefix @since(1.3) @@ -170,7 +171,7 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): """ - Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + Pivots a column of the current :class:`DataFrame` and perform the specified aggregation. There are two versions of pivot function: one that requires the caller to specify the list of distinct values to pivot on, and one that does not. The latter is more concise but less efficient, because Spark needs to first compute the list of distinct values internally. @@ -192,7 +193,85 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col) else: jgd = self._jgd.pivot(pivot_col, values) - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self._df) + + @since(2.3) + def apply(self, udf): + """ + Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result + as a `DataFrame`. + + The user-defined function should take a `pandas.DataFrame` and return another + `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` + to the user-function and the returned `pandas.DataFrame`s are combined as a + :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the + returnType of the pandas udf. + + This function does not support partial aggregation, and requires shuffling all the data in + the :class:`DataFrame`. + + :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + + >>> from pyspark.sql.functions import pandas_udf + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + + """ + from pyspark.sql.functions import pandas_udf + + # Columns are special because hasattr always return True + if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: + raise ValueError("The argument to apply must be a pandas_udf") + if not isinstance(udf.returnType, StructType): + raise ValueError("The returnType of the pandas_udf must be a StructType") + + df = self._df + func = udf.func + returnType = udf.returnType + + # The python executors expects the function to use pd.Series as input and output + # So we to create a wrapper function that turns that to a pd.DataFrame before passing + # down to the user function, then turn the result pd.DataFrame back into pd.Series + columns = df.columns + + def wrapped(*cols): + from pyspark.sql.types import to_arrow_type + import pandas as pd + result = func(pd.concat(cols, axis=1, keys=columns)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "Pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(returnType): + raise RuntimeError( + "Number of columns of the returned Pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(returnType), len(result.columns))) + arrow_return_types = (to_arrow_type(field.dataType) for field in returnType) + return [(result[result.columns[i]], arrow_type) + for i, arrow_type in enumerate(arrow_return_types)] + + wrapped_udf_obj = pandas_udf(wrapped, returnType) + udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) + jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) def _test(): @@ -206,6 +285,7 @@ def _test(): .getOrCreate() sc = spark.sparkContext globs['sc'] = sc + globs['spark'] = spark globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a59378b5e848a..bac2ef84ae7a7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3256,17 +3256,17 @@ def test_vectorized_udf_null_string(self): def test_vectorized_udf_zero_parameter(self): from pyspark.sql.functions import pandas_udf - error_str = '0-parameter pandas_udfs.*not.*supported' + error_str = '0-arg pandas_udfs.*not.*supported' with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): pandas_udf(lambda: 1, LongType()) - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf def zero_no_type(): return 1 - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf(LongType()) def zero_with_type(): return 1 @@ -3348,7 +3348,7 @@ def test_vectorized_udf_wrong_return_type(self): df = self.spark.range(10) f = pandas_udf(lambda x: x * 1.0, StringType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): df.select(f(col('id'))).collect() def test_vectorized_udf_return_scalar(self): @@ -3356,7 +3356,7 @@ def test_vectorized_udf_return_scalar(self): df = self.spark.range(10) f = pandas_udf(lambda x: 1.0, DoubleType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'): + with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'): df.select(f(col('id'))).collect() def test_vectorized_udf_decorator(self): @@ -3376,6 +3376,151 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_varargs(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) + f = pandas_udf(lambda *v: v[0], LongType()) + res = df.select(f(col('id'))) + self.assertEquals(df.collect(), res.collect()) + + +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyApplyTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + def assertFramesEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + + ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) + self.assertTrue(expected.equals(result), msg=msg) + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))).drop('vs') + + def test_simple(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_decorator(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + def foo(pdf): + return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_coerce(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo = pandas_udf( + lambda pdf: pdf, + StructType([StructField('id', LongType()), StructField('v', DoubleType())])) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + expected = expected.assign(v=expected.v.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_complex_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_empty_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby().apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = normalize.func(pdf) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_wrong_return_type(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo = pandas_udf( + lambda pdf: pdf, + StructType([StructField('id', LongType()), StructField('v', StringType())])) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): + df.groupby('id').apply(foo).sort('id').toPandas() + + def test_wrong_args(self): + from pyspark.sql.functions import udf, pandas_udf, sum + df = self.data + + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(lambda x: x) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(udf(lambda x: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(sum(df.v)) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'returnType'): + df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ebdc11c3b744a..f65273d5f0b6c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1597,7 +1597,7 @@ def convert(self, obj, gateway_client): register_input_converter(DateConverter()) -def toArrowType(dt): +def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ import pyarrow as pa diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4e24789cf010d..eb6d48688dc0a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,7 +32,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import toArrowType +from pyspark.sql.types import to_arrow_type, StructType from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,17 +74,28 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - arrow_return_type = toArrowType(return_type) - - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of pandas_udf should be a Pandas.Series") - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + # If the return_type is a StructType, it indicates this is a groupby apply udf, + # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. + # We can distinguish these two by return type because in groupby apply, we always specify + # returnType as a StructType, and in vectorized column udf, StructType is not supported. + # + # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs + if isinstance(return_type, StructType): + return lambda *a: f(*a) + else: + arrow_return_type = to_arrow_type(return_type) + + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result + + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bc2d4a824cb49..d829e01441dcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -452,6 +452,8 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) + case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => + f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala new file mode 100644 index 0000000000000..8abab24bc9b44 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} + +/** + * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame. + * This is used by DataFrame.groupby().apply(). + */ +case class FlatMapGroupsInPandas( + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** + * This is needed because output attributes are considered `references` when + * passed through the constructor. + * + * Without this, catalyst will complain that output attributes are missing + * from the input. + */ + override val producedAttributes = AttributeSet(output) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 147b549964913..cd0ac1feffa51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -27,12 +27,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.NumericType -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{NumericType, StructField, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -435,6 +435,36 @@ class RelationalGroupedDataset protected[sql]( df.logicalPlan.output, df.logicalPlan)) } + + /** + * Applies a vectorized python user-defined function to each group of data. + * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. + * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results + * for all groups are combined into a new [[DataFrame]]. + * + * This function does not support partial aggregation, and requires shuffling all the data in + * the [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ + private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { + require(expr.vectorized, "Must pass a vectorized python udf") + require(expr.dataType.isInstanceOf[StructType], + "The returnType of the vectorized python udf must be a StructType") + + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) + val child = df.logicalPlan + val project = Project(groupingNamedExpressions ++ child.output, child) + val output = expr.dataType.asInstanceOf[StructType].toAttributes + val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) + + Dataset.ofRows(df.sparkSession, plan) + } } private[sql] object RelationalGroupedDataset { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 92eaab5cd8f81..4cdcc73faacd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -392,6 +392,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.FlatMapGroupsInPandas(grouping, func, output, child) => + execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index f7e8cbe416121..81896187ecc46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -26,6 +26,35 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType +/** + * Grouped a iterator into batches. + * This is similar to iter.grouped but returns Iterator[T] instead of Seq[T]. + * This is necessary because sometimes we cannot hold reference of input rows + * because the some input rows are mutable and can be reused. + */ +private class BatchIterator[T](iter: Iterator[T], batchSize: Int) + extends Iterator[Iterator[T]] { + + override def hasNext: Boolean = iter.hasNext + + override def next(): Iterator[T] = { + new Iterator[T] { + var count = 0 + + override def hasNext: Boolean = iter.hasNext && count < batchSize + + override def next(): T = { + if (!hasNext) { + Iterator.empty.next() + } else { + count += 1 + iter.next() + } + } + } + } +} + /** * A physical plan that evaluates a [[PythonUDF]], */ @@ -44,14 +73,18 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex .map { case (attr, i) => attr.withName(s"_$i") }) + val batchSize = conf.arrowMaxRecordsPerBatch + // DO NOT use iter.grouped(). See BatchIterator. + val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) + val columnarBatchIter = new ArrowPythonRunner( - funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, + funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) - .compute(iter, context.partitionId(), context) + .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { - var currentIter = if (columnarBatchIter.hasNext) { + private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() assert(schemaOut.equals(batch.schema), s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index bbad9d6b631fd..f6c03c415dc66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -39,19 +39,18 @@ import org.apache.spark.util.Utils */ class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], - batchSize: Int, bufferSize: Int, reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], schema: StructType) - extends BasePythonRunner[InternalRow, ColumnarBatch]( + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[InternalRow], + inputIterator: Iterator[Iterator[InternalRow]], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { @@ -82,12 +81,12 @@ class ArrowPythonRunner( Utils.tryWithSafeFinally { while (inputIterator.hasNext) { - var rowCount = 0 - while (inputIterator.hasNext && (batchSize <= 0 || rowCount < batchSize)) { - val row = inputIterator.next() - arrowWriter.write(row) - rowCount += 1 + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) } + arrowWriter.finish() writer.writeBatch() arrowWriter.reset() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index fec456d86dbe2..e3f952e221d53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.{FilterExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** @@ -111,6 +110,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { + // FlatMapGroupsInPandas can be evaluated directly in python worker + // Therefore we don't need to extract the UDFs + case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) } @@ -169,7 +171,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - execution.ProjectExec(plan.output, newPlan) + ProjectExec(plan.output, newPlan) } else { newPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala new file mode 100644 index 0000000000000..b996b5bb38ba5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.StructType + +/** + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] + * + * Rows in each group are passed to the Python worker as an Arrow record batch. + * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the + * user-defined function, and passes the resulting `pandas.DataFrame` + * as an Arrow record batch. Finally, each record batch is turned to + * Iterator[InternalRow] using ColumnarBatch. + * + * Note on memory usage: + * Both the Python worker and the Java executor need to have enough memory to + * hold the largest group. The memory on the Java side is used to construct the + * record batch (off heap memory). The memory on the Python side is used for + * holding the `pandas.DataFrame`. It's possible to further split one group into + * multiple record batches to reduce the memory footprint on the Java side, this + * is left as future work. + */ +case class FlatMapGroupsInPandasExec( + groupingAttributes: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + child: SparkPlan) + extends UnaryExecNode { + + private val pandasFunction = func.asInstanceOf[PythonUDF].func + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) + val schema = StructType(child.schema.drop(groupingAttributes.length)) + + inputRDD.mapPartitionsInternal { iter => + val grouped = if (groupingAttributes.isEmpty) { + Iterator(iter) + } else { + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val dropGrouping = + UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + groupedIter.map { + case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + } + } + + val context = TaskContext.get() + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + .compute(grouped, context.partitionId(), context) + + columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) + } + } +} From bd4eb9ce57da7bacff69d9ed958c94f349b7e6fb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 10 Oct 2017 15:50:37 -0700 Subject: [PATCH 693/779] [SPARK-19558][SQL] Add config key to register QueryExecutionListeners automatically. This change adds a new SQL config key that is equivalent to SparkContext's "spark.extraListeners", allowing users to register QueryExecutionListener instances through the Spark configuration system instead of having to explicitly do it in code. The code used by SparkContext to implement the feature was refactored into a helper method in the Utils class, and SQL's ExecutionListenerManager was modified to use it to initialize listener declared in the configuration. Unit tests were added to verify all the new functionality. Author: Marcelo Vanzin Closes #19309 from vanzin/SPARK-19558. --- .../scala/org/apache/spark/SparkContext.scala | 38 ++-------- .../spark/internal/config/package.scala | 7 ++ .../scala/org/apache/spark/util/Utils.scala | 57 ++++++++++++++- .../spark/scheduler/SparkListenerSuite.scala | 6 +- .../org/apache/spark/util/UtilsSuite.scala | 56 ++++++++++++++- .../spark/sql/internal/StaticSQLConf.scala | 8 +++ .../internal/BaseSessionStateBuilder.scala | 3 +- .../sql/util/QueryExecutionListener.scala | 12 +++- .../util/ExecutionListenerManagerSuite.scala | 69 +++++++++++++++++++ 9 files changed, 216 insertions(+), 40 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b3cd03c0cfbe1..6f25d346e6e54 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2344,41 +2344,13 @@ class SparkContext(config: SparkConf) extends Logging { * (e.g. after the web UI and event logging listeners have been registered). */ private def setupAndStartListenerBus(): Unit = { - // Use reflection to instantiate listeners specified via `spark.extraListeners` try { - val listenerClassNames: Seq[String] = - conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "") - for (className <- listenerClassNames) { - // Use reflection to find the right constructor - val constructors = { - val listenerClass = Utils.classForName(className) - listenerClass - .getConstructors - .asInstanceOf[Array[Constructor[_ <: SparkListenerInterface]]] + conf.get(EXTRA_LISTENERS).foreach { classNames => + val listeners = Utils.loadExtensions(classOf[SparkListenerInterface], classNames, conf) + listeners.foreach { listener => + listenerBus.addToSharedQueue(listener) + logInfo(s"Registered listener ${listener.getClass().getName()}") } - val constructorTakingSparkConf = constructors.find { c => - c.getParameterTypes.sameElements(Array(classOf[SparkConf])) - } - lazy val zeroArgumentConstructor = constructors.find { c => - c.getParameterTypes.isEmpty - } - val listener: SparkListenerInterface = { - if (constructorTakingSparkConf.isDefined) { - constructorTakingSparkConf.get.newInstance(conf) - } else if (zeroArgumentConstructor.isDefined) { - zeroArgumentConstructor.get.newInstance() - } else { - throw new SparkException( - s"$className did not have a zero-argument constructor or a" + - " single-argument constructor that accepts SparkConf. Note: if the class is" + - " defined inside of another Scala class, then its constructors may accept an" + - " implicit parameter that references the enclosing class; in this case, you must" + - " define the listener as a top-level class in order to prevent this extra" + - " parameter from breaking Spark's ability to find a valid constructor.") - } - } - listenerBus.addToSharedQueue(listener) - logInfo(s"Registered listener $className") } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 5278e5e0fb270..19336f854145f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -419,4 +419,11 @@ package object config { .stringConf .toSequence .createWithDefault(Nil) + + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") + .doc("Class names of listeners to add to SparkContext during initialization.") + .stringConf + .toSequence + .createOptional + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 836e33c36d9a1..930e09d90c2f5 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} +import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer @@ -37,7 +38,7 @@ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Failure, Success, Try} import scala.util.control.{ControlThrowable, NonFatal} import scala.util.matching.Regex @@ -2687,6 +2688,60 @@ private[spark] object Utils extends Logging { def stringToSeq(str: String): Seq[String] = { str.split(",").map(_.trim()).filter(_.nonEmpty) } + + /** + * Create instances of extension classes. + * + * The classes in the given list must: + * - Be sub-classes of the given base class. + * - Provide either a no-arg constructor, or a 1-arg constructor that takes a SparkConf. + * + * The constructors are allowed to throw "UnsupportedOperationException" if the extension does not + * want to be registered; this allows the implementations to check the Spark configuration (or + * other state) and decide they do not need to be added. A log message is printed in that case. + * Other exceptions are bubbled up. + */ + def loadExtensions[T](extClass: Class[T], classes: Seq[String], conf: SparkConf): Seq[T] = { + classes.flatMap { name => + try { + val klass = classForName(name) + require(extClass.isAssignableFrom(klass), + s"$name is not a subclass of ${extClass.getName()}.") + + val ext = Try(klass.getConstructor(classOf[SparkConf])) match { + case Success(ctor) => + ctor.newInstance(conf) + + case Failure(_) => + klass.getConstructor().newInstance() + } + + Some(ext.asInstanceOf[T]) + } catch { + case _: NoSuchMethodException => + throw new SparkException( + s"$name did not have a zero-argument constructor or a" + + " single-argument constructor that accepts SparkConf. Note: if the class is" + + " defined inside of another Scala class, then its constructors may accept an" + + " implicit parameter that references the enclosing class; in this case, you must" + + " define the class as a top-level class in order to prevent this extra" + + " parameter from breaking Spark's ability to find a valid constructor.") + + case e: InvocationTargetException => + e.getCause() match { + case uoe: UnsupportedOperationException => + logDebug(s"Extension $name not being initialized.", uoe) + logInfo(s"Extension $name not being initialized.") + None + + case null => throw e + + case cause => throw cause + } + } + } + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index d061c7845f4a6..1beb36afa95f0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.internal.config.LISTENER_BUS_EVENT_QUEUE_CAPACITY +import org.apache.spark.internal.config._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{ResetSystemProperties, RpcUtils} @@ -446,13 +446,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match classOf[FirehoseListenerThatAcceptsSparkConf], classOf[BasicJobCounter]) val conf = new SparkConf().setMaster("local").setAppName("test") - .set("spark.extraListeners", listeners.map(_.getName).mkString(",")) + .set(EXTRA_LISTENERS, listeners.map(_.getName)) sc = new SparkContext(conf) sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) sc.listenerBus.listeners.asScala .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) sc.listenerBus.listeners.asScala - .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) + .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) } test("add and remove listeners to/from LiveListenerBus queues") { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 2b16cc4852ba8..4d3adeb968e84 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -38,9 +38,10 @@ import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit +import org.apache.spark.scheduler.SparkListener class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -1110,4 +1111,57 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.tryWithSafeFinallyAndFailureCallbacks {}(catchBlock = {}, finallyBlock = {}) TaskContext.unset } + + test("load extensions") { + val extensions = Seq( + classOf[SimpleExtension], + classOf[ExtensionWithConf], + classOf[UnregisterableExtension]).map(_.getName()) + + val conf = new SparkConf(false) + val instances = Utils.loadExtensions(classOf[Object], extensions, conf) + assert(instances.size === 2) + assert(instances.count(_.isInstanceOf[SimpleExtension]) === 1) + + val extWithConf = instances.find(_.isInstanceOf[ExtensionWithConf]) + .map(_.asInstanceOf[ExtensionWithConf]) + .get + assert(extWithConf.conf eq conf) + + class NestedExtension { } + + val invalid = Seq(classOf[NestedExtension].getName()) + intercept[SparkException] { + Utils.loadExtensions(classOf[Object], invalid, conf) + } + + val error = Seq(classOf[ExtensionWithError].getName()) + intercept[IllegalArgumentException] { + Utils.loadExtensions(classOf[Object], error, conf) + } + + val wrongType = Seq(classOf[ListenerImpl].getName()) + intercept[IllegalArgumentException] { + Utils.loadExtensions(classOf[Seq[_]], wrongType, conf) + } + } + +} + +private class SimpleExtension + +private class ExtensionWithConf(val conf: SparkConf) + +private class UnregisterableExtension { + + throw new UnsupportedOperationException() + +} + +private class ExtensionWithError { + + throw new IllegalArgumentException() + } + +private class ListenerImpl extends SparkListener diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index c6c0a605d89ff..c018fc8a332fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -87,4 +87,12 @@ object StaticSQLConf { "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") .stringConf .createOptional + + val QUERY_EXECUTION_LISTENERS = buildStaticConf("spark.sql.queryExecutionListeners") + .doc("List of class names implementing QueryExecutionListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 4e756084bbdbb..2867b4cd7da5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -266,7 +266,8 @@ abstract class BaseSessionStateBuilder( * This gets cloned from parent if available, otherwise is a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { - parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager) + parentState.map(_.listenerManager.clone()).getOrElse( + new ExecutionListenerManager(session.sparkContext.conf)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index f6240d85fba6f..2b46233e1a5df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -22,9 +22,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.ListBuffer import scala.util.control.NonFatal +import org.apache.spark.SparkConf import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -72,7 +75,14 @@ trait QueryExecutionListener { */ @Experimental @InterfaceStability.Evolving -class ExecutionListenerManager private[sql] () extends Logging { +class ExecutionListenerManager private extends Logging { + + private[sql] def this(conf: SparkConf) = { + this() + conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register) + } + } /** * Registers the specified [[QueryExecutionListener]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala new file mode 100644 index 0000000000000..4205e23ae240a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark._ +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.StaticSQLConf._ + +class ExecutionListenerManagerSuite extends SparkFunSuite { + + import CountingQueryExecutionListener._ + + test("register query execution listeners using configuration") { + val conf = new SparkConf(false) + .set(QUERY_EXECUTION_LISTENERS, Seq(classOf[CountingQueryExecutionListener].getName())) + + val mgr = new ExecutionListenerManager(conf) + assert(INSTANCE_COUNT.get() === 1) + mgr.onSuccess(null, null, 42L) + assert(CALLBACK_COUNT.get() === 1) + + val clone = mgr.clone() + assert(INSTANCE_COUNT.get() === 1) + + clone.onSuccess(null, null, 42L) + assert(CALLBACK_COUNT.get() === 2) + } + +} + +private class CountingQueryExecutionListener extends QueryExecutionListener { + + import CountingQueryExecutionListener._ + + INSTANCE_COUNT.incrementAndGet() + + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + CALLBACK_COUNT.incrementAndGet() + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + CALLBACK_COUNT.incrementAndGet() + } + +} + +private object CountingQueryExecutionListener { + + val CALLBACK_COUNT = new AtomicInteger() + val INSTANCE_COUNT = new AtomicInteger() + +} From 76fb173dd639baa9534486488155fc05a71f850e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 10 Oct 2017 20:29:02 -0700 Subject: [PATCH 694/779] [SPARK-21751][SQL] CodeGeneraor.splitExpressions counts code size more precisely ## What changes were proposed in this pull request? Current `CodeGeneraor.splitExpressions` splits statements into methods if the total length of statements is more than 1024 characters. The length may include comments or empty line. This PR excludes comment or empty line from the length to reduce the number of generated methods in a class, by using `CodeFormatter.stripExtraNewLinesAndComments()` method. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #18966 from kiszk/SPARK-21751. --- .../expressions/codegen/CodeFormatter.scala | 8 +++++ .../expressions/codegen/CodeGenerator.scala | 5 ++- .../codegen/CodeFormatterSuite.scala | 32 +++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 60e600d8dbd8f..7b398f424cead 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -89,6 +89,14 @@ object CodeFormatter { } new CodeAndComment(code.result().trim(), map) } + + def stripExtraNewLinesAndComments(input: String): String = { + val commentReg = + ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ + """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment + val codeWithoutComment = commentReg.replaceAllIn(input, "") + codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines + } } private class CodeFormatter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f9c5ef8439085..2cb66599076a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -772,16 +772,19 @@ class CodegenContext { foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() + var length = 0 for (code <- expressions) { // We can't know how many bytecode will be generated, so use the length of source code // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should // also not be too small, or it will have many function calls (for wide table), see the // results in BenchmarkWideTable. - if (blockBuilder.length > 1024) { + if (length > 1024) { blocks += blockBuilder.toString() blockBuilder.clear() + length = 0 } blockBuilder.append(code) + length += CodeFormatter.stripExtraNewLinesAndComments(code).length } blocks += blockBuilder.toString() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 9d0a41661beaa..a0f1a64b0ab08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -53,6 +53,38 @@ class CodeFormatterSuite extends SparkFunSuite { assert(reducedCode.body === "/*project_c4*/") } + test("removing extra new lines and comments") { + val code = + """ + |/* + | * multi + | * line + | * comments + | */ + | + |public function() { + |/*comment*/ + | /*comment_with_space*/ + |code_body + |//comment + |code_body + | //comment_with_space + | + |code_body + |} + """.stripMargin + + val reducedCode = CodeFormatter.stripExtraNewLinesAndComments(code) + assert(reducedCode === + """ + |public function() { + |code_body + |code_body + |code_body + |} + """.stripMargin) + } + testCase("basic example") { """ |class A { From 655f6f86f84ff5241d1d20766e1ef83bb32ca5e0 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 11 Oct 2017 00:16:12 -0700 Subject: [PATCH 695/779] [SPARK-22208][SQL] Improve percentile_approx by not rounding up targetError and starting from index 0 ## What changes were proposed in this pull request? Currently percentile_approx never returns the first element when percentile is in (relativeError, 1/N], where relativeError default 1/10000, and N is the total number of elements. But ideally, percentiles in [0, 1/N] should all return the first element as the answer. For example, given input data 1 to 10, if a user queries 10% (or even less) percentile, it should return 1, because the first value 1 already reaches 10%. Currently it returns 2. Based on the paper, targetError is not rounded up, and searching index should start from 0 instead of 1. By following the paper, we should be able to fix the cases mentioned above. ## How was this patch tested? Added a new test case and fix existing test cases. Author: Zhenhua Wang Closes #19438 from wzhfy/improve_percentile_approx. --- R/pkg/tests/fulltests/test_sparkSQL.R | 8 ++++---- .../apache/spark/ml/feature/ImputerSuite.scala | 2 +- python/pyspark/sql/dataframe.py | 6 +++--- .../sql/catalyst/util/QuantileSummaries.scala | 4 ++-- .../catalyst/util/QuantileSummariesSuite.scala | 10 ++++++++-- .../sql/ApproximatePercentileQuerySuite.scala | 17 ++++++++++++++++- .../apache/spark/sql/DataFrameStatSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 8 files changed, 36 insertions(+), 15 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index bbea25bc4da5c..4382ef2ed4525 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2538,7 +2538,7 @@ test_that("describe() and summary() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[5, "summary"], "25%") - expect_equal(collect(stats2)[5, "age"], "30") + expect_equal(collect(stats2)[5, "age"], "19") stats3 <- summary(df, "min", "max", "55.1%") @@ -2738,7 +2738,7 @@ test_that("sampleBy() on a DataFrame", { }) test_that("approxQuantile() on a DataFrame", { - l <- lapply(c(0:99), function(i) { list(i, 99 - i) }) + l <- lapply(c(0:100), function(i) { list(i, 100 - i) }) df <- createDataFrame(l, list("a", "b")) quantiles <- approxQuantile(df, "a", c(0.5, 0.8), 0.0) expect_equal(quantiles, list(50, 80)) @@ -2749,8 +2749,8 @@ test_that("approxQuantile() on a DataFrame", { dfWithNA <- createDataFrame(data.frame(a = c(NA, 30, 19, 11, 28, 15), b = c(-30, -19, NA, -11, -28, -15))) quantiles3 <- approxQuantile(dfWithNA, c("a", "b"), c(0.5), 0.0) - expect_equal(quantiles3[[1]], list(28)) - expect_equal(quantiles3[[2]], list(-15)) + expect_equal(quantiles3[[1]], list(19)) + expect_equal(quantiles3[[2]], list(-19)) }) test_that("SQL error message is returned from JVM", { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index ee2ba73fa96d5..c08b35b419266 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -43,7 +43,7 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default (0, 1.0, 1.0, 1.0), (1, 3.0, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 3.0) + (3, -1.0, 2.0, 1.0) )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1.0) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2d596229ced7e..38b01f0011671 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1038,8 +1038,8 @@ def summary(self, *statistics): | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| - | 25%| 5| null| - | 50%| 5| null| + | 25%| 2| null| + | 50%| 2| null| | 75%| 5| null| | max| 5| Bob| +-------+------------------+-----+ @@ -1050,7 +1050,7 @@ def summary(self, *statistics): +-------+---+-----+ | count| 2| 2| | min| 2|Alice| - | 25%| 5| null| + | 25%| 2| null| | 75%| 5| null| | max| 5| Bob| +-------+---+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index af543b04ba780..eb7941cf9e6af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -193,10 +193,10 @@ class QuantileSummaries( // Target rank val rank = math.ceil(quantile * count).toInt - val targetError = math.ceil(relativeError * count) + val targetError = relativeError * count // Minimum rank at current sample var minRank = 0 - var i = 1 + var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) minRank += curSample.g diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala index df579d5ec1ddf..650813975d75c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala @@ -57,8 +57,14 @@ class QuantileSummariesSuite extends SparkFunSuite { private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { if (data.nonEmpty) { val approx = summary.query(quant).get - // The rank of the approximation. - val rank = data.count(_ < approx) // has to be <, not <= to be exact + // Get the rank of the approximation. + val rankOfValue = data.count(_ <= approx) + val rankOfPreValue = data.count(_ < approx) + // `rankOfValue` is the last position of the quantile value. If the input repeats the value + // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's + // improper to choose the last position as its rank. Instead, we get the rank by averaging + // `rankOfValue` and `rankOfPreValue`. + val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0) val lower = math.floor((quant - summary.relativeError) * data.size) val upper = math.ceil((quant + summary.relativeError) * data.size) val msg = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 1aea33766407f..137c5bea2abb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -53,6 +53,21 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { } } + test("percentile_approx, the first element satisfies small percentages") { + withTempView(table) { + (1 to 10).toDF("col").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s""" + |SELECT + | percentile_approx(col, array(0.01, 0.1, 0.11)) + |FROM $table + """.stripMargin), + Row(Seq(1, 1, 2)) + ) + } + } + test("percentile_approx, array of percentile value") { withTempView(table) { (1 to 1000).toDF("col").createOrReplaceTempView(table) @@ -130,7 +145,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { (1 to 1000).toDF("col").createOrReplaceTempView(table) checkAnswer( spark.sql(s"SELECT percentile_approx(col, array(0.25 + 0.25D), 200 + 800D) FROM $table"), - Row(Seq(500D)) + Row(Seq(499)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 247c30e2ee65b..46b21c3b64a2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -141,7 +141,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("approximate quantile") { val n = 1000 - val df = Seq.tabulate(n)(i => (i, 2.0 * i)).toDF("singles", "doubles") + val df = Seq.tabulate(n + 1)(i => (i, 2.0 * i)).toDF("singles", "doubles") val q1 = 0.5 val q2 = 0.8 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dd8f54b690f64..ad461fa6144b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -855,7 +855,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("mean", null, "33.0", "178.0"), Row("stddev", null, "19.148542155126762", "11.547005383792516"), Row("min", "Alice", "16", "164"), - Row("25%", null, "24", "176"), + Row("25%", null, "16", "164"), Row("50%", null, "24", "176"), Row("75%", null, "32", "180"), Row("max", "David", "60", "192")) From 645e108eeb6364e57f5d7213dbbd42dbcf1124d3 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Oct 2017 13:51:33 -0700 Subject: [PATCH 696/779] [SPARK-21988][SS] Implement StreamingRelation.computeStats to fix explain ## What changes were proposed in this pull request? Implement StreamingRelation.computeStats to fix explain ## How was this patch tested? - unit tests: `StreamingRelation.computeStats` and `StreamingExecutionRelation.computeStats`. - regression tests: `explain join with a normal source` and `explain join with MemoryStream`. Author: Shixiong Zhu Closes #19465 from zsxwing/SPARK-21988. --- .../streaming/StreamingRelation.scala | 8 +++ .../spark/sql/streaming/StreamSuite.scala | 65 ++++++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index ab716052c28ba..6b82c78ea653d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -44,6 +44,14 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: extends LeafNode { override def isStreaming: Boolean = true override def toString: String = sourceName + + // There's no sensible value here. On the execution path, this relation will be + // swapped out with microbatches. But some dataframe operations (in particular explain) do lead + // to this node surviving analysis. So we satisfy the LeafNode contract with the session default + // value. + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) + ) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9c901062d570a..3d687d2214e90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -76,20 +76,65 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("StreamingRelation.computeStats") { + val streamingRelation = spark.readStream.format("rate").load().logicalPlan collect { + case s: StreamingRelation => s + } + assert(streamingRelation.nonEmpty, "cannot find StreamingRelation") + assert( + streamingRelation.head.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) + } - test("explain join") { - // Make a table and ensure it will be broadcast. - val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + test("StreamingExecutionRelation.computeStats") { + val streamingExecutionRelation = MemoryStream[Int].toDF.logicalPlan collect { + case s: StreamingExecutionRelation => s + } + assert(streamingExecutionRelation.nonEmpty, "cannot find StreamingExecutionRelation") + assert(streamingExecutionRelation.head.computeStats.sizeInBytes + == spark.sessionState.conf.defaultSizeInBytes) + } - // Join the input stream with a table. - val inputData = MemoryStream[Int] - val joined = inputData.toDF().join(smallTable, smallTable("number") === $"value") + test("explain join with a normal source") { + // This test triggers CostBasedJoinReorder to call `computeStats` + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = spark.readStream.format("rate").load() + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) + } + } - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - joined.explain() + test("explain join with MemoryStream") { + // This test triggers CostBasedJoinReorder to call `computeStats` + // Because MemoryStream doesn't use DataSource code path, we need a separate test. + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = MemoryStream[Int].toDF + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) } - assert(outputStream.toString.contains("StreamingRelation")) } test("SPARK-20432: union one stream with itself") { From ccdf21f56e4ff5497d7770dcbee2f7a60bb9e3a7 Mon Sep 17 00:00:00 2001 From: Jorge Machado Date: Wed, 11 Oct 2017 22:13:07 -0700 Subject: [PATCH 697/779] [SPARK-20055][DOCS] Added documentation for loading csv files into DataFrames ## What changes were proposed in this pull request? Added documentation for loading csv files into Dataframes ## How was this patch tested? /dev/run-tests Author: Jorge Machado Closes #19429 from jomach/master. --- docs/sql-programming-guide.md | 32 ++++++++++++++++--- .../sql/JavaSQLDataSourceExample.java | 7 ++++ examples/src/main/python/sql/datasource.py | 5 +++ examples/src/main/r/RSparkSQLExample.R | 6 ++++ examples/src/main/resources/people.csv | 3 ++ .../examples/sql/SQLDataSourceExample.scala | 8 +++++ 6 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 examples/src/main/resources/people.csv diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a095263bfa619..639a8ea7bb8ad 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -461,6 +461,8 @@ name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can al names (`json`, `parquet`, `jdbc`, `orc`, `libsvm`, `csv`, `text`). DataFrames loaded from any data source type can be converted into other types using this syntax. +To load a JSON file you can use: +
{% include_example manual_load_options scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} @@ -479,6 +481,26 @@ source type can be converted into other types using this syntax.
+To load a CSV file you can use: + +
+
+{% include_example manual_load_options_csv scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example manual_load_options_csv java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example manual_load_options_csv python/sql/datasource.py %} +
+ +
+{% include_example manual_load_options_csv r/RSparkSQLExample.R %} + +
+
### Run SQL on files directly Instead of using read API to load a file into DataFrame and query it, you can also query that @@ -573,7 +595,7 @@ Note that partition information is not gathered by default when creating externa ### Bucketing, Sorting and Partitioning -For file-based data source, it is also possible to bucket and sort or partition the output. +For file-based data source, it is also possible to bucket and sort or partition the output. Bucketing and sorting are applicable only to persistent tables:
@@ -598,7 +620,7 @@ CREATE TABLE users_bucketed_by_name( name STRING, favorite_color STRING, favorite_numbers array -) USING parquet +) USING parquet CLUSTERED BY(name) INTO 42 BUCKETS; {% endhighlight %} @@ -629,7 +651,7 @@ while partitioning can be used with both `save` and `saveAsTable` when using the {% highlight sql %} CREATE TABLE users_by_favorite_color( - name STRING, + name STRING, favorite_color STRING, favorite_numbers array ) USING csv PARTITIONED BY(favorite_color); @@ -664,7 +686,7 @@ CREATE TABLE users_bucketed_and_partitioned( name STRING, favorite_color STRING, favorite_numbers array -) USING parquet +) USING parquet PARTITIONED BY (favorite_color) CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS; @@ -675,7 +697,7 @@ CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS;
`partitionBy` creates a directory structure as described in the [Partition Discovery](#partition-discovery) section. -Thus, it has limited applicability to columns with high cardinality. In contrast +Thus, it has limited applicability to columns with high cardinality. In contrast `bucketBy` distributes data across a fixed number of buckets and can be used when a number of unique values is unbounded. diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 95859c52c2aeb..ef3c904775697 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -116,6 +116,13 @@ private static void runBasicDataSourceExample(SparkSession spark) { spark.read().format("json").load("examples/src/main/resources/people.json"); peopleDF.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); // $example off:manual_load_options$ + // $example on:manual_load_options_csv$ + Dataset peopleDFCsv = spark.read().format("csv") + .option("sep", ";") + .option("inferSchema", "true") + .option("header", "true") + .load("examples/src/main/resources/people.csv"); + // $example off:manual_load_options_csv$ // $example on:direct_sql$ Dataset sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index f86012ea382e8..b375fa775de39 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -53,6 +53,11 @@ def basic_datasource_example(spark): df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") # $example off:manual_load_options$ + # $example on:manual_load_options_csv$ + df = spark.read.load("examples/src/main/resources/people.csv", + format="csv", sep=":", inferSchema="true", header="true") + # $example off:manual_load_options_csv$ + # $example on:write_sorting_and_bucketing$ df.write.bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed") # $example off:write_sorting_and_bucketing$ diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 3734568d872d0..a5ed723da47ca 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -113,6 +113,12 @@ write.df(namesAndAges, "namesAndAges.parquet", "parquet") # $example off:manual_load_options$ +# $example on:manual_load_options_csv$ +df <- read.df("examples/src/main/resources/people.csv", "csv") +namesAndAges <- select(df, "name", "age") +# $example off:manual_load_options_csv$ + + # $example on:direct_sql$ df <- sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") # $example off:direct_sql$ diff --git a/examples/src/main/resources/people.csv b/examples/src/main/resources/people.csv new file mode 100644 index 0000000000000..7fe5adba93d77 --- /dev/null +++ b/examples/src/main/resources/people.csv @@ -0,0 +1,3 @@ +name;age;job +Jorge;30;Developer +Bob;32;Developer diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 86b3dc4a84f58..f9477969a4bb5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -49,6 +49,14 @@ object SQLDataSourceExample { val peopleDF = spark.read.format("json").load("examples/src/main/resources/people.json") peopleDF.select("name", "age").write.format("parquet").save("namesAndAges.parquet") // $example off:manual_load_options$ + // $example on:manual_load_options_csv$ + val peopleDFCsv = spark.read.format("csv") + .option("sep", ";") + .option("inferSchema", "true") + .option("header", "true") + .load("examples/src/main/resources/people.csv") + // $example off:manual_load_options_csv$ + // $example on:direct_sql$ val sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") // $example off:direct_sql$ From 274f0efefa0c063649bccddb787e8863910f4366 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Oct 2017 20:20:44 +0800 Subject: [PATCH 698/779] [SPARK-22252][SQL] FileFormatWriter should respect the input query schema ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/18064, we allowed `RunnableCommand` to have children in order to fix some UI issues. Then we made `InsertIntoXXX` commands take the input `query` as a child, when we do the actual writing, we just pass the physical plan to the writer(`FileFormatWriter.write`). However this is problematic. In Spark SQL, optimizer and planner are allowed to change the schema names a little bit. e.g. `ColumnPruning` rule will remove no-op `Project`s, like `Project("A", Scan("a"))`, and thus change the output schema from "" to ``. When it comes to writing, especially for self-description data format like parquet, we may write the wrong schema to the file and cause null values at the read path. Fortunately, in https://github.com/apache/spark/pull/18450 , we decided to allow nested execution and one query can map to multiple executions in the UI. This releases the major restriction in #18604 , and now we don't have to take the input `query` as child of `InsertIntoXXX` commands. So the fix is simple, this PR partially revert #18064 and make `InsertIntoXXX` commands leaf nodes again. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #19474 from cloud-fan/bug. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../sql/catalyst/plans/logical/Command.scala | 3 +- .../spark/sql/execution/QueryExecution.scala | 4 +-- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../execution/columnar/InMemoryRelation.scala | 3 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../command/DataWritingCommand.scala | 13 ++++++++ .../InsertIntoDataSourceDirCommand.scala | 6 ++-- .../spark/sql/execution/command/cache.scala | 2 +- .../sql/execution/command/commands.scala | 30 ++++++------------- .../command/createDataSourceTables.scala | 2 +- .../spark/sql/execution/command/views.scala | 4 +-- .../execution/datasources/DataSource.scala | 13 +++++++- .../datasources/FileFormatWriter.scala | 14 +++++---- .../InsertIntoDataSourceCommand.scala | 2 +- .../InsertIntoHadoopFsRelationCommand.scala | 9 ++---- .../SaveIntoDataSourceCommand.scala | 2 +- .../execution/streaming/FileStreamSink.scala | 2 +- .../datasources/FileFormatWriterSuite.scala | 16 +++++++++- .../execution/InsertIntoHiveDirCommand.scala | 10 ++----- .../hive/execution/InsertIntoHiveTable.scala | 14 ++------- .../sql/hive/execution/SaveAsHiveFile.scala | 6 ++-- 22 files changed, 85 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7addbaaa9afa5..c7952e3ff8280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -178,7 +178,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }) } - override def innerChildren: Seq[QueryPlan[_]] = subqueries + override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** * Returns a plan where a best effort attempt has been made to transform `this` in a way diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index ec5766e1f67f2..38f47081b6f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command extends LogicalPlan { +trait Command extends LeafNode { override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 4accf54a18232..f404621399cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -119,7 +119,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * `SparkSQLDriver` for CLI applications. */ def hiveResultString(): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand, _) => + case ExecutedCommandExec(desc: DescribeTableCommand) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. desc.run(sparkSession).map { @@ -130,7 +130,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .mkString("\t") } // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand, _) if !s.isExtended => + case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4cdcc73faacd7..19b858faba6ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -364,7 +364,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil + case r: RunnableCommand => ExecutedCommandExec(r) :: Nil case MemoryPlan(sink, output) => val encoder = RowEncoder(sink.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index bc98d8d9d6d61..a1c62a729900e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -62,7 +62,8 @@ case class InMemoryRelation( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { - override def innerChildren: Seq[SparkPlan] = Seq(child) + + override protected def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index c7ddec55682e1..af3636a5a2ca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -34,7 +34,7 @@ case class InMemoryTableScanExec( @transient relation: InMemoryRelation) extends LeafExecNode { - override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 4e1c5e4846f36..2cf06982e25f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.SerializableConfiguration @@ -30,6 +31,18 @@ import org.apache.spark.util.SerializableConfiguration */ trait DataWritingCommand extends RunnableCommand { + /** + * The input query plan that produces the data to be written. + */ + def query: LogicalPlan + + // We make the input `query` an inner child instead of a child in order to hide it from the + // optimizer. This is because optimizer may not preserve the output schema names' case, and we + // have to keep the original analyzed plan here so that we can pass the corrected schema to the + // writer. The schema of analyzed plan is what user expects(or specifies), so we should respect + // it when writing. + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override lazy val metrics: Map[String, SQLMetric] = { val sparkContext = SparkContext.getActive.get Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala index 633de4c37af94..9e3519073303c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources._ /** @@ -45,10 +44,9 @@ case class InsertIntoDataSourceDirCommand( query: LogicalPlan, overwrite: Boolean) extends RunnableCommand { - override def children: Seq[LogicalPlan] = Seq(query) + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) + override def run(sparkSession: SparkSession): Seq[Row] = { assert(storage.locationUri.nonEmpty, "Directory path is required") assert(provider.nonEmpty, "Data source is required") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 792290bef0163..140f920eaafae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -30,7 +30,7 @@ case class CacheTableCommand( require(plan.isEmpty || tableIdent.database.isEmpty, "Database name is not allowed in CACHE TABLE AS SELECT") - override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq + override protected def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 7cd4baef89e75..e28b5eb2e2a2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} @@ -37,19 +37,13 @@ import org.apache.spark.sql.types._ * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -trait RunnableCommand extends logical.Command { +trait RunnableCommand extends Command { // The map used to record the metrics of running the command. This will be passed to // `ExecutedCommand` during query planning. lazy val metrics: Map[String, SQLMetric] = Map.empty - def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - throw new NotImplementedError - } - - def run(sparkSession: SparkSession): Seq[Row] = { - throw new NotImplementedError - } + def run(sparkSession: SparkSession): Seq[Row] } /** @@ -57,9 +51,8 @@ trait RunnableCommand extends logical.Command { * saves the result to prevent multiple executions. * * @param cmd the `RunnableCommand` this operator will run. - * @param children the children physical plans ran by the `RunnableCommand`. */ -case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan { +case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode { override lazy val metrics: Map[String, SQLMetric] = cmd.metrics @@ -74,19 +67,14 @@ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) e */ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - val rows = if (children.isEmpty) { - cmd.run(sqlContext.sparkSession) - } else { - cmd.run(sqlContext.sparkSession, children) - } - rows.map(converter(_).asInstanceOf[InternalRow]) + cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow]) } - override def innerChildren: Seq[QueryPlan[_]] = cmd.innerChildren + override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil override def output: Seq[Attribute] = cmd.output - override def nodeName: String = cmd.nodeName + override def nodeName: String = "Execute " + cmd.nodeName override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 04b2534ca5eb1..9e3907996995c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -120,7 +120,7 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { - override def innerChildren: Seq[LogicalPlan] = Seq(query) + override protected def innerChildren: Seq[LogicalPlan] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index ffdfd527fa701..5172f32ec7b9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -98,7 +98,7 @@ case class CreateViewCommand( import ViewHelper._ - override def innerChildren: Seq[QueryPlan[_]] = Seq(child) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) if (viewType == PersistedView) { require(originalText.isDefined, "'originalText' must be provided to create permanent view") @@ -267,7 +267,7 @@ case class AlterViewAsCommand( import ViewHelper._ - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(session: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b9502a95a7c08..b43d282bd434c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -453,6 +453,17 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) + + // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does + // not need to have the query as child, to avoid to analyze an optimized query, + // because InsertIntoHadoopFsRelationCommand will be optimized first. + val partitionAttributes = partitionColumns.map { name => + data.output.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location @@ -465,7 +476,7 @@ case class DataSource( outputPath = outputPath, staticPartitions = Map.empty, ifPartitionNotExists = false, - partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), + partitionColumns = partitionAttributes, bucketSpec = bucketSpec, fileFormat = format, options = options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 514969715091a..75b1695fbc275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging { */ def write( sparkSession: SparkSession, - plan: SparkPlan, + queryExecution: QueryExecution, fileFormat: FileFormat, committer: FileCommitProtocol, outputSpec: OutputSpec, @@ -117,7 +117,9 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = plan.output + // Pick the attributes from analyzed plan, as optimizer may not preserve the output schema + // names' case. + val allColumns = queryExecution.analyzed.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = allColumns.filterNot(partitionSet.contains) @@ -158,7 +160,7 @@ object FileFormatWriter extends Logging { // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns // the sort order doesn't matter - val actualOrdering = plan.outputOrdering.map(_.child) + val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { false } else { @@ -176,12 +178,12 @@ object FileFormatWriter extends Logging { try { val rdd = if (orderingMatched) { - plan.execute() + queryExecution.toRdd } else { SortExec( requiredOrdering.map(SortOrder(_, Ascending)), global = false, - child = plan).execute() + child = queryExecution.executedPlan).execute() } val ret = new Array[WriteTaskResult](rdd.partitions.length) sparkSession.sparkContext.runJob( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 08b2f4f31170f..a813829d50cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -33,7 +33,7 @@ case class InsertIntoDataSourceCommand( overwrite: Boolean) extends RunnableCommand { - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 64e5a57adc37c..675bee85bf61e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.util.SchemaUtils @@ -57,11 +56,7 @@ case class InsertIntoHadoopFsRelationCommand( extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName - override def children: Seq[LogicalPlan] = query :: Nil - - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) - + override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkSchemaColumnNameDuplication( query.schema, @@ -144,7 +139,7 @@ case class InsertIntoHadoopFsRelationCommand( val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 5eb6a8471be0d..96c84eab1c894 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -38,7 +38,7 @@ case class SaveIntoDataSourceCommand( options: Map[String, String], mode: SaveMode) extends RunnableCommand { - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { dataSource.createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 72e5ac40bbfed..6bd0696622005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -121,7 +121,7 @@ class FileStreamSink( FileFormatWriter.write( sparkSession = sparkSession, - plan = data.queryExecution.executedPlan, + queryExecution = data.queryExecution, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala index a0c1ea63d3827..6f8767db176aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext class FileFormatWriterSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("empty file should be skipped while write to file") { withTempPath { path => @@ -30,4 +31,17 @@ class FileFormatWriterSuite extends QueryTest with SharedSQLContext { assert(partFiles.length === 2) } } + + test("FileFormatWriter should respect the input query schema") { + withTable("t1", "t2", "t3", "t4") { + spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") + spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") + checkAnswer(spark.table("t2"), Row(0, 0)) + + // Test picking part of the columns when writing. + spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3") + spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4") + checkAnswer(spark.table("t4"), Row(0, 0)) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 918c8be00d69d..1c6f8dd77fc2c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -27,11 +27,10 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred._ import org.apache.spark.SparkException -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.hive.client.HiveClientImpl /** @@ -57,10 +56,7 @@ case class InsertIntoHiveDirCommand( query: LogicalPlan, overwrite: Boolean) extends SaveAsHiveFile { - override def children: Seq[LogicalPlan] = query :: Nil - - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) + override def run(sparkSession: SparkSession): Seq[Row] = { assert(storage.locationUri.nonEmpty) val hiveTable = HiveClientImpl.toHiveTable(CatalogTable( @@ -102,7 +98,7 @@ case class InsertIntoHiveDirCommand( try { saveAsHiveFile( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpPath.toString) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index e5b59ed7a1a6b..56e10bc457a00 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,20 +17,16 @@ package org.apache.spark.sql.hive.execution -import scala.util.control.NonFatal - import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils -import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.HiveClientImpl @@ -72,16 +68,12 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifPartitionNotExists: Boolean) extends SaveAsHiveFile { - override def children: Seq[LogicalPlan] = query :: Nil - /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. */ - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) - + override def run(sparkSession: SparkSession): Seq[Row] = { val externalCatalog = sparkSession.sharedState.externalCatalog val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -170,7 +162,7 @@ case class InsertIntoHiveTable( saveAsHiveFile( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 2d74ef040ef5a..63657590e5e79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive.HiveExternalCatalog @@ -47,7 +47,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { protected def saveAsHiveFile( sparkSession: SparkSession, - plan: SparkPlan, + queryExecution: QueryExecution, hadoopConf: Configuration, fileSinkConf: FileSinkDesc, outputLocation: String, @@ -75,7 +75,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { FileFormatWriter.write( sparkSession = sparkSession, - plan = plan, + queryExecution = queryExecution, fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, outputSpec = FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations), From b5c1ef7a8e4db4067bc361d10d554ee9a538423f Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Thu, 12 Oct 2017 20:26:51 +0800 Subject: [PATCH 699/779] [SPARK-22097][CORE] Request an accurate memory after we unrolled the block ## What changes were proposed in this pull request? We only need request `bbos.size - unrollMemoryUsedByThisBlock` after unrolled the block. ## How was this patch tested? Existing UT. Author: Xianyang Liu Closes #19316 from ConeyLiu/putIteratorAsBytes. --- .../org/apache/spark/storage/memory/MemoryStore.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 651e9c7b2ab61..17f7a69ad6ba1 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -388,7 +388,13 @@ private[spark] class MemoryStore( // perform one final call to attempt to allocate additional memory if necessary. if (keepUnrolling) { serializationStream.close() - reserveAdditionalMemoryIfNecessary() + if (bbos.size > unrollMemoryUsedByThisBlock) { + val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + } } if (keepUnrolling) { From 73d80ec49713605d6a589e688020f0fc2d6feab2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Oct 2017 20:34:03 +0800 Subject: [PATCH 700/779] [SPARK-22197][SQL] push down operators to data source before planning ## What changes were proposed in this pull request? As we discussed in https://github.com/apache/spark/pull/19136#discussion_r137023744 , we should push down operators to data source before planning, so that data source can report statistics more accurate. This PR also includes some cleanup for the read path. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19424 from cloud-fan/follow. --- .../spark/sql/sources/v2/ReadSupport.java | 5 +- .../sql/sources/v2/ReadSupportWithSchema.java | 5 +- .../sql/sources/v2/reader/DataReader.java | 4 + .../sources/v2/reader/DataSourceV2Reader.java | 2 +- .../spark/sql/sources/v2/reader/ReadTask.java | 3 +- .../SupportsPushDownCatalystFilters.java | 8 + .../v2/reader/SupportsPushDownFilters.java | 8 + .../apache/spark/sql/DataFrameReader.scala | 5 +- .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../v2/DataSourceReaderHolder.scala | 68 +++++++++ .../datasources/v2/DataSourceV2Relation.scala | 8 +- .../datasources/v2/DataSourceV2ScanExec.scala | 22 +-- .../datasources/v2/DataSourceV2Strategy.scala | 60 +------- .../v2/PushDownOperatorsToDataSource.scala | 140 ++++++++++++++++++ .../sources/v2/JavaAdvancedDataSourceV2.java | 5 + .../sql/sources/v2/DataSourceV2Suite.scala | 2 + 16 files changed, 262 insertions(+), 87 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index ab5254a688d5a..ee489ad0f608f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,9 +30,8 @@ public interface ReadSupport { /** * Creates a {@link DataSourceV2Reader} to scan the data from this data source. * - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index c13aeca2ef36f..74e81a2c84d68 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -39,9 +39,8 @@ public interface ReadSupportWithSchema { * physical schema of the underlying storage of this data source reader, e.g. * CSV files, JSON files, etc, while this reader may not read data with full * schema, as column pruning or other optimizations may happen. - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index cfafc1a576793..95e091569b614 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -24,6 +24,10 @@ /** * A data reader returned by {@link ReadTask#createReader()} and is responsible for outputting data * for a RDD partition. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source + * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving public interface DataReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index fb4d5c0d7ae41..5989a4ac8440b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -30,7 +30,7 @@ * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic should be delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java index 7885bfcdd49e4..01362df0978cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -27,7 +27,8 @@ * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * * Note that, the read task will be serialized and sent to executors, then the data reader will be - * created on executors and do the actual reading. + * created on executors and do the actual reading. So {@link ReadTask} must be serializable and + * {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving public interface ReadTask extends Serializable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 19d706238ec8e..d6091774d75aa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -40,4 +40,12 @@ public interface SupportsPushDownCatalystFilters { * Pushes down filters, and returns unsupported filters. */ Expression[] pushCatalystFilters(Expression[] filters); + + /** + * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * It's possible that there is no filters in the query and + * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for + * this case. + */ + Expression[] pushedCatalystFilters(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index d4b509e7080f2..d6f297c013375 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.sources.Filter; /** @@ -35,4 +36,11 @@ public interface SupportsPushDownFilters { * Pushes down filters, and returns unsupported filters. */ Filter[] pushFilters(Filter[] filters); + + /** + * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} + * is never called, empty array should be returned for this case. + */ + Filter[] pushedFilters(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 78b668c04fd5c..17966eecfc051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -184,7 +184,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val dataSource = cls.newInstance() val options = new DataSourceV2Options(extraOptions.asJava) val reader = (cls.newInstance(), userSpecifiedSchema) match { @@ -194,8 +193,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { case (ds: ReadSupport, None) => ds.createReader(options) - case (_: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $dataSource.") + case (ds: ReadSupportWithSchema, None) => + throw new AnalysisException(s"A schema needs to be specified when using $ds.") case (ds: ReadSupport, Some(schema)) => val reader = ds.createReader(options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00ff4c8ac310b..1c8e4050978dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions +import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -31,7 +32,8 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala new file mode 100644 index 0000000000000..6093df26630cd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.sources.v2.reader._ + +/** + * A base class for data source reader holder with customized equals/hashCode methods. + */ +trait DataSourceReaderHolder { + + /** + * The full output of the data source reader, without column pruning. + */ + def fullOutput: Seq[AttributeReference] + + /** + * The held data source reader. + */ + def reader: DataSourceV2Reader + + /** + * The metadata of this data source reader that can be used for equality test. + */ + private def metadata: Seq[Any] = { + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Nil + } + Seq(fullOutput, reader.getClass, reader.readSchema(), filters) + } + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceReaderHolder => + canEqual(other) && metadata.length == other.metadata.length && + metadata.zip(other.metadata).forall { case (l, r) => l == r } + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => + fullOutput.find(_.name == name).get + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3c9b598fd07c9..7eb99a645001a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader.{DataSourceV2Reader, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceV2Reader) extends LeafNode { + fullOutput: Seq[AttributeReference], + reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { + + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7999c0ceb5749..addc12a3f0901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -29,20 +29,14 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.StructType +/** + * Physical plan node for scanning data from a data source. + */ case class DataSourceV2ScanExec( - fullOutput: Array[AttributeReference], - @transient reader: DataSourceV2Reader, - // TODO: these 3 parameters are only used to determine the equality of the scan node, however, - // the reader also have this information, and ideally we can just rely on the equality of the - // reader. The only concern is, the reader implementation is outside of Spark and we have no - // control. - readSchema: StructType, - @transient filters: ExpressionSet, - hashPartitionKeys: Seq[String]) extends LeafExecNode { - - def output: Seq[Attribute] = readSchema.map(_.name).map { name => - fullOutput.find(_.name == name).get - } + fullOutput: Seq[AttributeReference], + @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] override def references: AttributeSet = AttributeSet.empty @@ -74,7 +68,7 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) override def preferredLocations: Array[String] = rowReadTask.preferredLocations override def createReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema)) + new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index b80f695b2a87f..f2cda002245e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -29,64 +29,8 @@ import org.apache.spark.sql.sources.v2.reader._ object DataSourceV2Strategy extends Strategy { // TODO: write path override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, DataSourceV2Relation(output, reader)) => - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(filters.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => filters - } - - val attrMap = AttributeMap(output.zip(output)) - val projectSet = AttributeSet(projects.flatMap(_.references)) - val filterSet = AttributeSet(stayUpFilters.flatMap(_.references)) - - // Match original case of attributes. - // TODO: nested fields pruning - val requiredColumns = (projectSet ++ filterSet).toSeq.map(attrMap) - reader match { - case r: SupportsPushDownRequiredColumns => - r.pruneColumns(requiredColumns.toStructType) - case _ => - } - - val scan = DataSourceV2ScanExec( - output.toArray, - reader, - reader.readSchema(), - ExpressionSet(filters), - Nil) - - val filterCondition = stayUpFilters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - - val withProject = if (projects == withFilter.output) { - withFilter - } else { - ProjectExec(projects, withFilter) - } - - withProject :: Nil + case DataSourceV2Relation(output, reader) => + DataSourceV2ScanExec(output, reader) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala new file mode 100644 index 0000000000000..0c1708131ae46 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.v2.reader._ + +/** + * Pushes down various operators to the underlying data source for better performance. Operators are + * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you + * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the + * data source should execute FILTER before LIMIT. And required columns are calculated at the end, + * because when more operators are pushed down, we may need less columns at Spark side. + */ +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { + override def apply(plan: LogicalPlan): LogicalPlan = { + // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may + // appear in many places for column pruning. + // TODO: Ideally column pruning should be implemented via a plan property that is propagated + // top-down, then we can simplify the logic here and only collect target operators. + val filterPushed = plan transformUp { + case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + // Non-deterministic expressions are stateful and we must keep the input sequence unchanged + // to avoid changing the result. This means, we can't evaluate the filter conditions that + // are after the first non-deterministic condition ahead. Here we only try to push down + // deterministic conditions that are before the first non-deterministic condition. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val stayUpFilters: Seq[Expression] = reader match { + case r: SupportsPushDownCatalystFilters => + r.pushCatalystFilters(candidates.toArray) + + case r: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet + val unhandledPredicates = translatedMap.filter { case (_, f) => + unhandledFilters.contains(f) + }.keys + + nonConvertiblePredicates ++ unhandledPredicates + + case _ => candidates + } + + val filterCondition = (stayUpFilters ++ containingNonDeterministic).reduceLeftOption(And) + val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) + if (withFilter.output == fields) { + withFilter + } else { + Project(fields, withFilter) + } + } + + // TODO: add more push down rules. + + // TODO: nested fields pruning + def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { + plan match { + case Project(projectList, child) => + val required = projectList.filter(requiredByParent.contains).flatMap(_.references) + pushDownRequiredColumns(child, required) + + case Filter(condition, child) => + val required = requiredByParent ++ condition.references + pushDownRequiredColumns(child, required) + + case DataSourceV2Relation(fullOutput, reader) => reader match { + case r: SupportsPushDownRequiredColumns => + // Match original case of attributes. + val attrMap = AttributeMap(fullOutput.zip(fullOutput)) + val requiredColumns = requiredByParent.map(attrMap) + r.pruneColumns(requiredColumns.toStructType) + case _ => + } + + // TODO: there may be more operators can be used to calculate required columns, we can add + // more and more in the future. + case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) + } + } + + pushDownRequiredColumns(filterPushed, filterPushed.output) + // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. + RemoveRedundantProject(filterPushed) + } + + /** + * Finds a Filter node(with an optional Project child) above data source relation. + */ + object FilterAndProject { + // returns the project list, the filter condition and the data source relation. + def unapply(plan: LogicalPlan) + : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + + case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + + case Filter(condition, Project(fields, r: DataSourceV2Relation)) + if fields.forall(_.deterministic) => + val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) + val substituted = condition.transform { + case a: Attribute => attributeMap.getOrElse(a, a) + } + Some((fields, substituted, r)) + + case _ => None + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 7aacf0346d2fb..da2c13f70c52a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -54,6 +54,11 @@ public Filter[] pushFilters(Filter[] filters) { return new Filter[0]; } + @Override + public Filter[] pushedFilters() { + return filters; + } + @Override public List> createReadTasks() { List> res = new ArrayList<>(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 9ce93d7ae926c..f238e565dc2fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -129,6 +129,8 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { Array.empty } + override def pushedFilters(): Array[Filter] = filters + override def readSchema(): StructType = { requiredSchema } From 02218c4c73c32741390d9906b6190ef2124ce518 Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Thu, 12 Oct 2017 17:00:22 +0200 Subject: [PATCH 701/779] [SPARK-22251][SQL] Metric 'aggregate time' is incorrect when codegen is off ## What changes were proposed in this pull request? Adding the code for setting 'aggregate time' metric to non-codegen path in HashAggregateExec and to ObjectHashAggregateExces. ## How was this patch tested? Tested manually. Author: Ala Luszczak Closes #19473 from ala/fix-agg-time. --- .../sql/execution/aggregate/HashAggregateExec.scala | 6 +++++- .../execution/aggregate/ObjectHashAggregateExec.scala | 9 +++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8b573fdcf25e1..43e5ff89afee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -95,11 +95,13 @@ case class HashAggregateExec( val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") val avgHashProbe = longMetric("avgHashProbe") + val aggTime = longMetric("aggTime") child.execute().mapPartitionsWithIndex { (partIndex, iter) => + val beforeAgg = System.nanoTime() val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. Iterator.empty @@ -128,6 +130,8 @@ case class HashAggregateExec( aggregationIterator } } + aggTime += (System.nanoTime() - beforeAgg) / 1000000 + res } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 6316e06a8f34e..ec3f9a05b5ccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -76,7 +76,8 @@ case class ObjectHashAggregateExec( aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows") + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time") ) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -96,11 +97,13 @@ case class ObjectHashAggregateExec( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") + val aggTime = longMetric("aggTime") val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => + val beforeAgg = System.nanoTime() val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input kvIterator is empty, // so return an empty kvIterator. Iterator.empty @@ -127,6 +130,8 @@ case class ObjectHashAggregateExec( aggregationIterator } } + aggTime += (System.nanoTime() - beforeAgg) / 1000000 + res } } From 9104add4c7c6b578df15b64a8533a1266f90734e Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 13 Oct 2017 08:40:26 +0900 Subject: [PATCH 702/779] [SPARK-22217][SQL] ParquetFileFormat to support arbitrary OutputCommitters ## What changes were proposed in this pull request? `ParquetFileFormat` to relax its requirement of output committer class from `org.apache.parquet.hadoop.ParquetOutputCommitter` or subclass thereof (and so implicitly Hadoop `FileOutputCommitter`) to any committer implementing `org.apache.hadoop.mapreduce.OutputCommitter` This enables output committers which don't write to the filesystem the way `FileOutputCommitter` does to save parquet data from a dataframe: at present you cannot do this. Before a committer which isn't a subclass of `ParquetOutputCommitter`, it checks to see if the context has requested summary metadata by setting `parquet.enable.summary-metadata`. If true, and the committer class isn't a parquet committer, it raises a RuntimeException with an error message. (It could downgrade, of course, but raising an exception makes it clear there won't be an summary. It also makes the behaviour testable.) Note that `SQLConf` already states that any `OutputCommitter` can be used, but that typically it's a subclass of ParquetOutputCommitter. That's not currently true. This patch will make the code consistent with the docs, adding tests to verify, ## How was this patch tested? The patch includes a test suite, `ParquetCommitterSuite`, with a new committer, `MarkingFileOutputCommitter` which extends `FileOutputCommitter` and writes a marker file in the destination directory. The presence of the marker file can be used to verify the new committer was used. The tests then try the combinations of Parquet committer summary/no-summary and marking committer summary/no-summary. | committer | summary | outcome | |-----------|---------|---------| | parquet | true | success | | parquet | false | success | | marking | false | success with marker | | marking | true | exception | All tests are happy. Author: Steve Loughran Closes #19448 from steveloughran/cloud/SPARK-22217-committer. --- .../apache/spark/sql/internal/SQLConf.scala | 5 +- .../parquet/ParquetFileFormat.scala | 12 +- .../parquet/ParquetCommitterSuite.scala | 152 ++++++++++++++++++ 3 files changed, 165 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 58323740b80cc..618d4a0d6148a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -306,8 +306,9 @@ object SQLConf { val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + - "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + - "of org.apache.parquet.hadoop.ParquetOutputCommitter.") + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" + + "will never be created, irrespective of the value of parquet.enable.summary-metadata") .internal() .stringConf .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index e1e740500205a..c1535babbae1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -86,7 +86,7 @@ class ParquetFileFormat conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter], - classOf[ParquetOutputCommitter]) + classOf[OutputCommitter]) if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { logInfo("Using default output committer for Parquet: " + @@ -98,7 +98,7 @@ class ParquetFileFormat conf.setClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, - classOf[ParquetOutputCommitter]) + classOf[OutputCommitter]) // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why @@ -138,6 +138,14 @@ class ParquetFileFormat conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) } + if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { + // output summary is requested, but the class is not a Parquet Committer + logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + + s" create job summaries. " + + s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.") + } + new OutputWriterFactory { // This OutputWriterFactory instance is deserialized when writing Parquet files on the // executor side without constructing or deserializing ParquetFileFormat. Therefore, we hold diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala new file mode 100644 index 0000000000000..caa4f6d70c6a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.FileNotFoundException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} + +import org.apache.spark.{LocalSparkContext, SparkFunSuite} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +/** + * Test logic related to choice of output committers. + */ +class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils + with LocalSparkContext { + + private val PARQUET_COMMITTER = classOf[ParquetOutputCommitter].getCanonicalName + + protected var spark: SparkSession = _ + + /** + * Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + try { + if (spark != null) { + spark.stop() + spark = null + } + } finally { + super.afterAll() + } + } + + test("alternative output committer, merge schema") { + writeDataFrame(MarkingFileOutput.COMMITTER, summary = true, check = true) + } + + test("alternative output committer, no merge schema") { + writeDataFrame(MarkingFileOutput.COMMITTER, summary = false, check = true) + } + + test("Parquet output committer, merge schema") { + writeDataFrame(PARQUET_COMMITTER, summary = true, check = false) + } + + test("Parquet output committer, no merge schema") { + writeDataFrame(PARQUET_COMMITTER, summary = false, check = false) + } + + /** + * Write a trivial dataframe as Parquet, using the given committer + * and job summary option. + * @param committer committer to use + * @param summary create a job summary + * @param check look for a marker file + * @return if a marker file was sought, it's file status. + */ + private def writeDataFrame( + committer: String, + summary: Boolean, + check: Boolean): Option[FileStatus] = { + var result: Option[FileStatus] = None + withSQLConf( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer, + ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) { + withTempPath { dest => + val df = spark.createDataFrame(Seq((1, "4"), (2, "2"))) + val destPath = new Path(dest.toURI) + df.write.format("parquet").save(destPath.toString) + if (check) { + result = Some(MarkingFileOutput.checkMarker( + destPath, + spark.sparkContext.hadoopConfiguration)) + } + } + } + result + } +} + +/** + * A file output committer which explicitly touches a file "marker"; this + * is how tests can verify that this committer was used. + * @param outputPath output path + * @param context task context + */ +private class MarkingFileOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) extends FileOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + super.commitJob(context) + MarkingFileOutput.touch(outputPath, context.getConfiguration) + } +} + +private object MarkingFileOutput { + + val COMMITTER = classOf[MarkingFileOutputCommitter].getCanonicalName + + /** + * Touch the marker. + * @param outputPath destination directory + * @param conf configuration to create the FS with + */ + def touch(outputPath: Path, conf: Configuration): Unit = { + outputPath.getFileSystem(conf).create(new Path(outputPath, "marker")).close() + } + + /** + * Get the file status of the marker + * + * @param outputPath destination directory + * @param conf configuration to create the FS with + * @return the status of the marker + * @throws FileNotFoundException if the marker is absent + */ + def checkMarker(outputPath: Path, conf: Configuration): FileStatus = { + outputPath.getFileSystem(conf).getFileStatus(new Path(outputPath, "marker")) + } +} From 3ff766f61afbd09dcc7a73eae02e68a39114ce3f Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 12 Oct 2017 18:47:16 -0700 Subject: [PATCH 703/779] [SPARK-22263][SQL] Refactor deterministic as lazy value ## What changes were proposed in this pull request? The method `deterministic` is frequently called in optimizer. Refactor `deterministic` as lazy value, in order to avoid redundant computations. ## How was this patch tested? Simple benchmark test over TPC-DS queries, run time from query string to optimized plan(continuous 20 runs, and get the average of last 5 results): Before changes: 12601 ms After changes: 11993ms This is 4.8% performance improvement. Also run test with Unit test. Author: Wang Gengliang Closes #19478 from gengliangwang/deterministicAsLazyVal. --- .../sql/catalyst/expressions/CallMethodViaReflection.scala | 2 +- .../apache/spark/sql/catalyst/expressions/Expression.scala | 4 ++-- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/First.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/Last.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/collect.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/misc.scala | 2 +- .../sql/execution/aggregate/TypedAggregateExpression.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/aggregate/udaf.scala | 2 +- .../org/apache/spark/sql/TypedImperativeAggregateSuite.scala | 2 +- .../src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 4 ++-- 11 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index cd97304302e48..65bb9a8c642b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -76,7 +76,7 @@ case class CallMethodViaReflection(children: Seq[Expression]) } } - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def nullable: Boolean = true override val dataType: DataType = StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c058425b4bc36..0e75ac88dc2b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -79,7 +79,7 @@ abstract class Expression extends TreeNode[Expression] { * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. * By default leaf expressions are deterministic as Nil.forall(_.deterministic) returns true. */ - def deterministic: Boolean = children.forall(_.deterministic) + lazy val deterministic: Boolean = children.forall(_.deterministic) def nullable: Boolean @@ -265,7 +265,7 @@ trait NonSQLExpression extends Expression { * An expression that is nondeterministic. */ trait Nondeterministic extends Expression { - final override def deterministic: Boolean = false + final override lazy val deterministic: Boolean = false final override def foldable: Boolean = false @transient diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 527f1670c25e1..179853032035e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,7 +49,7 @@ case class ScalaUDF( udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { - override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"${udfName.map(name => s"UDF:$name").getOrElse("UDF")}(${children.mkString(", ")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index bfc58c22886cc..4e671e1f3e6eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -44,7 +44,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override def nullable: Boolean = true // First is not a deterministic function. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 96a6ec08a160a..0ccabb9d98914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -44,7 +44,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override def nullable: Boolean = true // Last is not a deterministic function. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 405c2065680f5..be972f006352e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -44,7 +44,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the // actual order of input rows. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def update(buffer: T, input: InternalRow): T = { val value = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index ef293ff3f18ea..b86e271fe2958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -119,7 +119,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { // scalastyle:on line.size.limit case class Uuid() extends LeafExpression { - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def nullable: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 717758fdf716f..aab8cc50b9526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -127,7 +127,7 @@ case class SimpleTypedAggregateExpression( nullable: Boolean) extends DeclarativeAggregate with TypedAggregateExpression with NonSQLExpression { - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer @@ -221,7 +221,7 @@ case class ComplexTypedAggregateExpression( inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Any] with TypedAggregateExpression with NonSQLExpression { - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = inputDeserializer.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index fec1add18cbf2..72aa4adff4e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -340,7 +340,7 @@ case class ScalaUDAF( override def dataType: DataType = udaf.dataType - override def deterministic: Boolean = udaf.deterministic + override lazy val deterministic: Boolean = udaf.deterministic override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index b76f168220d84..c5fb17345222a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -268,7 +268,7 @@ object TypedImperativeAggregateSuite { } } - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = Seq(child) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index e9bdcf00b9346..68af99ea272a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -48,7 +48,7 @@ private[hive] case class HiveSimpleUDF( with Logging with UserDefinedExpression { - override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def nullable: Boolean = true @@ -131,7 +131,7 @@ private[hive] case class HiveGenericUDF( override def nullable: Boolean = true - override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] From ec122209fb35a65637df42eded64b0203e105aae Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Oct 2017 13:09:35 +0800 Subject: [PATCH 704/779] [SPARK-21165][SQL] FileFormatWriter should handle mismatched attribute ids between logical and physical plan ## What changes were proposed in this pull request? Due to optimizer removing some unnecessary aliases, the logical and physical plan may have different output attribute ids. FileFormatWriter should handle this when creating the physical sort node. ## How was this patch tested? new regression test. Author: Wenchen Fan Closes #19483 from cloud-fan/bug2. --- .../datasources/FileFormatWriter.scala | 7 +++++- .../datasources/FileFormatWriterSuite.scala | 2 +- .../apache/spark/sql/hive/InsertSuite.scala | 22 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 75b1695fbc275..1fac01a2c26c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -180,8 +180,13 @@ object FileFormatWriter extends Logging { val rdd = if (orderingMatched) { queryExecution.toRdd } else { + // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and + // the physical plan may have different attribute ids due to optimizer removing some + // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. + val orderingExpr = requiredOrdering + .map(SortOrder(_, Ascending)).map(BindReferences.bindReference(_, allColumns)) SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), + orderingExpr, global = false, child = queryExecution.executedPlan).execute() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala index 6f8767db176aa..13f0e0bca86c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -32,7 +32,7 @@ class FileFormatWriterSuite extends QueryTest with SharedSQLContext { } } - test("FileFormatWriter should respect the input query schema") { + test("SPARK-22252: FileFormatWriter should respect the input query schema") { withTable("t1", "t2", "t3", "t4") { spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index aa5cae33f5cd9..ab91727049ff5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -728,4 +728,26 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter assert(e.contains("mismatched input 'ROW'")) } } + + test("SPARK-21165: FileFormatWriter should only rely on attributes from analyzed plan") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + withTable("tab1", "tab2") { + Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1") + + spark.sql( + """ + |CREATE TABLE tab2 (word string, length int) + |PARTITIONED BY (first string) + """.stripMargin) + + spark.sql( + """ + |INSERT INTO TABLE tab2 PARTITION(first) + |SELECT word, length, cast(first as string) as first FROM tab1 + """.stripMargin) + + checkAnswer(spark.table("tab2"), Row("a", 3, "b")) + } + } + } } From 2f00a71a876321af02865d7cd53ada167e1ce2e3 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 12 Oct 2017 22:45:19 -0700 Subject: [PATCH 705/779] [SPARK-22257][SQL] Reserve all non-deterministic expressions in ExpressionSet ## What changes were proposed in this pull request? For non-deterministic expressions, they should be considered as not contained in the [[ExpressionSet]]. This is consistent with how we define `semanticEquals` between two expressions. Otherwise, combining expressions will remove non-deterministic expressions which should be reserved. E.g. Combine filters of ```scala testRelation.where(Rand(0) > 0.1).where(Rand(0) > 0.1) ``` should result in ```scala testRelation.where(Rand(0) > 0.1 && Rand(0) > 0.1) ``` ## How was this patch tested? Unit test Author: Wang Gengliang Closes #19475 from gengliangwang/non-deterministic-expressionSet. --- .../catalyst/expressions/ExpressionSet.scala | 23 ++++++--- .../expressions/ExpressionSetSuite.scala | 51 +++++++++++++++---- .../optimizer/FilterPushdownSuite.scala | 15 ++++++ 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index 305ac90e245b8..7e8e7b8cd5f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -30,8 +30,9 @@ object ExpressionSet { } /** - * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]] - * (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details. + * A [[Set]] where membership is determined based on determinacy and a canonical representation of + * an [[Expression]] (i.e. one that attempts to ignore cosmetic differences). + * See [[Canonicalize]] for more details. * * Internally this set uses the canonical representation, but keeps also track of the original * expressions to ease debugging. Since different expressions can share the same canonical @@ -46,6 +47,10 @@ object ExpressionSet { * set.contains(1 + a) => true * set.contains(a + 2) => false * }}} + * + * For non-deterministic expressions, they are always considered as not contained in the [[Set]]. + * On adding a non-deterministic expression, simply append it to the original expressions. + * This is consistent with how we define `semanticEquals` between two expressions. */ class ExpressionSet protected( protected val baseSet: mutable.Set[Expression] = new mutable.HashSet, @@ -53,7 +58,9 @@ class ExpressionSet protected( extends Set[Expression] { protected def add(e: Expression): Unit = { - if (!baseSet.contains(e.canonicalized)) { + if (!e.deterministic) { + originals += e + } else if (!baseSet.contains(e.canonicalized) ) { baseSet.add(e.canonicalized) originals += e } @@ -74,9 +81,13 @@ class ExpressionSet protected( } override def -(elem: Expression): ExpressionSet = { - val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) - val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) - new ExpressionSet(newBaseSet, newOriginals) + if (elem.deterministic) { + val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) + val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) + new ExpressionSet(newBaseSet, newOriginals) + } else { + new ExpressionSet(baseSet.clone(), originals.clone()) + } } override def iterator: Iterator[Expression] = originals.iterator diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index a1000a0e80799..12eddf557109f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -175,20 +175,14 @@ class ExpressionSetSuite extends SparkFunSuite { aUpper > bUpper || aUpper <= Rand(1L) || aUpper <= 10, aUpper <= Rand(1L) || aUpper <= 10 || aUpper > bUpper) - // Partial reorder case: we don't reorder non-deterministic expressions, - // but we can reorder sub-expressions in deterministic AND/OR expressions. - // There are two predicates: - // (aUpper > bUpper || bUpper > 100) => we can reorder sub-expressions in it. - // (aUpper === Rand(1L)) - setTest(1, + // Keep all the non-deterministic expressions even they are semantically equal. + setTest(2, Rand(1L), Rand(1L)) + + setTest(2, (aUpper > bUpper || bUpper > 100) && aUpper === Rand(1L), (bUpper > 100 || aUpper > bUpper) && aUpper === Rand(1L)) - // There are three predicates: - // (Rand(1L) > aUpper) - // (aUpper <= Rand(1L) && aUpper > bUpper) - // (aUpper > 10 && bUpper > 10) => we can reorder sub-expressions in it. - setTest(1, + setTest(2, Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (bUpper > 10 && aUpper > 10)) @@ -219,4 +213,39 @@ class ExpressionSetSuite extends SparkFunSuite { assert((initialSet ++ setToAddWithSameExpression).size == 2) assert((initialSet ++ setToAddWithOutSameExpression).size == 3) } + + test("add single element to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + + assert((initialSet + (aUpper + 1)).size == 2) + assert((initialSet + Rand(0)).size == 3) + assert((initialSet + (aUpper + 2)).size == 3) + } + + test("remove single element to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + + assert((initialSet - (aUpper + 1)).size == 1) + assert((initialSet - Rand(0)).size == 2) + assert((initialSet - (aUpper + 2)).size == 2) + } + + test("add multiple elements to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToAddWithSameDeterministicExpression = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToAddWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet ++ setToAddWithSameDeterministicExpression).size == 3) + assert((initialSet ++ setToAddWithOutSameExpression).size == 4) + } + + test("remove multiple elements to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToRemoveWithSameDeterministicExpression = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToRemoveWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet -- setToRemoveWithSameDeterministicExpression).size == 1) + assert((initialSet -- setToRemoveWithOutSameExpression).size == 2) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 582b3ead5e54a..de0e7c7ee49ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -94,6 +94,21 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("combine redundant deterministic filters") { + val originalQuery = + testRelation + .where(Rand(0) > 0.1 && 'a === 1) + .where(Rand(0) > 0.1 && 'a === 1) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Rand(0) > 0.1 && 'a === 1 && Rand(0) > 0.1) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("SPARK-16164: Filter pushdown should keep the ordering in the logical plan") { val originalQuery = testRelation From e6e36004afc3f9fc8abea98542248e9de11b4435 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 13 Oct 2017 23:09:12 +0800 Subject: [PATCH 706/779] [SPARK-14387][SPARK-16628][SPARK-18355][SQL] Use Spark schema to read ORC table instead of ORC file schema ## What changes were proposed in this pull request? Before Hive 2.0, ORC File schema has invalid column names like `_col1` and `_col2`. This is a well-known limitation and there are several Apache Spark issues with `spark.sql.hive.convertMetastoreOrc=true`. This PR ignores ORC File schema and use Spark schema. ## How was this patch tested? Pass the newly added test case. Author: Dongjoon Hyun Closes #19470 from dongjoon-hyun/SPARK-18355. --- .../spark/sql/hive/orc/OrcFileFormat.scala | 31 ++++++---- .../sql/hive/execution/SQLQuerySuite.scala | 62 ++++++++++++++++++- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index c76f0ebb36a60..194e69c93e1a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -134,12 +134,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)) - if (maybePhysicalSchema.isEmpty) { + val isEmptyFile = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).isEmpty + if (isEmptyFile) { Iterator.empty } else { - val physicalSchema = maybePhysicalSchema.get - OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) + OrcRelation.setRequiredColumns(conf, dataSchema, requiredSchema) val orcRecordReader = { val job = Job.getInstance(conf) @@ -163,6 +162,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // Unwraps `OrcStruct`s to `UnsafeRow`s OrcRelation.unwrapOrcStructs( conf, + dataSchema, requiredSchema, Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), recordsIterator) @@ -272,25 +272,32 @@ private[orc] object OrcRelation extends HiveInspectors { def unwrapOrcStructs( conf: Configuration, dataSchema: StructType, + requiredSchema: StructType, maybeStructOI: Option[StructObjectInspector], iterator: Iterator[Writable]): Iterator[InternalRow] = { val deserializer = new OrcSerde - val mutableRow = new SpecificInternalRow(dataSchema.map(_.dataType)) - val unsafeProjection = UnsafeProjection.create(dataSchema) + val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(requiredSchema) def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { - val (fieldRefs, fieldOrdinals) = dataSchema.zipWithIndex.map { - case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal + val (fieldRefs, fieldOrdinals) = requiredSchema.zipWithIndex.map { + case (field, ordinal) => + var ref = oi.getStructFieldRef(field.name) + if (ref == null) { + ref = oi.getStructFieldRef("_col" + dataSchema.fieldIndex(field.name)) + } + ref -> ordinal }.unzip - val unwrappers = fieldRefs.map(unwrapperFor) + val unwrappers = fieldRefs.map(r => if (r == null) null else unwrapperFor(r)) iterator.map { value => val raw = deserializer.deserialize(value) var i = 0 val length = fieldRefs.length while (i < length) { - val fieldValue = oi.getStructFieldData(raw, fieldRefs(i)) + val fieldRef = fieldRefs(i) + val fieldValue = if (fieldRef == null) null else oi.getStructFieldData(raw, fieldRef) if (fieldValue == null) { mutableRow.setNullAt(fieldOrdinals(i)) } else { @@ -306,8 +313,8 @@ private[orc] object OrcRelation extends HiveInspectors { } def setRequiredColumns( - conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = { - val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer) + conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = { + val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer) val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 09c59000b3e3f..94fa43dec7313 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -2050,4 +2050,64 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + Seq("orc", "parquet").foreach { format => + test(s"SPARK-18355 Read data from a hive table with a new column - $format") { + val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + + Seq("true", "false").foreach { value => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> value, + HiveUtils.CONVERT_METASTORE_PARQUET.key -> value) { + withTempDatabase { db => + client.runSqlHive( + s""" + |CREATE TABLE $db.t( + | click_id string, + | search_id string, + | uid bigint) + |PARTITIONED BY ( + | ts string, + | hour string) + |STORED AS $format + """.stripMargin) + + client.runSqlHive( + s""" + |INSERT INTO TABLE $db.t + |PARTITION (ts = '98765', hour = '01') + |VALUES (12, 2, 12345) + """.stripMargin + ) + + checkAnswer( + sql(s"SELECT click_id, search_id, uid, ts, hour FROM $db.t"), + Row("12", "2", 12345, "98765", "01")) + + client.runSqlHive(s"ALTER TABLE $db.t ADD COLUMNS (dummy string)") + + checkAnswer( + sql(s"SELECT click_id, search_id FROM $db.t"), + Row("12", "2")) + + checkAnswer( + sql(s"SELECT search_id, click_id FROM $db.t"), + Row("2", "12")) + + checkAnswer( + sql(s"SELECT search_id FROM $db.t"), + Row("2")) + + checkAnswer( + sql(s"SELECT dummy, click_id FROM $db.t"), + Row(null, "12")) + + checkAnswer( + sql(s"SELECT click_id, search_id, uid, dummy, ts, hour FROM $db.t"), + Row("12", "2", 12345, null, "98765", "01")) + } + } + } + } + } } From 6412ea1759d39a2380c572ec24cfd8ae4f2d81f7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 14 Oct 2017 00:35:12 +0800 Subject: [PATCH 707/779] [SPARK-21247][SQL] Type comparison should respect case-sensitive SQL conf ## What changes were proposed in this pull request? This is an effort to reduce the difference between Hive and Spark. Spark supports case-sensitivity in columns. Especially, for Struct types, with `spark.sql.caseSensitive=true`, the following is supported. ```scala scala> sql("select named_struct('a', 1, 'A', 2).a").show +--------------------------+ |named_struct(a, 1, A, 2).a| +--------------------------+ | 1| +--------------------------+ scala> sql("select named_struct('a', 1, 'A', 2).A").show +--------------------------+ |named_struct(a, 1, A, 2).A| +--------------------------+ | 2| +--------------------------+ ``` And vice versa, with `spark.sql.caseSensitive=false`, the following is supported. ```scala scala> sql("select named_struct('a', 1).A, named_struct('A', 1).a").show +--------------------+--------------------+ |named_struct(a, 1).A|named_struct(A, 1).a| +--------------------+--------------------+ | 1| 1| +--------------------+--------------------+ ``` However, types are considered different. For example, SET operations fail. ```scala scala> sql("SELECT named_struct('a',1) union all (select named_struct('A',2))").show org.apache.spark.sql.AnalysisException: Union can only be performed on tables with the compatible column types. struct <> struct at the first column of the second table;; 'Union :- Project [named_struct(a, 1) AS named_struct(a, 1)#57] : +- OneRowRelation$ +- Project [named_struct(A, 2) AS named_struct(A, 2)#58] +- OneRowRelation$ ``` This PR aims to support case-insensitive type equality. For example, in Set operation, the above operation succeed when `spark.sql.caseSensitive=false`. ```scala scala> sql("SELECT named_struct('a',1) union all (select named_struct('A',2))").show +------------------+ |named_struct(a, 1)| +------------------+ | [1]| | [2]| +------------------+ ``` ## How was this patch tested? Pass the Jenkins with a newly add test case. Author: Dongjoon Hyun Closes #18460 from dongjoon-hyun/SPARK-21247. --- .../sql/catalyst/analysis/TypeCoercion.scala | 10 ++++ .../org/apache/spark/sql/types/DataType.scala | 7 ++- .../catalyst/analysis/TypeCoercionSuite.scala | 52 +++++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 38 ++++++++++++++ 4 files changed, 102 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 9ffe646b5e4ec..532d22dbf2321 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -100,6 +100,16 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) + case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 30745c6a9d42a..d6e0df12218ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -26,6 +26,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils /** @@ -80,7 +81,11 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) + if (SQLConf.get.caseSensitiveAnalysis) { + DataType.equalsIgnoreNullability(this, other) + } else { + DataType.equalsIgnoreCaseAndNullability(this, other) + } /** * Returns the same data type but set all nullability fields are true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index d62e3b6dfe34f..793e04f66f0f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -131,14 +131,17 @@ class TypeCoercionSuite extends AnalysisTest { widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { var found = widenFunc(t1, t2) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = widenFunc(t2, t1) - assert(found == expected, - s"Expected $expected as wider common type for $t2 and $t1, found $found") + if (isSymmetric) { + found = widenFunc(t2, t1) + assert(found == expected, + s"Expected $expected as wider common type for $t2 and $t1, found $found") + } } test("implicit type cast - ByteType") { @@ -385,6 +388,47 @@ class TypeCoercionSuite extends AnalysisTest { widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) widenTest(StringType, MapType(IntegerType, StringType, true), None) widenTest(ArrayType(IntegerType), StructType(Seq()), None) + + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("b", IntegerType))), + None) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", DoubleType, nullable = false))), + None) + + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = false))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("A", IntegerType))), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkWidenType( + TypeCoercion.findTightestCommonType, + StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType))), + StructType(Seq(StructField("A", IntegerType), StructField("b", IntegerType))), + Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), + isSymmetric = false) + } } test("wider common type for decimal and array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 93a7777b70b46..f0c58e2e5bf45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2646,6 +2646,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-21247: Allow case-insensitive type equality in Set operation") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") + sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") + + withTable("t", "S") { + sql("CREATE TABLE t(c struct) USING parquet") + sql("CREATE TABLE S(C struct) USING parquet") + Seq(("c", "C"), ("C", "c"), ("c.f", "C.F"), ("C.F", "c.f")).foreach { + case (left, right) => + checkAnswer(sql(s"SELECT * FROM t, S WHERE t.$left = S.$right"), Seq.empty) + } + } + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val m1 = intercept[AnalysisException] { + sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") + }.message + assert(m1.contains("Union can only be performed on tables with the compatible column types")) + + val m2 = intercept[AnalysisException] { + sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") + }.message + assert(m2.contains("Except can only be performed on tables with the compatible column types")) + + withTable("t", "S") { + sql("CREATE TABLE t(c struct) USING parquet") + sql("CREATE TABLE S(C struct) USING parquet") + checkAnswer(sql("SELECT * FROM t, S WHERE t.c.f = S.C.F"), Seq.empty) + val m = intercept[AnalysisException] { + sql("SELECT * FROM t, S WHERE c = C") + }.message + assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch")) + } + } + } + test("SPARK-21335: support un-aliased subquery") { withTempView("v") { Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("v") From 3823dc88d3816c7d1099f9601426108acc90574c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Oct 2017 10:49:48 -0700 Subject: [PATCH 708/779] [SPARK-22252][SQL][FOLLOWUP] Command should not be a LeafNode ## What changes were proposed in this pull request? This is a minor folllowup of #19474 . #19474 partially reverted #18064 but accidentally introduced a behavior change. `Command` extended `LogicalPlan` before #18064 , but #19474 made it extend `LeafNode`. This is an internal behavior change as now all `Command` subclasses can't define children, and they have to implement `computeStatistic` method. This PR fixes this by making `Command` extend `LogicalPlan` ## How was this patch tested? N/A Author: Wenchen Fan Closes #19493 from cloud-fan/minor. --- .../org/apache/spark/sql/catalyst/plans/logical/Command.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 38f47081b6f55..ec5766e1f67f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command extends LeafNode { +trait Command extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty } From 1bb8b76045420b61a37806d6f4765af15c4052a7 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Fri, 13 Oct 2017 15:13:06 -0700 Subject: [PATCH 709/779] [MINOR][SS] keyWithIndexToNumValues" -> "keyWithIndexToValue" ## What changes were proposed in this pull request? This PR changes `keyWithIndexToNumValues` to `keyWithIndexToValue`. There will be directories on HDFS named with this `keyWithIndexToNumValues`. So if we ever want to fix this, let's fix it now. ## How was this patch tested? existing unit test cases. Author: Liwei Lin Closes #19435 from lw-lin/keyWithIndex. --- .../streaming/state/SymmetricHashJoinStateManager.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index d256fb578d921..6b386308c79fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -384,7 +384,7 @@ class SymmetricHashJoinStateManager( } /** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */ - private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValuesType) { + private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValueType) { private val keyWithIndexExprs = keyAttributes :+ Literal(1L) private val keyWithIndexSchema = keySchema.add("index", LongType) private val indexOrdinalInKeyWithIndexRow = keyAttributes.size @@ -471,7 +471,7 @@ class SymmetricHashJoinStateManager( object SymmetricHashJoinStateManager { def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { - val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValuesType) + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { getStateStoreName(joinSide, stateStoreType) } @@ -483,8 +483,8 @@ object SymmetricHashJoinStateManager { override def toString(): String = "keyToNumValues" } - private case object KeyWithIndexToValuesType extends StateStoreType { - override def toString(): String = "keyWithIndexToNumValues" + private case object KeyWithIndexToValueType extends StateStoreType { + override def toString(): String = "keyWithIndexToValue" } private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = { From 06df34d35ec088277445ef09cfb24bfe996f072e Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Fri, 13 Oct 2017 17:12:50 -0700 Subject: [PATCH 710/779] [SPARK-11034][LAUNCHER][MESOS] Launcher: add support for monitoring Mesos apps ## What changes were proposed in this pull request? Added Launcher support for monitoring Mesos apps in Client mode. SPARK-11033 can handle the support for Mesos/Cluster mode since the Standalone/Cluster and Mesos/Cluster modes use the same code at client side. ## How was this patch tested? I verified it manually by running launcher application, able to launch, stop and kill the mesos applications and also can invoke other launcher API's. Author: Devaraj K Closes #19385 from devaraj-kavali/SPARK-11034. --- .../MesosCoarseGrainedSchedulerBackend.scala | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 80c0a041b7322..603c980cb268d 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -32,6 +32,7 @@ import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskStat import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcEndpointAddress @@ -89,6 +90,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Synchronization protected by stateLock private[this] var stopCalled: Boolean = false + private val launcherBackend = new LauncherBackend() { + override protected def onStopRequest(): Unit = { + stopSchedulerBackend() + setState(SparkAppHandle.State.KILLED) + } + } + // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -182,6 +190,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( override def start() { super.start() + if (sc.deployMode == "client") { + launcherBackend.connect() + } val startedBefore = IdHelper.startedBefore.getAndSet(true) val suffix = if (startedBefore) { @@ -202,6 +213,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( sc.conf.getOption("spark.mesos.driver.frameworkId").map(_ + suffix) ) + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) startScheduler(driver) } @@ -295,15 +307,21 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( this.mesosExternalShuffleClient.foreach(_.init(appId)) this.schedulerDriver = driver markRegistered() + launcherBackend.setAppId(appId) + launcherBackend.setState(SparkAppHandle.State.RUNNING) } override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get >= maxCoresOption.getOrElse(0) * minRegisteredRatio } - override def disconnected(d: org.apache.mesos.SchedulerDriver) {} + override def disconnected(d: org.apache.mesos.SchedulerDriver) { + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) + } - override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) {} + override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) { + launcherBackend.setState(SparkAppHandle.State.RUNNING) + } /** * Method called by Mesos to offer resources on slaves. We respond by launching an executor, @@ -611,6 +629,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def stop() { + stopSchedulerBackend() + launcherBackend.setState(SparkAppHandle.State.FINISHED) + launcherBackend.close() + } + + private def stopSchedulerBackend() { // Make sure we're not launching tasks during shutdown stateLock.synchronized { if (stopCalled) { From e3536406ec6ff65a8b41ba2f2fd40517a760cfd6 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 13 Oct 2017 23:08:17 -0700 Subject: [PATCH 711/779] [SPARK-21762][SQL] FileFormatWriter/BasicWriteTaskStatsTracker metrics collection fails if a new file isn't yet visible ## What changes were proposed in this pull request? `BasicWriteTaskStatsTracker.getFileSize()` to catch `FileNotFoundException`, log info and then return 0 as a file size. This ensures that if a newly created file isn't visible due to the store not always having create consistency, the metric collection doesn't cause the failure. ## How was this patch tested? New test suite included, `BasicWriteTaskStatsTrackerSuite`. This not only checks the resilience to missing files, but verifies the existing logic as to how file statistics are gathered. Note that in the current implementation 1. if you call `Tracker..getFinalStats()` more than once, the file size count will increase by size of the last file. This could be fixed by clearing the filename field inside `getFinalStats()` itself. 2. If you pass in an empty or null string to `Tracker.newFile(path)` then IllegalArgumentException is raised, but only in `getFinalStats()`, rather than in `newFile`. There's a test for this behaviour in the new suite, as it verifies that only FNFEs get swallowed. Author: Steve Loughran Closes #18979 from steveloughran/cloud/SPARK-21762-missing-files-in-metrics. --- .../datasources/BasicWriteStatsTracker.scala | 49 +++- .../BasicWriteTaskStatsTrackerSuite.scala | 220 ++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 8 + 3 files changed, 265 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index b8f7d130d569f..11af0aaa7b206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.datasources +import java.io.FileNotFoundException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -44,20 +47,32 @@ case class BasicWriteTaskStats( * @param hadoopConf */ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) - extends WriteTaskStatsTracker { + extends WriteTaskStatsTracker with Logging { private[this] var numPartitions: Int = 0 private[this] var numFiles: Int = 0 + private[this] var submittedFiles: Int = 0 private[this] var numBytes: Long = 0L private[this] var numRows: Long = 0L - private[this] var curFile: String = null - + private[this] var curFile: Option[String] = None - private def getFileSize(filePath: String): Long = { + /** + * Get the size of the file expected to have been written by a worker. + * @param filePath path to the file + * @return the file size or None if the file was not found. + */ + private def getFileSize(filePath: String): Option[Long] = { val path = new Path(filePath) val fs = path.getFileSystem(hadoopConf) - fs.getFileStatus(path).getLen() + try { + Some(fs.getFileStatus(path).getLen()) + } catch { + case e: FileNotFoundException => + // may arise against eventually consistent object stores + logDebug(s"File $path is not yet visible", e) + None + } } @@ -70,12 +85,19 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def newFile(filePath: String): Unit = { - if (numFiles > 0) { - // we assume here that we've finished writing to disk the previous file by now - numBytes += getFileSize(curFile) + statCurrentFile() + curFile = Some(filePath) + submittedFiles += 1 + } + + private def statCurrentFile(): Unit = { + curFile.foreach { path => + getFileSize(path).foreach { len => + numBytes += len + numFiles += 1 + } + curFile = None } - curFile = filePath - numFiles += 1 } override def newRow(row: InternalRow): Unit = { @@ -83,8 +105,11 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def getFinalStats(): WriteTaskStats = { - if (numFiles > 0) { - numBytes += getFileSize(curFile) + statCurrentFile() + if (submittedFiles != numFiles) { + logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " + + "This could be due to the output format not writing empty files, " + + "or files being not immediately visible in the filesystem.") } BasicWriteTaskStats(numPartitions, numFiles, numBytes, numRows) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala new file mode 100644 index 0000000000000..bf3c8ede9a980 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.nio.charset.Charset + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils + +/** + * Test how BasicWriteTaskStatsTracker handles files. + * + * Two different datasets are written (alongside 0), one of + * length 10, one of 3. They were chosen to be distinct enough + * that it is straightforward to determine which file lengths were added + * from the sum of all files added. Lengths like "10" and "5" would + * be less informative. + */ +class BasicWriteTaskStatsTrackerSuite extends SparkFunSuite { + + private val tempDir = Utils.createTempDir() + private val tempDirPath = new Path(tempDir.toURI) + private val conf = new Configuration() + private val localfs = tempDirPath.getFileSystem(conf) + private val data1 = "0123456789".getBytes(Charset.forName("US-ASCII")) + private val data2 = "012".getBytes(Charset.forName("US-ASCII")) + private val len1 = data1.length + private val len2 = data2.length + + /** + * In teardown delete the temp dir. + */ + protected override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + } + + /** + * Assert that the stats match that expected. + * @param tracker tracker to check + * @param files number of files expected + * @param bytes total number of bytes expected + */ + private def assertStats( + tracker: BasicWriteTaskStatsTracker, + files: Int, + bytes: Int): Unit = { + val stats = finalStatus(tracker) + assert(files === stats.numFiles, "Wrong number of files") + assert(bytes === stats.numBytes, "Wrong byte count of file size") + } + + private def finalStatus(tracker: BasicWriteTaskStatsTracker): BasicWriteTaskStats = { + tracker.getFinalStats().asInstanceOf[BasicWriteTaskStats] + } + + test("No files in run") { + val tracker = new BasicWriteTaskStatsTracker(conf) + assertStats(tracker, 0, 0) + } + + test("Missing File") { + val missing = new Path(tempDirPath, "missing") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(missing.toString) + assertStats(tracker, 0, 0) + } + + test("Empty filename is forwarded") { + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile("") + intercept[IllegalArgumentException] { + finalStatus(tracker) + } + } + + test("Null filename is only picked up in final status") { + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(null) + intercept[IllegalArgumentException] { + finalStatus(tracker) + } + } + + test("0 byte file") { + val file = new Path(tempDirPath, "file0") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + touch(file) + assertStats(tracker, 1, 0) + } + + test("File with data") { + val file = new Path(tempDirPath, "file-with-data") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + write1(file) + assertStats(tracker, 1, len1) + } + + test("Open file") { + val file = new Path(tempDirPath, "file-open") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + val stream = localfs.create(file, true) + try { + assertStats(tracker, 1, 0) + stream.write(data1) + stream.flush() + assert(1 === finalStatus(tracker).numFiles, "Wrong number of files") + } finally { + stream.close() + } + } + + test("Two files") { + val file1 = new Path(tempDirPath, "f-2-1") + val file2 = new Path(tempDirPath, "f-2-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file1.toString) + write1(file1) + tracker.newFile(file2.toString) + write2(file2) + assertStats(tracker, 2, len1 + len2) + } + + test("Three files, last one empty") { + val file1 = new Path(tempDirPath, "f-3-1") + val file2 = new Path(tempDirPath, "f-3-2") + val file3 = new Path(tempDirPath, "f-3-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file1.toString) + write1(file1) + tracker.newFile(file2.toString) + write2(file2) + tracker.newFile(file3.toString) + touch(file3) + assertStats(tracker, 3, len1 + len2) + } + + test("Three files, one not found") { + val file1 = new Path(tempDirPath, "f-4-1") + val file2 = new Path(tempDirPath, "f-4-2") + val file3 = new Path(tempDirPath, "f-3-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + // file 1 + tracker.newFile(file1.toString) + write1(file1) + + // file 2 is noted, but not created + tracker.newFile(file2.toString) + + // file 3 is noted & then created + tracker.newFile(file3.toString) + write2(file3) + + // the expected size is file1 + file3; only two files are reported + // as found + assertStats(tracker, 2, len1 + len2) + } + + /** + * Write a 0-byte file. + * @param file file path + */ + private def touch(file: Path): Unit = { + localfs.create(file, true).close() + } + + /** + * Write a byte array. + * @param file path to file + * @param data data + * @return bytes written + */ + private def write(file: Path, data: Array[Byte]): Integer = { + val stream = localfs.create(file, true) + try { + stream.write(data) + } finally { + stream.close() + } + data.length + } + + /** + * Write a data1 array. + * @param file file + */ + private def write1(file: Path): Unit = { + write(file, data1) + } + + /** + * Write a data2 array. + * + * @param file file + */ + private def write2(file: Path): Unit = { + write(file, data2) + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 94fa43dec7313..60935c3e85c43 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2110,4 +2110,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + Seq("orc", "parquet", "csv", "json", "text").foreach { format => + test(s"Writing empty datasets should not fail - $format") { + withTempDir { dir => + Seq("str").toDS.limit(0).write.format(format).save(dir.getCanonicalPath + "/tmp") + } + } + } } From e0503a7223410289d01bc4b20da3a451730577da Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 13 Oct 2017 23:24:36 -0700 Subject: [PATCH 712/779] [SPARK-22273][SQL] Fix key/value schema field names in HashMapGenerators. ## What changes were proposed in this pull request? When fixing schema field names using escape characters with `addReferenceMinorObj()` at [SPARK-18952](https://issues.apache.org/jira/browse/SPARK-18952) (#16361), double-quotes around the names were remained and the names become something like `"((java.lang.String) references[1])"`. ```java /* 055 */ private int maxSteps = 2; /* 056 */ private int numRows = 0; /* 057 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add("((java.lang.String) references[1])", org.apache.spark.sql.types.DataTypes.StringType); /* 058 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add("((java.lang.String) references[2])", org.apache.spark.sql.types.DataTypes.LongType); /* 059 */ private Object emptyVBase; ``` We should remove the double-quotes to refer the values in `references` properly: ```java /* 055 */ private int maxSteps = 2; /* 056 */ private int numRows = 0; /* 057 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1]), org.apache.spark.sql.types.DataTypes.StringType); /* 058 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2]), org.apache.spark.sql.types.DataTypes.LongType); /* 059 */ private Object emptyVBase; ``` ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #19491 from ueshin/issues/SPARK-22273. --- .../execution/aggregate/RowBasedHashMapGenerator.scala | 8 ++++---- .../execution/aggregate/VectorizedHashMapGenerator.scala | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 9316ebcdf105c..3718424931b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -50,10 +50,10 @@ class RowBasedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") @@ -63,10 +63,10 @@ class RowBasedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 13f79275cac41..812d405d5ebfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -55,10 +55,10 @@ class VectorizedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") @@ -68,10 +68,10 @@ class VectorizedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") From 014dc8471200518d63005eed531777d30d8a6639 Mon Sep 17 00:00:00 2001 From: liulijia Date: Sat, 14 Oct 2017 17:37:33 +0900 Subject: [PATCH 713/779] [SPARK-22233][CORE] Allow user to filter out empty split in HadoopRDD ## What changes were proposed in this pull request? Add a flag spark.files.ignoreEmptySplits. When true, methods like that use HadoopRDD and NewHadoopRDD such as SparkContext.textFiles will not create a partition for input splits that are empty. Author: liulijia Closes #19464 from liutang123/SPARK-22233. --- .../spark/internal/config/package.scala | 6 ++ .../org/apache/spark/rdd/HadoopRDD.scala | 12 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 13 ++- .../scala/org/apache/spark/FileSuite.scala | 95 +++++++++++++++++-- 4 files changed, 112 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 19336f854145f..ce013d69579c1 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -270,6 +270,12 @@ package object config { .longConf .createWithDefault(4 * 1024 * 1024) + private[spark] val IGNORE_EMPTY_SPLITS = ConfigBuilder("spark.files.ignoreEmptySplits") + .doc("If true, methods that use HadoopRDD and NewHadoopRDD such as " + + "SparkContext.textFiles will not create a partition for input splits that are empty.") + .booleanConf + .createWithDefault(false) + private[spark] val SECRET_REDACTION_PATTERN = ConfigBuilder("spark.redaction.regex") .doc("Regex to decide which Spark configuration properties and environment variables in " + diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 23b344230e490..1f33c0a2b709f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel @@ -134,6 +134,8 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value @@ -195,8 +197,12 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) - val inputFormat = getInputFormat(jobConf) - val inputSplits = inputFormat.getSplits(jobConf, minPartitions) + val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) + val inputSplits = if (ignoreEmptySplits) { + allInputSplits.filter(_.getLength > 0) + } else { + allInputSplits + } val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { array(i) = new HadoopPartition(id, i, inputSplits(i)) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 482875e6c1ac5..db4eac1d0a775 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -21,6 +21,7 @@ import java.io.IOException import java.text.SimpleDateFormat import java.util.{Date, Locale} +import scala.collection.JavaConverters.asScalaBufferConverter import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} @@ -34,7 +35,7 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -89,6 +90,8 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + def getConf: Configuration = { val conf: Configuration = confBroadcast.value.value if (shouldCloneJobConf) { @@ -121,8 +124,12 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val jobContext = new JobContextImpl(_conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray + val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala + val rawSplits = if (ignoreEmptySplits) { + allRowSplits.filter(_.getLength > 0) + } else { + allRowSplits + } val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 02728180ac82d..4da4323ceb5c8 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -347,10 +347,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } - test ("allow user to disable the output directory existence checking (old Hadoop API") { - val sf = new SparkConf() - sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") - sc = new SparkContext(sf) + test ("allow user to disable the output directory existence checking (old Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(conf) val randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1) randomRDD.saveAsTextFile(tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-00000").exists() === true) @@ -380,9 +380,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } test ("allow user to disable the output directory existence checking (new Hadoop API") { - val sf = new SparkConf() - sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") - sc = new SparkContext(sf) + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(conf) val randomRDD = sc.parallelize( Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( @@ -510,4 +510,83 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } + test("spark.files.ignoreEmptySplits work correctly (old Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + sc = new SparkContext(conf) + + def testIgnoreEmptySplits( + data: Array[Tuple2[String, String]], + actualPartitionNum: Int, + expectedPartitionNum: Int): Unit = { + val output = new File(tempDir, "output") + sc.parallelize(data, actualPartitionNum) + .saveAsHadoopFile[TextOutputFormat[String, String]](output.getPath) + for (i <- 0 until actualPartitionNum) { + assert(new File(output, s"part-0000$i").exists() === true) + } + val hadoopRDD = sc.textFile(new File(output, "part-*").getPath) + assert(hadoopRDD.partitions.length === expectedPartitionNum) + Utils.deleteRecursively(output) + } + + // Ensure that if all of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array.empty[Tuple2[String, String]], + actualPartitionNum = 1, + expectedPartitionNum = 0) + + // Ensure that if no split is empty, we don't lose any splits + testIgnoreEmptySplits( + data = Array(("key1", "a"), ("key2", "a"), ("key3", "b")), + actualPartitionNum = 2, + expectedPartitionNum = 2) + + // Ensure that if part of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array(("key1", "a"), ("key2", "a")), + actualPartitionNum = 5, + expectedPartitionNum = 2) + } + + test("spark.files.ignoreEmptySplits work correctly (new Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + sc = new SparkContext(conf) + + def testIgnoreEmptySplits( + data: Array[Tuple2[String, String]], + actualPartitionNum: Int, + expectedPartitionNum: Int): Unit = { + val output = new File(tempDir, "output") + sc.parallelize(data, actualPartitionNum) + .saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](output.getPath) + for (i <- 0 until actualPartitionNum) { + assert(new File(output, s"part-r-0000$i").exists() === true) + } + val hadoopRDD = sc.newAPIHadoopFile(new File(output, "part-r-*").getPath, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]) + .asInstanceOf[NewHadoopRDD[_, _]] + assert(hadoopRDD.partitions.length === expectedPartitionNum) + Utils.deleteRecursively(output) + } + + // Ensure that if all of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array.empty[Tuple2[String, String]], + actualPartitionNum = 1, + expectedPartitionNum = 0) + + // Ensure that if no split is empty, we don't lose any splits + testIgnoreEmptySplits( + data = Array(("1", "a"), ("2", "a"), ("3", "b")), + actualPartitionNum = 2, + expectedPartitionNum = 2) + + // Ensure that if part of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array(("1", "a"), ("2", "b")), + actualPartitionNum = 5, + expectedPartitionNum = 2) + } } From e8547ffb49071525c06876c856cecc0d4731b918 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sat, 14 Oct 2017 17:39:15 -0700 Subject: [PATCH 714/779] [SPARK-22238] Fix plan resolution bug caused by EnsureStatefulOpPartitioning ## What changes were proposed in this pull request? In EnsureStatefulOpPartitioning, we check that the inputRDD to a SparkPlan has the expected partitioning for Streaming Stateful Operators. The problem is that we are not allowed to access this information during planning. The reason we added that check was because CoalesceExec could actually create RDDs with 0 partitions. We should fix it such that when CoalesceExec says that there is a SinglePartition, there is in fact an inputRDD of 1 partition instead of 0 partitions. ## How was this patch tested? Regression test in StreamingQuerySuite Author: Burak Yavuz Closes #19467 from brkyvz/stateful-op. --- .../plans/physical/partitioning.scala | 15 +- .../execution/basicPhysicalOperators.scala | 27 +++- .../exchange/EnsureRequirements.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 2 +- .../streaming/IncrementalExecution.scala | 39 ++--- .../streaming/statefulOperators.scala | 11 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 + .../spark/sql/execution/PlannerSuite.scala | 17 +++ .../streaming/state/StateStoreRDDSuite.scala | 2 +- .../SymmetricHashJoinStateManagerSuite.scala | 2 +- .../sql/streaming/DeduplicateSuite.scala | 11 +- .../EnsureStatefulOpPartitioningSuite.scala | 138 ------------------ .../FlatMapGroupsWithStateSuite.scala | 6 +- .../sql/streaming/StatefulOperatorTest.scala | 49 +++++++ .../streaming/StreamingAggregationSuite.scala | 8 +- .../sql/streaming/StreamingJoinSuite.scala | 2 +- .../sql/streaming/StreamingQuerySuite.scala | 13 ++ 17 files changed, 160 insertions(+), 189 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233fe..e57c842ce2a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -49,7 +49,9 @@ case object AllTuples extends Distribution * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class ClusteredDistribution( + clustering: Seq[Expression], + numPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -221,6 +223,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false + case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) case _ => true } @@ -243,8 +246,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } @@ -289,8 +293,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 63cd1691f4cd7..d15ece304cac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -590,10 +590,33 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } protected override def doExecute(): RDD[InternalRow] = { - child.execute().coalesce(numPartitions, shuffle = false) + if (numPartitions == 1 && child.execute().getNumPartitions < 1) { + // Make sure we don't output an RDD with 0 partitions, when claiming that we have a + // `SinglePartition`. + new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + } else { + child.execute().coalesce(numPartitions, shuffle = false) + } } } +object CoalesceExec { + /** A simple RDD with no data, but with the given number of partitions. */ + class EmptyRDDWithPartitions( + @transient private val sc: SparkContext, + numPartitions: Int) extends RDD[InternalRow](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => EmptyPartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + Iterator.empty + } + } + + case class EmptyPartition(index: Int) extends Partition +} + /** * Physical plan for a subquery. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d28ce60e276d5..4e2ca37bc1a59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -44,13 +44,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { /** * Given a required distribution, returns a partitioning that satisfies that distribution. + * @param requiredDistribution The distribution that is required by the operator + * @param numPartitions Used when the distribution doesn't require a specific number of partitions */ private def createPartitioning( requiredDistribution: Distribution, numPartitions: Int): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case ClusteredDistribution(clustering, desiredPartitions) => + HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) case dist => sys.error(s"Do not know how to satisfy distribution $dist") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index aab06d611a5ea..c81f1a8142784 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -64,7 +64,7 @@ case class FlatMapGroupsWithStateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil /** Ordering needed for using GroupingIterator */ override def requiredChildOrdering: Seq[Seq[SortOrder]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 82f879c763c2b..2e378637727fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** @@ -61,6 +62,10 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } + private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) + .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) + /** * See [SPARK-18339] * Walk the optimized logical plan and replace CurrentBatchTimestamp @@ -83,7 +88,11 @@ class IncrementalExecution( /** Get the state info of the next stateful operator */ private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { StatefulOperatorStateInfo( - checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + checkpointLocation, + runId, + statefulOperatorId.getAndIncrement(), + currentBatchId, + numStateStores) } /** Locates save/restore pairs surrounding aggregation. */ @@ -130,34 +139,8 @@ class IncrementalExecution( } } - override def preparations: Seq[Rule[SparkPlan]] = - Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } } - -object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { - // Needs to be transformUp to avoid extra shuffles - override def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case so: StatefulOperator => - val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions - val distributions = so.requiredChildDistribution - val children = so.children.zip(distributions).map { case (child, reqDistribution) => - val expectedPartitioning = reqDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions) - case _ => throw new AnalysisException("Unexpected distribution expected for " + - s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " + - s"$reqDistribution.") - } - if (child.outputPartitioning.guarantees(expectedPartitioning) && - child.execute().getNumPartitions == expectedPartitioning.numPartitions) { - child - } else { - ShuffleExchangeExec(expectedPartitioning, child) - } - } - so.withNewChildren(children) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 0d85542928ee6..b9b07a2e688f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -43,10 +43,11 @@ case class StatefulOperatorStateInfo( checkpointLocation: String, queryRunId: UUID, operatorId: Long, - storeVersion: Long) { + storeVersion: Long, + numPartitions: Int) { override def toString(): String = { s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " + - s"opId = $operatorId, ver = $storeVersion]" + s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions]" } } @@ -239,7 +240,7 @@ case class StateStoreRestoreExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -386,7 +387,7 @@ case class StateStoreSaveExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -401,7 +402,7 @@ case class StreamingDeduplicateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ad461fa6144b3..50de2fd3bca8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -368,6 +368,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( testData.select('key).coalesce(1).select('key), testData.select('key).collect().toSeq) + + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) } test("convert $\"attribute name\" into unresolved attribute") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 86066362da9dd..c25c90d0c70e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -425,6 +425,23 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements should respect ClusteredDistribution's num partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13)) + // Number of partitions differ + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + val shuffle = outputPlan.collect { case e: ShuffleExchangeExec => e } + assert(shuffle.size === 1) + assert(shuffle.head.newPartitioning === finalPartitioning) + } + test("Reuse exchanges") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index defb9ed63a881..65b39f0fbd73d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -214,7 +214,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn path: String, queryRunId: UUID = UUID.randomUUID, version: Int = 0): StatefulOperatorStateInfo = { - StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5) } private val increment = (store: StateStore, iter: Iterator[String]) => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index d44af1d14c27a..c0216a2ef3e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -160,7 +160,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter withTempDir { file => val storeConf = new StateStoreConf() - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0) + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) val manager = new SymmetricHashJoinStateManager( LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index e858b7d9998a8..caf2bab8a5859 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class DeduplicateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ @@ -41,6 +44,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -58,6 +63,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala deleted file mode 100644 index ed9823fbddfda..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import java.util.UUID - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} -import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} -import org.apache.spark.sql.test.SharedSQLContext - -class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { - - import testImplicits._ - - private var baseDf: DataFrame = null - - override def beforeAll(): Unit = { - super.beforeAll() - baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char") - } - - test("ClusteredDistribution generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("AllTuples generates Exchange with SinglePartition") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = true) - } - - test("AllTuples with coalesce(1) doesn't need Exchange") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = false) - } - - /** - * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan - * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to - * ensure the expected partitioning. - */ - private def testEnsureStatefulOpPartitioning( - inputPlan: SparkPlan, - requiredDistribution: Seq[Attribute] => Distribution, - expectedPartitioning: Seq[Attribute] => Partitioning, - expectShuffle: Boolean): Unit = { - val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) - val executed = executePlan(operator, OutputMode.Complete()) - if (expectShuffle) { - val exchange = executed.children.find(_.isInstanceOf[Exchange]) - if (exchange.isEmpty) { - fail(s"Was expecting an exchange but didn't get one in:\n$executed") - } - assert(exchange.get === - ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan), - s"Exchange didn't have expected properties:\n${exchange.get}") - } else { - assert(!executed.children.exists(_.isInstanceOf[Exchange]), - s"Unexpected exchange found in:\n$executed") - } - } - - /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */ - private def executePlan( - p: SparkPlan, - outputMode: OutputMode = OutputMode.Append()): SparkPlan = { - val execution = new IncrementalExecution( - spark, - null, - OutputMode.Complete(), - "chk", - UUID.randomUUID(), - 0L, - OffsetSeqMetadata()) { - override lazy val sparkPlan: SparkPlan = p transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - } - execution.executedPlan - } -} - -/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ -case class TestStatefulOperator( - child: SparkPlan, - requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { - override def output: Seq[Attribute] = child.output - override def doExecute(): RDD[InternalRow] = child.execute() - override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil - override def stateInfo: Option[StatefulOperatorStateInfo] = None -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index d2e8beb2f5290..aeb83835f981a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -41,7 +41,9 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) -class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ import GroupStateImpl._ @@ -544,6 +546,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( + sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala new file mode 100644 index 0000000000000..45142278993bb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.streaming._ + +trait StatefulOperatorTest { + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + colNames: Seq[String]): Boolean = { + val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output + val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions + val groupingAttr = attr.filter(a => colNames.contains(a.name)) + checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) + } + + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + expectedPartitioning: Partitioning): Boolean = { + val operator = sq.asInstanceOf[StreamExecution].lastExecution + .executedPlan.collect { case p: T => p } + operator.head.children.forall( + _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index fe7efa69f7e31..1b4d8556f6ae5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions { + with BeforeAndAfterAll with Assertions with StatefulOperatorTest { override def afterAll(): Unit = { super.afterAll() @@ -281,6 +281,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L), @@ -455,8 +457,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest }, AddBlockData(inputSource), // create an empty trigger CheckLastBatch(1), - AssertOnQuery("Verify addition of exchange operator") { se => - checkAggregationChain(se, expectShuffling = true, 1) + AssertOnQuery("Verify that no exchange is required") { se => + checkAggregationChain(se, expectShuffling = false, 1) }, AddBlockData(inputSource, Seq(2, 3)), CheckLastBatch(3), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index a6593b71e51de..d32617275aadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -330,7 +330,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with val queryId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString - val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L) + val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5) implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ab35079dca23f..c53889bb8566c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -652,6 +652,19 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-22238: don't check for RDD partitions during streaming aggregation preparation") { + val stream = MemoryStream[(Int, Int)] + val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'") + val otherDf = stream.toDF().toDF("num", "numSq") + .join(broadcast(baseDf), "num") + .groupBy('char) + .agg(sum('numSq)) + + testStream(otherDf, OutputMode.Complete())( + AddData(stream, (1, 1), (2, 4)), + CheckLastBatch(("A", 1))) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) From 13c1559587d0eb533c94f5a492390f81b048b347 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 15 Oct 2017 18:40:53 -0700 Subject: [PATCH 715/779] [SPARK-21549][CORE] Respect OutputFormats with no/invalid output directory provided ## What changes were proposed in this pull request? PR #19294 added support for null's - but spark 2.1 handled other error cases where path argument can be invalid. Namely: * empty string * URI parse exception while creating Path This is resubmission of PR #19487, which I messed up while updating my repo. ## How was this patch tested? Enhanced test to cover new support added. Author: Mridul Muralidharan Closes #19497 from mridulm/master. --- .../io/HadoopMapReduceCommitProtocol.scala | 24 +++++++------- .../spark/rdd/PairRDDFunctionsSuite.scala | 31 +++++++++++++------ 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index a7e6859ef6b64..95c99d29c3a9c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -20,6 +20,7 @@ package org.apache.spark.internal.io import java.util.{Date, UUID} import scala.collection.mutable +import scala.util.Try import org.apache.hadoop.conf.Configurable import org.apache.hadoop.fs.Path @@ -47,6 +48,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ + /** + * Checks whether there are files to be committed to a valid output location. + * + * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, + * it is necessary to check whether a valid output path is specified. + * [[HadoopMapReduceCommitProtocol#path]] need not be a valid [[org.apache.hadoop.fs.Path]] for + * committers not writing to distributed file systems. + */ + private val hasValidPath = Try { new Path(path) }.isSuccess + /** * Tracks files staged by this task for absolute output paths. These outputs are not managed by * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. @@ -60,15 +71,6 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) */ private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) - /** - * Checks whether there are files to be committed to an absolute output location. - * - * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, - * it is necessary to check whether the output path is specified. Output path may not be required - * for committers not writing to distributed file systems. - */ - private def hasAbsPathFiles: Boolean = path != null - protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() // If OutputFormat is Configurable, we should set conf to it. @@ -142,7 +144,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) .foldLeft(Map[String, String]())(_ ++ _) logDebug(s"Committing files staged for absolute locations $filesToMove") - if (hasAbsPathFiles) { + if (hasValidPath) { val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) @@ -153,7 +155,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) - if (hasAbsPathFiles) { + if (hasValidPath) { val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) fs.delete(absPathStagingDir, true) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 07579c5098014..0a248b6064ee8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -568,21 +568,34 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") } - test("saveAsNewAPIHadoopDataset should respect empty output directory when " + + test("saveAsNewAPIHadoopDataset should support invalid output paths when " + "there are no files to be committed to an absolute output location") { val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) - val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) - job.setOutputKeyClass(classOf[Integer]) - job.setOutputValueClass(classOf[Integer]) - job.setOutputFormatClass(classOf[NewFakeFormat]) - val jobConfiguration = job.getConfiguration + def saveRddWithPath(path: String): Unit = { + val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) + job.setOutputKeyClass(classOf[Integer]) + job.setOutputValueClass(classOf[Integer]) + job.setOutputFormatClass(classOf[NewFakeFormat]) + if (null != path) { + job.getConfiguration.set("mapred.output.dir", path) + } else { + job.getConfiguration.unset("mapred.output.dir") + } + val jobConfiguration = job.getConfiguration + + // just test that the job does not fail with java.lang.IllegalArgumentException. + pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + } - // just test that the job does not fail with - // java.lang.IllegalArgumentException: Can not create a Path from a null string - pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + saveRddWithPath(null) + saveRddWithPath("") + saveRddWithPath("::invalid::") } + // In spark 2.1, only null was supported - not other invalid paths. + // org.apache.hadoop.mapred.FileOutputFormat.getOutputPath fails with IllegalArgumentException + // for non-null invalid paths. test("saveAsHadoopDataset should respect empty output directory when " + "there are no files to be committed to an absolute output location") { val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) From 0ae96495dedb54b3b6bae0bd55560820c5ca29a2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Oct 2017 13:37:58 +0800 Subject: [PATCH 716/779] [SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessary shuffle ## What changes were proposed in this pull request? `ObjectHashAggregateExec` should override `outputPartitioning` in order to avoid unnecessary shuffle. ## How was this patch tested? Added Jenkins test. Author: Liang-Chi Hsieh Closes #19501 from viirya/SPARK-22223. --- .../aggregate/ObjectHashAggregateExec.scala | 2 ++ .../spark/sql/DataFrameAggregateSuite.scala | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index ec3f9a05b5ccc..66955b8ef723c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -95,6 +95,8 @@ case class ObjectHashAggregateExec( } } + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") val aggTime = longMetric("aggTime") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8549eac58ee95..06848e4d2b297 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -636,4 +637,33 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), Seq(Row(3, 4, 9))) } + + test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") { + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + .repartition(col("a")) + + val objHashAggDF = df + .withColumn("d", expr("(a, b, c)")) + .groupBy("a", "b").agg(collect_list("d").as("e")) + .withColumn("f", expr("(b, e)")) + .groupBy("a").agg(collect_list("f").as("g")) + val aggPlan = objHashAggDF.queryExecution.executedPlan + + val sortAggPlans = aggPlan.collect { + case sortAgg: SortAggregateExec => sortAgg + } + assert(sortAggPlans.isEmpty) + + val objHashAggPlans = aggPlan.collect { + case objHashAgg: ObjectHashAggregateExec => objHashAgg + } + assert(objHashAggPlans.nonEmpty) + + val exchangePlans = aggPlan.collect { + case shuffle: ShuffleExchangeExec => shuffle + } + assert(exchangePlans.length == 1) + } + } } From 0fa10666cf75e3c4929940af49c8a6f6ea874759 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 16 Oct 2017 22:15:50 +0800 Subject: [PATCH 717/779] [SPARK-22233][CORE][FOLLOW-UP] Allow user to filter out empty split in HadoopRDD ## What changes were proposed in this pull request? Update the config `spark.files.ignoreEmptySplits`, rename it and make it internal. This is followup of #19464 ## How was this patch tested? Exsiting tests. Author: Xingbo Jiang Closes #19504 from jiangxb1987/partitionsplit. --- .../org/apache/spark/internal/config/package.scala | 11 ++++++----- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 4 ++-- .../scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- .../test/scala/org/apache/spark/FileSuite.scala | 14 +++++++++----- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ce013d69579c1..efffdca1ea59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -270,11 +270,12 @@ package object config { .longConf .createWithDefault(4 * 1024 * 1024) - private[spark] val IGNORE_EMPTY_SPLITS = ConfigBuilder("spark.files.ignoreEmptySplits") - .doc("If true, methods that use HadoopRDD and NewHadoopRDD such as " + - "SparkContext.textFiles will not create a partition for input splits that are empty.") - .booleanConf - .createWithDefault(false) + private[spark] val HADOOP_RDD_IGNORE_EMPTY_SPLITS = + ConfigBuilder("spark.hadoopRDD.ignoreEmptySplits") + .internal() + .doc("When true, HadoopRDD/NewHadoopRDD will not create partitions for empty input splits.") + .booleanConf + .createWithDefault(false) private[spark] val SECRET_REDACTION_PATTERN = ConfigBuilder("spark.redaction.regex") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 1f33c0a2b709f..2480559a41b7a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel @@ -134,7 +134,7 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) - private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index db4eac1d0a775..e4dd1b6a82498 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -90,7 +90,7 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) - private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) def getConf: Configuration = { val conf: Configuration = confBroadcast.value.value diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 4da4323ceb5c8..e9539dc73f6fa 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -510,9 +510,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } - test("spark.files.ignoreEmptySplits work correctly (old Hadoop API)") { + test("spark.hadoopRDD.ignoreEmptySplits work correctly (old Hadoop API)") { val conf = new SparkConf() - conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + .setAppName("test") + .setMaster("local") + .set(HADOOP_RDD_IGNORE_EMPTY_SPLITS, true) sc = new SparkContext(conf) def testIgnoreEmptySplits( @@ -549,9 +551,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { expectedPartitionNum = 2) } - test("spark.files.ignoreEmptySplits work correctly (new Hadoop API)") { + test("spark.hadoopRDD.ignoreEmptySplits work correctly (new Hadoop API)") { val conf = new SparkConf() - conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + .setAppName("test") + .setMaster("local") + .set(HADOOP_RDD_IGNORE_EMPTY_SPLITS, true) sc = new SparkContext(conf) def testIgnoreEmptySplits( From 561505e2fc290fc2cee3b8464ec49df773dca5eb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 16 Oct 2017 11:27:08 -0700 Subject: [PATCH 718/779] [SPARK-22282][SQL] Rename OrcRelation to OrcFileFormat and remove ORC_COMPRESSION ## What changes were proposed in this pull request? This PR aims to - Rename `OrcRelation` to `OrcFileFormat` object. - Replace `OrcRelation.ORC_COMPRESSION` with `org.apache.orc.OrcConf.COMPRESS`. Since [SPARK-21422](https://issues.apache.org/jira/browse/SPARK-21422), we can use `OrcConf.COMPRESS` instead of Hive's. ```scala // The references of Hive's classes will be minimized. val ORC_COMPRESSION = "orc.compress" ``` ## How was this patch tested? Pass the Jenkins with the existing and updated test cases. Author: Dongjoon Hyun Closes #19502 from dongjoon-hyun/SPARK-22282. --- .../org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../spark/sql/hive/orc/OrcFileFormat.scala | 18 ++++++++---------- .../apache/spark/sql/hive/orc/OrcOptions.scala | 8 +++++--- .../spark/sql/hive/orc/OrcQuerySuite.scala | 11 ++++++----- .../spark/sql/hive/orc/OrcSourceSuite.scala | 9 ++++++--- 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 07347d2748544..c9e45436ed42f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -520,8 +520,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default is the value specified in `spark.sql.orc.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive * shorten names(`none`, `snappy`, `zlib`, and `lzo`). This will override - * `orc.compress` and `spark.sql.parquet.compression.codec`. If `orc.compress` is given, - * it overrides `spark.sql.parquet.compression.codec`.
  • + * `orc.compress` and `spark.sql.orc.compression.codec`. If `orc.compress` is given, + * it overrides `spark.sql.orc.compression.codec`. * * * @since 1.5.0 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 194e69c93e1a8..d26ec15410d95 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.orc.OrcConf.COMPRESS import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession @@ -72,7 +73,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val configuration = job.getConfiguration - configuration.set(OrcRelation.ORC_COMPRESSION, orcOptions.compressionCodec) + configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodec) configuration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) @@ -93,8 +94,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable override def getFileExtension(context: TaskAttemptContext): String = { val compressionExtension: String = { - val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) - OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcFileFormat.extensionsForCompressionCodecNames.getOrElse(name, "") } compressionExtension + ".orc" @@ -120,7 +121,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => - hadoopConf.set(OrcRelation.SARG_PUSHDOWN, f.toKryo) + hadoopConf.set(OrcFileFormat.SARG_PUSHDOWN, f.toKryo) hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) } } @@ -138,7 +139,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable if (isEmptyFile) { Iterator.empty } else { - OrcRelation.setRequiredColumns(conf, dataSchema, requiredSchema) + OrcFileFormat.setRequiredColumns(conf, dataSchema, requiredSchema) val orcRecordReader = { val job = Job.getInstance(conf) @@ -160,7 +161,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) // Unwraps `OrcStruct`s to `UnsafeRow`s - OrcRelation.unwrapOrcStructs( + OrcFileFormat.unwrapOrcStructs( conf, dataSchema, requiredSchema, @@ -255,10 +256,7 @@ private[orc] class OrcOutputWriter( } } -private[orc] object OrcRelation extends HiveInspectors { - // The references of Hive's classes will be minimized. - val ORC_COMPRESSION = "orc.compress" - +private[orc] object OrcFileFormat extends HiveInspectors { // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. private[orc] val SARG_PUSHDOWN = "sarg.pushdown" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index 7f94c8c579026..6ce90c07b4921 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.orc import java.util.Locale +import org.apache.orc.OrcConf.COMPRESS + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -40,9 +42,9 @@ private[orc] class OrcOptions( * Acceptable values are defined in [[shortOrcCompressionCodecNames]]. */ val compressionCodec: String = { - // `compression`, `orc.compress`, and `spark.sql.orc.compression.codec` are - // in order of precedence from highest to lowest. - val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION) + // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec` + // are in order of precedence from highest to lowest. + val orcCompressionConf = parameters.get(COMPRESS.getAttribute) val codecName = parameters .get("compression") .orElse(orcCompressionConf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 60ccd996d6d58..1fa9091f967a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} +import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ @@ -176,11 +177,11 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("SPARK-16610: Respect orc.compress option when compression is unset") { - // Respect `orc.compress`. + test("SPARK-16610: Respect orc.compress (i.e., OrcConf.COMPRESS) when compression is unset") { + // Respect `orc.compress` (i.e., OrcConf.COMPRESS). withTempPath { file => spark.range(0, 10).write - .option("orc.compress", "ZLIB") + .option(COMPRESS.getAttribute, "ZLIB") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -191,7 +192,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { file => spark.range(0, 10).write .option("compression", "ZLIB") - .option("orc.compress", "SNAPPY") + .option(COMPRESS.getAttribute, "SNAPPY") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -598,7 +599,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val requestedSchema = StructType(Nil) val conf = new Configuration() val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get - OrcRelation.setRequiredColumns(conf, physicalSchema, requestedSchema) + OrcFileFormat.setRequiredColumns(conf, physicalSchema, requestedSchema) val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) assert(maybeOrcReader.isDefined) val orcRecordReader = new SparkOrcNewRecordReader( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 781de6631f324..ef9e67c743837 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.hive.orc import java.io.File +import java.util.Locale +import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} @@ -150,7 +152,8 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val conf = sqlContext.sessionState.conf - assert(new OrcOptions(Map("Orc.Compress" -> "NONE"), conf).compressionCodec == "NONE") + val option = new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> "NONE"), conf) + assert(option.compressionCodec == "NONE") } test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { @@ -205,8 +208,8 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA // `compression` -> `orc.compression` -> `spark.sql.orc.compression.codec` withSQLConf(SQLConf.ORC_COMPRESSION.key -> "uncompressed") { assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "NONE") - val map1 = Map("orc.compress" -> "zlib") - val map2 = Map("orc.compress" -> "zlib", "compression" -> "lzo") + val map1 = Map(COMPRESS.getAttribute -> "zlib") + val map2 = Map(COMPRESS.getAttribute -> "zlib", "compression" -> "lzo") assert(new OrcOptions(map1, conf).compressionCodec == "ZLIB") assert(new OrcOptions(map2, conf).compressionCodec == "LZO") } From c09a2a76b52905a784d2767cb899dc886c330628 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 16 Oct 2017 16:16:34 -0700 Subject: [PATCH 719/779] [SPARK-22280][SQL][TEST] Improve StatisticsSuite to test `convertMetastore` properly ## What changes were proposed in this pull request? This PR aims to improve **StatisticsSuite** to test `convertMetastore` configuration properly. Currently, some test logic in `test statistics of LogicalRelation converted from Hive serde tables` depends on the default configuration. New test case is shorter and covers both(true/false) cases explicitly. This test case was previously modified by SPARK-17410 and SPARK-17284 in Spark 2.3.0. - https://github.com/apache/spark/commit/a2460be9c30b67b9159fe339d115b84d53cc288a#diff-1c464c86b68c2d0b07e73b7354e74ce7R443 ## How was this patch tested? Pass the Jenkins with the improved test case. Author: Dongjoon Hyun Closes #19500 from dongjoon-hyun/SPARK-22280. --- .../spark/sql/hive/StatisticsSuite.scala | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9ff9ecf7f3677..b9a5ad7657134 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -937,26 +937,20 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } test("test statistics of LogicalRelation converted from Hive serde tables") { - val parquetTable = "parquetTable" - val orcTable = "orcTable" - withTable(parquetTable, orcTable) { - sql(s"CREATE TABLE $parquetTable (key STRING, value STRING) STORED AS PARQUET") - sql(s"CREATE TABLE $orcTable (key STRING, value STRING) STORED AS ORC") - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - sql(s"INSERT INTO TABLE $orcTable SELECT * FROM src") - - // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it - // for robustness - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true") { - checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) - } - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { - // We still can get tableSize from Hive before Analyze - checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = None) - sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") - checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted") { + withTable(format) { + sql(s"CREATE TABLE $format (key STRING, value STRING) STORED AS $format") + sql(s"INSERT INTO TABLE $format SELECT * FROM src") + + checkTableStats(format, hasSizeInBytes = !isConverted, expectedRowCounts = None) + sql(s"ANALYZE TABLE $format COMPUTE STATISTICS") + checkTableStats(format, hasSizeInBytes = true, expectedRowCounts = Some(500)) + } + } } } } From e66cabb0215204605ca7928406d4787d41853dd1 Mon Sep 17 00:00:00 2001 From: Ben Barnard Date: Tue, 17 Oct 2017 09:36:09 +0200 Subject: [PATCH 720/779] [SPARK-20992][SCHEDULER] Add links in documentation to Nomad integration. ## What changes were proposed in this pull request? Adds links to the fork that provides integration with Nomad, in the same places the k8s integration is linked to. ## How was this patch tested? I clicked on the links to make sure they're correct ;) Author: Ben Barnard Closes #19354 from barnardb/link-to-nomad-integration. --- docs/cluster-overview.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index a2ad958959a50..c42bb4bb8377e 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -58,6 +58,9 @@ for providing container-centric infrastructure. Kubernetes support is being acti developed in an [apache-spark-on-k8s](https://github.com/apache-spark-on-k8s/) Github organization. For documentation, refer to that project's README. +A third-party project (not supported by the Spark project) exists to add support for +[Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. + # Submitting Applications Applications can be submitted to a cluster of any type using the `spark-submit` script. From 8148f19ca1f0e0375603cb4f180c1bad8b0b8042 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Oct 2017 09:41:23 +0200 Subject: [PATCH 721/779] [SPARK-22249][SQL] isin with empty list throws exception on cached DataFrame ## What changes were proposed in this pull request? As pointed out in the JIRA, there is a bug which causes an exception to be thrown if `isin` is called with an empty list on a cached DataFrame. The PR fixes it. ## How was this patch tested? Added UT. Author: Marco Gaido Closes #19494 from mgaido91/SPARK-22249. --- .../columnar/InMemoryTableScanExec.scala | 1 + .../columnar/InMemoryColumnarQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index af3636a5a2ca7..846ec03e46a12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -102,6 +102,7 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 + case In(_: AttributeReference, list: Seq[Expression]) if list.isEmpty => Literal.FalseLiteral case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 8d411eb191cd9..75d17bc79477d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -429,4 +429,19 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(agg_without_cache, agg_with_cache) } } + + test("SPARK-22249: IN should work also with cached DataFrame") { + val df = spark.range(10).cache() + // with an empty list + assert(df.filter($"id".isin()).count() == 0) + // with a non-empty list + assert(df.filter($"id".isin(2)).count() == 1) + assert(df.filter($"id".isin(2, 3)).count() == 2) + df.unpersist() + val dfNulls = spark.range(10).selectExpr("null as id").cache() + // with null as value for the attribute + assert(dfNulls.filter($"id".isin()).count() == 0) + assert(dfNulls.filter($"id".isin(2, 3)).count() == 0) + dfNulls.unpersist() + } } From 99e32f8ba5d908d5408e9857fd96ac1d7d7e5876 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 17 Oct 2017 17:58:45 +0800 Subject: [PATCH 722/779] [SPARK-22224][SQL] Override toString of KeyValue/Relational-GroupedDataset ## What changes were proposed in this pull request? #### before ```scala scala> val words = spark.read.textFile("README.md").flatMap(_.split(" ")) words: org.apache.spark.sql.Dataset[String] = [value: string] scala> val grouped = words.groupByKey(identity) grouped: org.apache.spark.sql.KeyValueGroupedDataset[String,String] = org.apache.spark.sql.KeyValueGroupedDataset65214862 ``` #### after ```scala scala> val words = spark.read.textFile("README.md").flatMap(_.split(" ")) words: org.apache.spark.sql.Dataset[String] = [value: string] scala> val grouped = words.groupByKey(identity) grouped: org.apache.spark.sql.KeyValueGroupedDataset[String,String] = [key: [value: string], value: [value: string]] ``` ## How was this patch tested? existing ut cc gatorsmile cloud-fan Author: Kent Yao Closes #19363 from yaooqinn/minor-dataset-tostring. --- .../spark/sql/KeyValueGroupedDataset.scala | 22 ++++++- .../spark/sql/RelationalGroupedDataset.scala | 19 +++++- .../org/apache/spark/sql/DatasetSuite.scala | 61 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 12 +--- 4 files changed, 100 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index cb42e9e4560cf..6bab21dca0cbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,7 +24,6 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} @@ -564,4 +563,25 @@ class KeyValueGroupedDataset[K, V] private[sql]( encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } + + override def toString: String = { + val builder = new StringBuilder + val kFields = kExprEnc.schema.map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + val vFields = vExprEnc.schema.map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("KeyValueGroupedDataset: [key: [") + builder.append(kFields.take(2).mkString(", ")) + if (kFields.length > 2) { + builder.append(" ... " + (kFields.length - 2) + " more field(s)") + } + builder.append("], value: [") + builder.append(vFields.take(2).mkString(", ")) + if (vFields.length > 2) { + builder.append(" ... " + (vFields.length - 2) + " more field(s)") + } + builder.append("]]").toString() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index cd0ac1feffa51..33ec3a27110a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{NumericType, StructField, StructType} +import org.apache.spark.sql.types.{NumericType, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -465,6 +465,19 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + + override def toString: String = { + val builder = new StringBuilder + builder.append("RelationalGroupedDataset: [grouping expressions: [") + val kFields = groupingExprs.map(_.asInstanceOf[NamedExpression]).map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append(kFields.take(2).mkString(", ")) + if (kFields.length > 2) { + builder.append(" ... " + (kFields.length - 2) + " more field(s)") + } + builder.append(s"], value: ${df.toString}, type: $groupType]").toString() + } } private[sql] object RelationalGroupedDataset { @@ -479,7 +492,9 @@ private[sql] object RelationalGroupedDataset { /** * The Grouping Type */ - private[sql] trait GroupType + private[sql] trait GroupType { + override def toString: String = getClass.getSimpleName.stripSuffix("$").stripSuffix("Type") + } /** * To indicate it's the GroupBy diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dace6825ee40e..1537ce3313c09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1341,8 +1341,69 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)), ("", Seq((1, 1), (2, 2)))) } + + test("Check RelationalGroupedDataset toString: Single data") { + val kvDataset = (1 to 3).toDF("id").groupBy("id") + val expected = "RelationalGroupedDataset: [" + + "grouping expressions: [id: int], value: [id: int], type: GroupBy]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check RelationalGroupedDataset toString: over length schema ") { + val kvDataset = (1 to 3).map( x => (x, x.toString, x.toLong)) + .toDF("id", "val1", "val2").groupBy("id") + val expected = "RelationalGroupedDataset:" + + " [grouping expressions: [id: int]," + + " value: [id: int, val1: string ... 1 more field]," + + " type: GroupBy]" + val actual = kvDataset.toString + assert(expected === actual) + } + + + test("Check KeyValueGroupedDataset toString: Single data") { + val kvDataset = (1 to 3).toDF("id").as[SingleData].groupByKey(identity) + val expected = "KeyValueGroupedDataset: [key: [id: int], value: [id: int]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: Unnamed KV-pair") { + val kvDataset = (1 to 3).map(x => (x, x.toString)) + .toDF("id", "val1").as[DoubleData].groupByKey(x => (x.id, x.val1)) + val expected = "KeyValueGroupedDataset:" + + " [key: [_1: int, _2: string]," + + " value: [id: int, val1: string]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: Named KV-pair") { + val kvDataset = (1 to 3).map( x => (x, x.toString)) + .toDF("id", "val1").as[DoubleData].groupByKey(x => DoubleData(x.id, x.val1)) + val expected = "KeyValueGroupedDataset:" + + " [key: [id: int, val1: string]," + + " value: [id: int, val1: string]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: over length schema ") { + val kvDataset = (1 to 3).map( x => (x, x.toString, x.toLong)) + .toDF("id", "val1", "val2").as[TripleData].groupByKey(identity) + val expected = "KeyValueGroupedDataset:" + + " [key: [id: int, val1: string ... 1 more field(s)]," + + " value: [id: int, val1: string ... 1 more field(s)]]" + val actual = kvDataset.toString + assert(expected === actual) + } } +case class SingleData(id: Int) +case class DoubleData(id: Int, val1: String) +case class TripleData(id: Int, val1: String, val2: Long) + case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) case class WithMap(id: String, map_test: scala.collection.Map[Long, String]) case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f9808834df4a5..fcaca3d75b74f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,23 +17,13 @@ package org.apache.spark.sql -import java.util.{ArrayDeque, Locale, TimeZone} +import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.streaming.MemoryPlan -import org.apache.spark.sql.types.{Metadata, ObjectType} abstract class QueryTest extends PlanTest { From e1960c3d6f380b0dfbba6ee5d8ac6da4bc29a698 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 17 Oct 2017 22:54:38 +0800 Subject: [PATCH 723/779] [SPARK-22062][CORE] Spill large block to disk in BlockManager's remote fetch to avoid OOM ## What changes were proposed in this pull request? In the current BlockManager's `getRemoteBytes`, it will call `BlockTransferService#fetchBlockSync` to get remote block. In the `fetchBlockSync`, Spark will allocate a temporary `ByteBuffer` to store the whole fetched block. This will potentially lead to OOM if block size is too big or several blocks are fetched simultaneously in this executor. So here leveraging the idea of shuffle fetch, to spill the large block to local disk before consumed by upstream code. The behavior is controlled by newly added configuration, if block size is smaller than the threshold, then this block will be persisted in memory; otherwise it will first spill to disk, and then read from disk file. To achieve this feature, what I did is: 1. Rename `TempShuffleFileManager` to `TempFileManager`, since now it is not only used by shuffle. 2. Add a new `TempFileManager` to manage the files of fetched remote blocks, the files are tracked by weak reference, will be deleted when no use at all. ## How was this patch tested? This was tested by adding UT, also manual verification in local test to perform GC to clean the files. Author: jerryshao Closes #19476 from jerryshao/SPARK-22062. --- .../shuffle/ExternalShuffleClient.java | 4 +- .../shuffle/OneForOneBlockFetcher.java | 12 +-- .../spark/network/shuffle/ShuffleClient.java | 10 +- ...eFileManager.java => TempFileManager.java} | 12 +-- .../scala/org/apache/spark/SparkConf.scala | 4 +- .../spark/internal/config/package.scala | 15 +-- .../spark/network/BlockTransferService.scala | 28 +++-- .../netty/NettyBlockTransferService.scala | 6 +- .../shuffle/BlockStoreShuffleReader.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 102 ++++++++++++++++-- .../spark/storage/BlockManagerMaster.scala | 6 ++ .../storage/BlockManagerMasterEndpoint.scala | 14 +++ .../spark/storage/BlockManagerMessages.scala | 7 ++ .../storage/ShuffleBlockFetcherIterator.scala | 8 +- .../org/apache/spark/DistributedSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 57 +++++++--- .../ShuffleBlockFetcherIteratorSuite.scala | 10 +- docs/configuration.md | 11 +- 18 files changed, 236 insertions(+), 74 deletions(-) rename common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/{TempShuffleFileManager.java => TempFileManager.java} (74%) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 77702447edb88..510017fee2db5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -91,7 +91,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempShuffleFileManager tempShuffleFileManager) { + TempFileManager tempFileManager) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -99,7 +99,7 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, tempShuffleFileManager).start(); + blockIds1, listener1, conf, tempFileManager).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 66b67e282c80d..3f2f20b4149f1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -58,7 +58,7 @@ public class OneForOneBlockFetcher { private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; private final TransportConf transportConf; - private final TempShuffleFileManager tempShuffleFileManager; + private final TempFileManager tempFileManager; private StreamHandle streamHandle = null; @@ -79,14 +79,14 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - TempShuffleFileManager tempShuffleFileManager) { + TempFileManager tempFileManager) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - this.tempShuffleFileManager = tempShuffleFileManager; + this.tempFileManager = tempFileManager; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -125,7 +125,7 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (tempShuffleFileManager != null) { + if (tempFileManager != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { @@ -164,7 +164,7 @@ private class DownloadCallback implements StreamCallback { private int chunkIndex; DownloadCallback(int chunkIndex) throws IOException { - this.targetFile = tempShuffleFileManager.createTempShuffleFile(); + this.targetFile = tempFileManager.createTempFile(); this.channel = Channels.newChannel(Files.newOutputStream(targetFile.toPath())); this.chunkIndex = chunkIndex; } @@ -180,7 +180,7 @@ public void onComplete(String streamId) throws IOException { ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); - if (!tempShuffleFileManager.registerTempShuffleFileToClean(targetFile)) { + if (!tempFileManager.registerTempFileToClean(targetFile)) { targetFile.delete(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 5bd4412b75275..18b04fedcac5b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -43,10 +43,10 @@ public void init(String appId) { } * @param execId the executor id. * @param blockIds block ids to fetch. * @param listener the listener to receive block fetching status. - * @param tempShuffleFileManager TempShuffleFileManager to create and clean temp shuffle files. - * If it's not null, the remote blocks will be streamed - * into temp shuffle files to reduce the memory usage, otherwise, - * they will be kept in memory. + * @param tempFileManager TempFileManager to create and clean temp files. + * If it's not null, the remote blocks will be streamed + * into temp shuffle files to reduce the memory usage, otherwise, + * they will be kept in memory. */ public abstract void fetchBlocks( String host, @@ -54,7 +54,7 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempShuffleFileManager tempShuffleFileManager); + TempFileManager tempFileManager); /** * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java similarity index 74% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java index 84a5ed6a276bd..552364d274f19 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java @@ -20,17 +20,17 @@ import java.io.File; /** - * A manager to create temp shuffle block files to reduce the memory usage and also clean temp + * A manager to create temp block files to reduce the memory usage and also clean temp * files when they won't be used any more. */ -public interface TempShuffleFileManager { +public interface TempFileManager { - /** Create a temp shuffle block file. */ - File createTempShuffleFile(); + /** Create a temp block file. */ + File createTempFile(); /** - * Register a temp shuffle file to clean up when it won't be used any more. Return whether the + * Register a temp file to clean up when it won't be used any more. Return whether the * file is registered successfully. If `false`, the caller should clean up the file by itself. */ - boolean registerTempShuffleFileToClean(File file); + boolean registerTempFileToClean(File file); } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index e61f943af49f2..57b3744e9c30a 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -662,7 +662,9 @@ private[spark] object SparkConf extends Logging { "spark.yarn.jars" -> Seq( AlternateConfig("spark.yarn.jar", "2.0")), "spark.yarn.access.hadoopFileSystems" -> Seq( - AlternateConfig("spark.yarn.access.namenodes", "2.2")) + AlternateConfig("spark.yarn.access.namenodes", "2.2")), + "spark.maxRemoteBlockSizeFetchToMem" -> Seq( + AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")) ) /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index efffdca1ea59b..e7b406af8d9b1 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -357,13 +357,15 @@ package object config { .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") .createWithDefault(Int.MaxValue) - private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = - ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") - .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + + private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = + ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") + .doc("Remote block will be fetched to disk when size of the block is " + "above this threshold. This is to avoid a giant request takes too much memory. We can " + - "enable this config by setting a specific value(e.g. 200m). Note that this config can " + - "be enabled only when the shuffle shuffle service is newer than Spark-2.2 or the shuffle" + - " service is disabled.") + "enable this config by setting a specific value(e.g. 200m). Note this configuration will " + + "affect both shuffle fetch and block manager remote block fetch. For users who " + + "enabled external shuffle service, this feature can only be worked when external shuffle" + + " service is newer than Spark 2.2.") + .withAlternative("spark.reducer.maxReqSizeShuffleToMem") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) @@ -432,5 +434,4 @@ package object config { .stringConf .toSequence .createOptional - } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index fe5fd2da039bb..1d8a266d0079c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -25,8 +25,8 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.ThreadUtils @@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit + tempFileManager: TempFileManager): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -87,7 +87,12 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo * * It is also only available after [[init]] is invoked. */ - def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = { + def fetchBlockSync( + host: String, + port: Int, + execId: String, + blockId: String, + tempFileManager: TempFileManager): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() fetchBlocks(host, port, execId, Array(blockId), @@ -96,12 +101,17 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.failure(exception) } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - val ret = ByteBuffer.allocate(data.size.toInt) - ret.put(data.nioByteBuffer()) - ret.flip() - result.success(new NioManagedBuffer(ret)) + data match { + case f: FileSegmentManagedBuffer => + result.success(f) + case _ => + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + ret.flip() + result.success(new NioManagedBuffer(ret)) + } } - }, tempShuffleFileManager = null) + }, tempFileManager) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 6a29e18bf3cbb..b7d8c35032763 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -32,7 +32,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -105,14 +105,14 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit = { + tempFileManager: TempFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tempShuffleFileManager).start() + transportConf, tempFileManager).start() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c8d1460300934..0562d45ff57c5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -52,7 +52,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a98083df5bd84..e0276a4dc4224 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,8 +18,11 @@ package org.apache.spark.storage import java.io._ +import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} import java.nio.ByteBuffer import java.nio.channels.Channels +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.collection.mutable.HashMap @@ -39,7 +42,7 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.shuffle.ExternalShuffleClient +import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.{SerializerInstance, SerializerManager} @@ -203,6 +206,13 @@ private[spark] class BlockManager( private var blockReplicationPolicy: BlockReplicationPolicy = _ + // A TempFileManager used to track all the files of remote blocks which above the + // specified memory threshold. Files will be deleted automatically based on weak reference. + // Exposed for test + private[storage] val remoteBlockTempFileManager = + new BlockManager.RemoteBlockTempFileManager(this) + private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -632,8 +642,8 @@ private[spark] class BlockManager( * Return a list of locations for the given block, prioritizing the local machine since * multiple block managers can share the same host, followed by hosts on the same rack. */ - private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - val locs = Random.shuffle(master.getLocations(blockId)) + private def sortLocations(locations: Seq[BlockManagerId]): Seq[BlockManagerId] = { + val locs = Random.shuffle(locations) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } blockManagerId.topologyInfo match { case None => preferredLocs ++ otherLocs @@ -653,7 +663,25 @@ private[spark] class BlockManager( require(blockId != null, "BlockId is null") var runningFailureCount = 0 var totalFailureCount = 0 - val locations = getLocations(blockId) + + // Because all the remote blocks are registered in driver, it is not necessary to ask + // all the slave executors to get block status. + val locationsAndStatus = master.getLocationsAndStatus(blockId) + val blockSize = locationsAndStatus.map { b => + b.status.diskSize.max(b.status.memSize) + }.getOrElse(0L) + val blockLocations = locationsAndStatus.map(_.locations).getOrElse(Seq.empty) + + // If the block size is above the threshold, we should pass our FileManger to + // BlockTransferService, which will leverage it to spill the block; if not, then passed-in + // null value means the block will be persisted in memory. + val tempFileManager = if (blockSize > maxRemoteBlockToMem) { + remoteBlockTempFileManager + } else { + null + } + + val locations = sortLocations(blockLocations) val maxFetchFailures = locations.size var locationIterator = locations.iterator while (locationIterator.hasNext) { @@ -661,7 +689,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() } catch { case NonFatal(e) => runningFailureCount += 1 @@ -684,7 +712,7 @@ private[spark] class BlockManager( // take a significant amount of time. To get rid of these stale entries // we refresh the block locations after a certain number of fetch failures if (runningFailureCount >= maxFailuresBeforeLocationRefresh) { - locationIterator = getLocations(blockId).iterator + locationIterator = sortLocations(master.getLocations(blockId)).iterator logDebug(s"Refreshed locations from the driver " + s"after ${runningFailureCount} fetch failures.") runningFailureCount = 0 @@ -1512,6 +1540,7 @@ private[spark] class BlockManager( // Closing should be idempotent, but maybe not for the NioBlockTransferService. shuffleClient.close() } + remoteBlockTempFileManager.stop() diskBlockManager.stop() rpcEnv.stop(slaveEndpoint) blockInfoManager.clear() @@ -1552,4 +1581,65 @@ private[spark] object BlockManager { override val metricRegistry = new MetricRegistry metricRegistry.registerAll(metricSet) } + + class RemoteBlockTempFileManager(blockManager: BlockManager) + extends TempFileManager with Logging { + + private class ReferenceWithCleanup(file: File, referenceQueue: JReferenceQueue[File]) + extends WeakReference[File](file, referenceQueue) { + private val filePath = file.getAbsolutePath + + def cleanUp(): Unit = { + logDebug(s"Clean up file $filePath") + + if (!new File(filePath).delete()) { + logDebug(s"Fail to delete file $filePath") + } + } + } + + private val referenceQueue = new JReferenceQueue[File] + private val referenceBuffer = Collections.newSetFromMap[ReferenceWithCleanup]( + new ConcurrentHashMap) + + private val POLL_TIMEOUT = 1000 + @volatile private var stopped = false + + private val cleaningThread = new Thread() { override def run() { keepCleaning() } } + cleaningThread.setDaemon(true) + cleaningThread.setName("RemoteBlock-temp-file-clean-thread") + cleaningThread.start() + + override def createTempFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + + override def registerTempFileToClean(file: File): Boolean = { + referenceBuffer.add(new ReferenceWithCleanup(file, referenceQueue)) + } + + def stop(): Unit = { + stopped = true + cleaningThread.interrupt() + cleaningThread.join() + } + + private def keepCleaning(): Unit = { + while (!stopped) { + try { + Option(referenceQueue.remove(POLL_TIMEOUT)) + .map(_.asInstanceOf[ReferenceWithCleanup]) + .foreach { ref => + referenceBuffer.remove(ref) + ref.cleanUp() + } + } catch { + case _: InterruptedException => + // no-op + case NonFatal(e) => + logError("Error in cleaning thread", e) + } + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 8b1dc0ba6356a..d24421b962774 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -84,6 +84,12 @@ class BlockManagerMaster( driverEndpoint.askSync[Seq[BlockManagerId]](GetLocations(blockId)) } + /** Get locations as well as status of the blockId from the driver */ + def getLocationsAndStatus(blockId: BlockId): Option[BlockLocationsAndStatus] = { + driverEndpoint.askSync[Option[BlockLocationsAndStatus]]( + GetLocationsAndStatus(blockId)) + } + /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { driverEndpoint.askSync[IndexedSeq[Seq[BlockManagerId]]]( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index df0a5f5e229fb..56d0266b8edad 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -82,6 +82,9 @@ class BlockManagerMasterEndpoint( case GetLocations(blockId) => context.reply(getLocations(blockId)) + case GetLocationsAndStatus(blockId) => + context.reply(getLocationsAndStatus(blockId)) + case GetLocationsMultipleBlockIds(blockIds) => context.reply(getLocationsMultipleBlockIds(blockIds)) @@ -422,6 +425,17 @@ class BlockManagerMasterEndpoint( if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } + private def getLocationsAndStatus(blockId: BlockId): Option[BlockLocationsAndStatus] = { + val locations = Option(blockLocations.get(blockId)).map(_.toSeq).getOrElse(Seq.empty) + val status = locations.headOption.flatMap { bmId => blockManagerInfo(bmId).getStatus(blockId) } + + if (locations.nonEmpty && status.isDefined) { + Some(BlockLocationsAndStatus(locations, status.get)) + } else { + None + } + } + private def getLocationsMultipleBlockIds( blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 0c0ff144596ac..1bbe7a5b39509 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -93,6 +93,13 @@ private[spark] object BlockManagerMessages { case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster + case class GetLocationsAndStatus(blockId: BlockId) extends ToBlockManagerMaster + + // The response message of `GetLocationsAndStatus` request. + case class BlockLocationsAndStatus(locations: Seq[BlockManagerId], status: BlockStatus) { + assert(locations.nonEmpty) + } + case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 2d176b62f8b36..98b5a735a4529 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -69,7 +69,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { + extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -162,11 +162,11 @@ final class ShuffleBlockFetcherIterator( currentResult = null } - override def createTempShuffleFile(): File = { + override def createTempFile(): File = { blockManager.diskBlockManager.createTempLocalBlock()._2 } - override def registerTempShuffleFileToClean(file: File): Boolean = synchronized { + override def registerTempFileToClean(file: File): Boolean = synchronized { if (isZombie) { false } else { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index bea67b71a5a12..f8005610f7e4f 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -171,7 +171,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, - blockId.toString) + blockId.toString, null) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index cfe89fde63f88..d45c194d31adc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.storage -import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -45,14 +44,14 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat +import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer @@ -512,8 +511,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE when(bmMaster.getLocations(mc.any[BlockId])).thenReturn(Seq(bmId1, bmId2, bmId3)) val blockManager = makeBlockManager(128, "exec", bmMaster) - val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) - val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + val sortLocations = PrivateMethod[Seq[BlockManagerId]]('sortLocations) + val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test")) assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) } @@ -535,8 +534,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) blockManager.blockManagerId = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack)) - val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) - val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + val sortLocations = PrivateMethod[Seq[BlockManagerId]]('sortLocations) + val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test")) assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost)) assert(locations.flatMap(_.topologyInfo) === Seq(localRack, localRack, localRack, otherRack, otherRack)) @@ -1274,13 +1273,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // so that we have a chance to do location refresh val blockManagerIds = (0 to maxFailuresBeforeLocationRefresh) .map { i => BlockManagerId(s"id-$i", s"host-$i", i + 1) } - when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockManagerIds) + when(mockBlockManagerMaster.getLocationsAndStatus(mc.any[BlockId])).thenReturn( + Option(BlockLocationsAndStatus(blockManagerIds, BlockStatus.empty))) + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn( + blockManagerIds) + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, transferService = Option(mockBlockTransferService)) val block = store.getRemoteBytes("item") .asInstanceOf[Option[ByteBuffer]] assert(block.isDefined) - verify(mockBlockManagerMaster, times(2)).getLocations("item") + verify(mockBlockManagerMaster, times(1)).getLocationsAndStatus("item") + verify(mockBlockManagerMaster, times(1)).getLocations("item") } test("SPARK-17484: block status is properly updated following an exception in put()") { @@ -1371,8 +1375,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE server.close() } + test("fetch remote block to local disk if block size is larger than threshold") { + conf.set(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM, 1000L) + + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val mockBlockTransferService = new MockBlockTransferService(0) + val blockLocations = Seq(BlockManagerId("id-0", "host-0", 1)) + val blockStatus = BlockStatus(StorageLevel.DISK_ONLY, 0L, 2000L) + + when(mockBlockManagerMaster.getLocationsAndStatus(mc.any[BlockId])).thenReturn( + Option(BlockLocationsAndStatus(blockLocations, blockStatus))) + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockLocations) + + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, + transferService = Option(mockBlockTransferService)) + val block = store.getRemoteBytes("item") + .asInstanceOf[Option[ByteBuffer]] + + assert(block.isDefined) + assert(mockBlockTransferService.numCalls === 1) + // assert FileManager is not null if the block size is larger than threshold. + assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 + var tempFileManager: TempFileManager = null override def init(blockDataManager: BlockDataManager): Unit = {} @@ -1382,7 +1410,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit = { + tempFileManager: TempFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } @@ -1394,7 +1422,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def uploadBlock( hostname: String, - port: Int, execId: String, + port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel, @@ -1407,12 +1436,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE host: String, port: Int, execId: String, - blockId: String): ManagedBuffer = { + blockId: String, + tempFileManager: TempFileManager): ManagedBuffer = { numCalls += 1 + this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId) + super.fetchBlockSync(host, port, execId, blockId, tempFileManager) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c371cbcf8dff5..5bfe9905ff17b 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -437,12 +437,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var tempShuffleFileManager: TempShuffleFileManager = null + var tempFileManager: TempFileManager = null when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager] + tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -472,13 +472,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(tempShuffleFileManager == null) + assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(tempShuffleFileManager != null) + assert(tempFileManager != null) } } diff --git a/docs/configuration.md b/docs/configuration.md index 7a777d3c6fa3d..bb06c8faaaed7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -547,13 +547,14 @@ Apart from these, the following properties are also available, and may be useful - spark.reducer.maxReqSizeShuffleToMem + spark.maxRemoteBlockSizeFetchToMem Long.MaxValue - The blocks of a shuffle request will be fetched to disk when size of the request is above - this threshold. This is to avoid a giant request takes too much memory. We can enable this - config by setting a specific value(e.g. 200m). Note that this config can be enabled only when - the shuffle shuffle service is newer than Spark-2.2 or the shuffle service is disabled. + The remote block will be fetched to disk when size of the block is above this threshold. + This is to avoid a giant request takes too much memory. We can enable this config by setting + a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + and block manager remote block fetch. For users who enabled external shuffle service, + this feature can only be worked when external shuffle service is newer than Spark 2.2. From 75d666b95a711787355ca3895057dabadd429023 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 17 Oct 2017 12:26:53 -0700 Subject: [PATCH 724/779] [SPARK-22136][SS] Evaluate one-sided conditions early in stream-stream joins. ## What changes were proposed in this pull request? Evaluate one-sided conditions early in stream-stream joins. This is in addition to normal filter pushdown, because integrating it with the join logic allows it to take place in outer join scenarios. This means that rows which can never satisfy the join condition won't clog up the state. ## How was this patch tested? new unit tests Author: Jose Torres Closes #19452 from joseph-torres/SPARK-22136. --- .../streaming/IncrementalExecution.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 134 ++++++++++------ .../StreamingSymmetricHashJoinHelper.scala | 70 +++++++- .../sql/streaming/StreamingJoinSuite.scala | 150 +++++++++++++++++- ...treamingSymmetricHashJoinHelperSuite.scala | 130 +++++++++++++++ 5 files changed, 433 insertions(+), 53 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 2e378637727fc..a10ed5f2df1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -133,7 +133,7 @@ class IncrementalExecution( eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( - j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition, + j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, Some(offsetSeqMetadata.batchWatermarkMs)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 9bd2127a28ff6..c351f658cb955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SessionState -import org.apache.spark.sql.types.{LongType, TimestampType} import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -115,7 +114,8 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} * @param leftKeys Expression to generate key rows for joining from left input * @param rightKeys Expression to generate key rows for joining from right input * @param joinType Type of join (inner, left outer, etc.) - * @param condition Optional, additional condition to filter output of the equi-join + * @param condition Conditions to filter rows, split by left, right, and joined. See + * [[JoinConditionSplitPredicates]] * @param stateInfo Version information required to read join state (buffered rows) * @param eventTimeWatermark Watermark of input event, same for both sides * @param stateWatermarkPredicates Predicates for removal of state, see @@ -127,7 +127,7 @@ case class StreamingSymmetricHashJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression], + condition: JoinConditionSplitPredicates, stateInfo: Option[StatefulOperatorStateInfo], eventTimeWatermark: Option[Long], stateWatermarkPredicates: JoinStateWatermarkPredicates, @@ -141,8 +141,10 @@ case class StreamingSymmetricHashJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan) = { + this( - leftKeys, rightKeys, joinType, condition, stateInfo = None, eventTimeWatermark = None, + leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right), + stateInfo = None, eventTimeWatermark = None, stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) } @@ -161,6 +163,9 @@ case class StreamingSymmetricHashJoinExec( new SerializableConfiguration(SessionState.newHadoopConf( sparkContext.hadoopConfiguration, sqlContext.conf))) + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -206,10 +211,15 @@ case class StreamingSymmetricHashJoinExec( val updateStartTimeNs = System.nanoTime val joinedRow = new JoinedRow + + val postJoinFilter = + newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _ val leftSideJoiner = new OneSideHashJoiner( - LeftSide, left.output, leftKeys, leftInputIter, stateWatermarkPredicates.left) + LeftSide, left.output, leftKeys, leftInputIter, + condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left) val rightSideJoiner = new OneSideHashJoiner( - RightSide, right.output, rightKeys, rightInputIter, stateWatermarkPredicates.right) + RightSide, right.output, rightKeys, rightInputIter, + condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right) // Join one side input using the other side's buffered/state rows. Here is how it is done. // @@ -221,43 +231,28 @@ case class StreamingSymmetricHashJoinExec( // matching new left input with new right input, since the new left input has become stored // by that point. This tiny asymmetry is necessary to avoid duplication. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { - (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(input).withRight(matched) + (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched) } val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) { - (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(matched).withRight(input) + (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input) } - // Filter the joined rows based on the given condition. - val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _ - // We need to save the time that the inner join output iterator completes, since outer join // output counts as both update and removal time. var innerOutputCompletionTimeNs: Long = 0 def onInnerOutputCompletion = { innerOutputCompletionTimeNs = System.nanoTime } - val filteredInnerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction), onInnerOutputCompletion) - - def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { - rightSideJoiner.get(leftKeyValue.key).exists( - rightValue => { - outputFilterFunction( - joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) - }) - } + // This is the iterator which produces the inner join rows. For outer joins, this will be + // prepended to a second iterator producing outer join rows; for inner joins, this is the full + // output. + val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( + (leftOutputIter ++ rightOutputIter), onInnerOutputCompletion) - def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { - leftSideJoiner.get(rightKeyValue.key).exists( - leftValue => { - outputFilterFunction( - joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) - }) - } val outputIter: Iterator[InternalRow] = joinType match { case Inner => - filteredInnerOutputIter + innerOutputIter case LeftOuter => // We generate the outer join input by: // * Getting an iterator over the rows that have aged out on the left side. These rows are @@ -268,28 +263,37 @@ case class StreamingSymmetricHashJoinExec( // we know we can join with null, since there was never (including this batch) a match // within the watermark period. If it does, there must have been a match at some point, so // we know we can't join with null. - val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { + rightSideJoiner.get(leftKeyValue.key).exists { rightValue => + postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) + } + } val removedRowIter = leftSideJoiner.removeOldState() val outerOutputIter = removedRowIter .filterNot(pair => matchesWithRightSideState(pair)) .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) - filteredInnerOutputIter ++ outerOutputIter + innerOutputIter ++ outerOutputIter case RightOuter => // See comments for left outer case. - val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { + leftSideJoiner.get(rightKeyValue.key).exists { leftValue => + postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) + } + } val removedRowIter = rightSideJoiner.removeOldState() val outerOutputIter = removedRowIter .filterNot(pair => matchesWithLeftSideState(pair)) .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) - filteredInnerOutputIter ++ outerOutputIter + innerOutputIter ++ outerOutputIter case _ => throwBadJoinTypeException() } + val outputProjection = UnsafeProjection.create(left.output ++ right.output, output) val outputIterWithMetrics = outputIter.map { row => numOutputRows += 1 - row + outputProjection(row) } // Function to remove old state after all the input has been consumed and output generated @@ -349,14 +353,36 @@ case class StreamingSymmetricHashJoinExec( /** * Internal helper class to consume input rows, generate join output rows using other sides * buffered state rows, and finally clean up this sides buffered state rows + * + * @param joinSide The JoinSide - either left or right. + * @param inputAttributes The input attributes for this side of the join. + * @param joinKeys The join keys. + * @param inputIter The iterator of input rows on this side to be joined. + * @param preJoinFilterExpr A filter over rows on this side. This filter rejects rows that could + * never pass the overall join condition no matter what other side row + * they're joined with. + * @param postJoinFilter A filter over joined rows. This filter completes the application of + * the overall join condition, assuming that preJoinFilter on both sides + * of the join has already been passed. + * Passed as a function rather than expression to avoid creating the + * predicate twice; we also need this filter later on in the parent exec. + * @param stateWatermarkPredicate The state watermark predicate. See + * [[StreamingSymmetricHashJoinExec]] for further description of + * state watermarks. */ private class OneSideHashJoiner( joinSide: JoinSide, inputAttributes: Seq[Attribute], joinKeys: Seq[Expression], inputIter: Iterator[InternalRow], + preJoinFilterExpr: Option[Expression], + postJoinFilter: (InternalRow) => Boolean, stateWatermarkPredicate: Option[JoinStateWatermarkPredicate]) { + // Filter the joined rows based on the given condition. + val preJoinFilter = + newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ + private val joinStateManager = new SymmetricHashJoinStateManager( joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) @@ -388,8 +414,8 @@ case class StreamingSymmetricHashJoinExec( */ def storeAndJoinWithOtherSide( otherSideJoiner: OneSideHashJoiner)( - generateJoinedRow: (UnsafeRow, UnsafeRow) => JoinedRow): Iterator[InternalRow] = { - + generateJoinedRow: (InternalRow, InternalRow) => JoinedRow): + Iterator[InternalRow] = { val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey)) val nonLateRows = WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match { @@ -402,17 +428,31 @@ case class StreamingSymmetricHashJoinExec( nonLateRows.flatMap { row => val thisRow = row.asInstanceOf[UnsafeRow] - val key = keyGenerator(thisRow) - val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => - generateJoinedRow(thisRow, thatRow) - } - val shouldAddToState = // add only if both removal predicates do not match - !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) - if (shouldAddToState) { - joinStateManager.append(key, thisRow) - updatedStateRowsCount += 1 + // If this row fails the pre join filter, that means it can never satisfy the full join + // condition no matter what other side row it's matched with. This allows us to avoid + // adding it to the state, and generate an outer join row immediately (or do nothing in + // the case of inner join). + if (preJoinFilter(thisRow)) { + val key = keyGenerator(thisRow) + val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => + generateJoinedRow(thisRow, thatRow) + }.filter(postJoinFilter) + val shouldAddToState = // add only if both removal predicates do not match + !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) + if (shouldAddToState) { + joinStateManager.append(key, thisRow) + updatedStateRowsCount += 1 + } + outputIter + } else { + joinSide match { + case LeftSide if joinType == LeftOuter => + Iterator(generateJoinedRow(thisRow, nullRight)) + case RightSide if joinType == RightOuter => + Iterator(generateJoinedRow(thisRow, nullLeft)) + case _ => Iterator() + } } - outputIter } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 64c7189f72ac3..167e991ca62f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -24,8 +24,9 @@ import org.apache.spark.{Partition, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper -import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression import org.apache.spark.sql.execution.streaming.state.{StateStoreCoordinatorRef, StateStoreProvider, StateStoreProviderId} import org.apache.spark.sql.types._ @@ -66,6 +67,73 @@ object StreamingSymmetricHashJoinHelper extends Logging { } } + /** + * Wrapper around various useful splits of the join condition. + * left AND right AND joined is equivalent to full. + * + * Note that left and right do not necessarily contain *all* conjuncts which satisfy + * their condition. Any conjuncts after the first nondeterministic one are treated as + * nondeterministic for purposes of the split. + * + * @param leftSideOnly Deterministic conjuncts which reference only the left side of the join. + * @param rightSideOnly Deterministic conjuncts which reference only the right side of the join. + * @param bothSides Conjuncts which are nondeterministic, occur after a nondeterministic conjunct, + * or reference both left and right sides of the join. + * @param full The full join condition. + */ + case class JoinConditionSplitPredicates( + leftSideOnly: Option[Expression], + rightSideOnly: Option[Expression], + bothSides: Option[Expression], + full: Option[Expression]) { + override def toString(): String = { + s"condition = [ leftOnly = ${leftSideOnly.map(_.toString).getOrElse("null")}, " + + s"rightOnly = ${rightSideOnly.map(_.toString).getOrElse("null")}, " + + s"both = ${bothSides.map(_.toString).getOrElse("null")}, " + + s"full = ${full.map(_.toString).getOrElse("null")} ]" + } + } + + object JoinConditionSplitPredicates extends PredicateHelper { + def apply(condition: Option[Expression], left: SparkPlan, right: SparkPlan): + JoinConditionSplitPredicates = { + // Split the condition into 3 parts: + // * Conjuncts that can be evaluated on only the left input. + // * Conjuncts that can be evaluated on only the right input. + // * Conjuncts that require both left and right input. + // + // Note that we treat nondeterministic conjuncts as though they require both left and right + // input. To maintain their semantics, they need to be evaluated exactly once per joined row. + val (leftCondition, rightCondition, joinedCondition) = { + if (condition.isEmpty) { + (None, None, None) + } else { + // Span rather than partition, because nondeterministic expressions don't commute + // across AND. + val (deterministicConjuncts, nonDeterministicConjuncts) = + splitConjunctivePredicates(condition.get).span(_.deterministic) + + val (leftConjuncts, nonLeftConjuncts) = deterministicConjuncts.partition { cond => + cond.references.subsetOf(left.outputSet) + } + + val (rightConjuncts, nonRightConjuncts) = deterministicConjuncts.partition { cond => + cond.references.subsetOf(right.outputSet) + } + + ( + leftConjuncts.reduceOption(And), + rightConjuncts.reduceOption(And), + (nonLeftConjuncts.intersect(nonRightConjuncts) ++ nonDeterministicConjuncts) + .reduceOption(And) + ) + } + } + + JoinConditionSplitPredicates(leftCondition, rightCondition, joinedCondition, condition) + } + } + /** Get the predicates defining the state watermarks for both sides of the join */ def getStateWatermarkPredicates( leftAttributes: Seq[Attribute], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index d32617275aadc..54eb863dacc83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -365,6 +365,24 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with } } } + + test("join between three streams") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val input3 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue") + val df2 = input2.toDF.select('value as "middleKey", ('value * 3) as "middleValue") + val df3 = input3.toDF.select('value as "rightKey", ('value * 5) as "rightValue") + + val joined = df1.join(df2, expr("leftKey = middleKey")).join(df3, expr("rightKey = middleKey")) + + testStream(joined)( + AddData(input1, 1, 5), + AddData(input2, 1, 5, 10), + AddData(input3, 5, 10), + CheckLastBatch((5, 10, 5, 15, 5, 25))) + } } class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { @@ -405,6 +423,130 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with (input1, input2, joined) } + test("left outer early state exclusion on left") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'leftValue > 4, + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 3, 4, 5), + // The left rows with leftValue <= 4 should generate their outer join row now and + // not get added to the state. + CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more outer join rows when the watermark advances. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch((20, 30, 40, "60")) + ) + } + + test("left outer early state exclusion on right") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'rightValue.cast("int") > 7, + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 3, 4, 5), + AddData(rightInput, 1, 2, 3), + // The right rows with value <= 7 should never be added to the state. + CheckLastBatch(Row(3, 10, 6, "9")), + assertNumStateRows(total = 4, updated = 4), + // When the watermark advances, we get the outer join rows just as we would if they + // were added but didn't match the full join condition. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + ) + } + + test("right outer early state exclusion on left") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'leftValue > 4, + "right_outer") + .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 3, 4, 5), + // The left rows with value <= 4 should never be added to the state. + CheckLastBatch(Row(3, 10, 6, "9")), + assertNumStateRows(total = 4, updated = 4), + // When the watermark advances, we get the outer join rows just as we would if they + // were added but didn't match the full join condition. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + ) + } + + test("right outer early state exclusion on right") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'rightValue.cast("int") > 7, + "right_outer") + .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 3, 4, 5), + AddData(rightInput, 1, 2, 3), + // The right rows with rightValue <= 7 should generate their outer join row now and + // not get added to the state. + CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more outer join rows when the watermark advances. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch((20, 30, 40, "60")) + ) + } + test("windowed left outer join") { val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer") @@ -495,7 +637,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // When the join condition isn't true, the outer null rows must be generated, even if the join // keys themselves have a match. - test("left outer join with non-key condition violated on left") { + test("left outer join with non-key condition violated") { val (leftInput, simpleLeftDf) = setupStream("left", 2) val (rightInput, simpleRightDf) = setupStream("right", 3) @@ -513,14 +655,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // leftValue <= 10 should generate outer join rows even though it matches right keys AddData(leftInput, 1, 2, 3), AddData(rightInput, 1, 2, 3), - CheckLastBatch(), + CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), AddData(leftInput, 20), AddData(rightInput, 21), CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), + assertNumStateRows(total = 5, updated = 2), AddData(rightInput, 20), CheckLastBatch( - Row(20, 30, 40, 60), Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), // leftValue and rightValue both satisfying condition should not generate outer join rows AddData(leftInput, 40, 41), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala new file mode 100644 index 0000000000000..2a854e37bf0df --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec, SparkPlan} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates +import org.apache.spark.sql.types._ + +class StreamingSymmetricHashJoinHelperSuite extends StreamTest { + import org.apache.spark.sql.functions._ + + val leftAttributeA = AttributeReference("a", IntegerType)() + val leftAttributeB = AttributeReference("b", IntegerType)() + val rightAttributeC = AttributeReference("c", IntegerType)() + val rightAttributeD = AttributeReference("d", IntegerType)() + val leftColA = new Column(leftAttributeA) + val leftColB = new Column(leftAttributeB) + val rightColC = new Column(rightAttributeC) + val rightColD = new Column(rightAttributeD) + + val left = new LocalTableScanExec(Seq(leftAttributeA, leftAttributeB), Seq()) + val right = new LocalTableScanExec(Seq(rightAttributeC, rightAttributeD), Seq()) + + test("empty") { + val split = JoinConditionSplitPredicates(None, left, right) + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.isEmpty) + assert(split.full.isEmpty) + } + + test("only literals") { + // Literal-only conjuncts end up on the left side because that's the first bucket they fit in. + // There's no semantic reason they couldn't be in any bucket. + val predicate = (lit(1) < lit(5) && lit(6) < lit(7) && lit(0) === lit(-1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains(predicate)) + assert(split.rightSideOnly.contains(predicate)) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("only left") { + val predicate = (leftColA > lit(1) && leftColB > lit(5) && leftColA < leftColB).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains(predicate)) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("only right") { + val predicate = (rightColC > lit(1) && rightColD > lit(5) && rightColD < rightColC).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.contains(predicate)) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("mixed conjuncts") { + val predicate = + (leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC).expr)) + assert(split.full.contains(predicate)) + } + + test("conjuncts after nondeterministic") { + // All conjuncts after a nondeterministic conjunct shouldn't be split because they don't + // commute across it. + val predicate = + (rand() > lit(0) + && leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.contains(predicate)) + assert(split.full.contains(predicate)) + } + + + test("conjuncts before nondeterministic") { + val randCol = rand() + val predicate = + (leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1) + && randCol > lit(0)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC && randCol > lit(0)).expr)) + assert(split.full.contains(predicate)) + } +} From 28f9f3f22511e9f2f900764d9bd5b90d2eeee773 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 17 Oct 2017 12:50:41 -0700 Subject: [PATCH 725/779] [SPARK-22271][SQL] mean overflows and returns null for some decimal variables ## What changes were proposed in this pull request? In Average.scala, it has ``` override lazy val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) Cast(Cast(sum, dt) / Cast(count, dt), resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } def setChild (newchild: Expression) = { child = newchild } ``` It is possible that Cast(count, dt), resultType) will make the precision of the decimal number bigger than 38, and this causes over flow. Since count is an integer and doesn't need a scale, I will cast it using DecimalType.bounded(38,0) ## How was this patch tested? In DataFrameSuite, I will add a test case. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Huaxin Gao Closes #19496 from huaxingao/spark-22271. --- .../sql/catalyst/expressions/aggregate/Average.scala | 3 ++- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c423e17169e85..708bdbfc36058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -80,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, dt), resultType) + Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), + resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50de2fd3bca8d..473c355cf3c7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2105,4 +2105,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + + test("SPARK-22271: mean overflows and returns null for some decimal variables") { + val d = 0.034567890 + val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") + val result = df.select('DecimalCol cast DecimalType(38, 33)) + .select(col("DecimalCol")).describe() + val mean = result.select("DecimalCol").where($"summary" === "mean") + assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) + } } From 1437e344ec0c29a44a19f4513986f5f184c44695 Mon Sep 17 00:00:00 2001 From: Michael Mior Date: Tue, 17 Oct 2017 14:30:52 -0700 Subject: [PATCH 726/779] [SPARK-22050][CORE] Allow BlockUpdated events to be optionally logged to the event log ## What changes were proposed in this pull request? I see that block updates are not logged to the event log. This makes sense as a default for performance reasons. However, I find it helpful when trying to get a better understanding of caching for a job to be able to log these updates. This PR adds a configuration setting `spark.eventLog.blockUpdates` (defaulting to false) which allows block updates to be recorded in the log. This contribution is original work which is licensed to the Apache Spark project. ## How was this patch tested? Current and additional unit tests. Author: Michael Mior Closes #19263 from michaelmior/log-block-updates. --- .../spark/internal/config/package.scala | 23 +++++++++++++ .../scheduler/EventLoggingListener.scala | 18 ++++++---- .../org/apache/spark/util/JsonProtocol.scala | 34 +++++++++++++++++-- .../scheduler/EventLoggingListenerSuite.scala | 2 ++ .../apache/spark/util/JsonProtocolSuite.scala | 27 +++++++++++++++ docs/configuration.md | 8 +++++ 6 files changed, 104 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e7b406af8d9b1..0c36bdcdd2904 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -41,6 +41,29 @@ package object config { .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") + private[spark] val EVENT_LOG_COMPRESS = + ConfigBuilder("spark.eventLog.compress") + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_BLOCK_UPDATES = + ConfigBuilder("spark.eventLog.logBlockUpdates.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_TESTING = + ConfigBuilder("spark.eventLog.testing") + .internal() + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.eventLog.buffer.kb") + .bytesConf(ByteUnit.KiB) + .createWithDefaultString("100k") + + private[spark] val EVENT_LOG_OVERWRITE = + ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 9dafa0b7646bf..a77adc5ff3545 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -37,6 +37,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SPARK_VERSION, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -45,6 +46,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} * * Event logging is specified by the following configurable parameters: * spark.eventLog.enabled - Whether event logging is enabled. + * spark.eventLog.logBlockUpdates.enabled - Whether to log block updates * spark.eventLog.compress - Whether to compress logged events * spark.eventLog.overwrite - Whether to overwrite any existing files. * spark.eventLog.dir - Path to the directory in which events are logged. @@ -64,10 +66,11 @@ private[spark] class EventLoggingListener( this(appId, appAttemptId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) - private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) - private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) - private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) - private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 + private val shouldCompress = sparkConf.get(EVENT_LOG_COMPRESS) + private val shouldOverwrite = sparkConf.get(EVENT_LOG_OVERWRITE) + private val shouldLogBlockUpdates = sparkConf.get(EVENT_LOG_BLOCK_UPDATES) + private val testing = sparkConf.get(EVENT_LOG_TESTING) + private val outputBufferSize = sparkConf.get(EVENT_LOG_OUTPUT_BUFFER_SIZE).toInt private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) private val compressionCodec = if (shouldCompress) { @@ -216,8 +219,11 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } - // No-op because logging every update would be overkill - override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { + if (shouldLogBlockUpdates) { + logEvent(event, flushLogger = true) + } + } // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 8406826a228db..5e60218c5740b 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -98,8 +98,8 @@ private[spark] object JsonProtocol { logStartToJson(logStart) case metricsUpdate: SparkListenerExecutorMetricsUpdate => executorMetricsUpdateToJson(metricsUpdate) - case blockUpdated: SparkListenerBlockUpdated => - throw new MatchError(blockUpdated) // TODO(ekl) implement this + case blockUpdate: SparkListenerBlockUpdated => + blockUpdateToJson(blockUpdate) case _ => parse(mapper.writeValueAsString(event)) } } @@ -246,6 +246,12 @@ private[spark] object JsonProtocol { }) } + def blockUpdateToJson(blockUpdate: SparkListenerBlockUpdated): JValue = { + val blockUpdatedInfo = blockUpdatedInfoToJson(blockUpdate.blockUpdatedInfo) + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockUpdate) ~ + ("Block Updated Info" -> blockUpdatedInfo) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -458,6 +464,14 @@ private[spark] object JsonProtocol { ("Log Urls" -> mapToJson(executorInfo.logUrlMap)) } + def blockUpdatedInfoToJson(blockUpdatedInfo: BlockUpdatedInfo): JValue = { + ("Block Manager ID" -> blockManagerIdToJson(blockUpdatedInfo.blockManagerId)) ~ + ("Block ID" -> blockUpdatedInfo.blockId.toString) ~ + ("Storage Level" -> storageLevelToJson(blockUpdatedInfo.storageLevel)) ~ + ("Memory Size" -> blockUpdatedInfo.memSize) ~ + ("Disk Size" -> blockUpdatedInfo.diskSize) + } + /** ------------------------------ * * Util JSON serialization methods | * ------------------------------- */ @@ -515,6 +529,7 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + val blockUpdate = Utils.getFormattedClassName(SparkListenerBlockUpdated) } def sparkEventFromJson(json: JValue): SparkListenerEvent = { @@ -538,6 +553,7 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case `blockUpdate` => blockUpdateFromJson(json) case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) .asInstanceOf[SparkListenerEvent] } @@ -676,6 +692,11 @@ private[spark] object JsonProtocol { SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) } + def blockUpdateFromJson(json: JValue): SparkListenerBlockUpdated = { + val blockUpdatedInfo = blockUpdatedInfoFromJson(json \ "Block Updated Info") + SparkListenerBlockUpdated(blockUpdatedInfo) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ @@ -989,6 +1010,15 @@ private[spark] object JsonProtocol { new ExecutorInfo(executorHost, totalCores, logUrls) } + def blockUpdatedInfoFromJson(json: JValue): BlockUpdatedInfo = { + val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") + val blockId = BlockId((json \ "Block ID").extract[String]) + val storageLevel = storageLevelFromJson(json \ "Storage Level") + val memorySize = (json \ "Memory Size").extract[Long] + val diskSize = (json \ "Disk Size").extract[Long] + BlockUpdatedInfo(blockManagerId, blockId, storageLevel, memorySize, diskSize) + } + /** -------------------------------- * * Util JSON deserialization methods | * --------------------------------- */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 6b42775ccb0f6..a9e92fa07b9dd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -228,6 +228,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit SparkListenerStageCompleted, SparkListenerTaskStart, SparkListenerTaskEnd, + SparkListenerBlockUpdated, SparkListenerApplicationEnd).map(Utils.getFormattedClassName) Utils.tryWithSafeFinally { val logStart = SparkListenerLogStart(SPARK_VERSION) @@ -291,6 +292,7 @@ object EventLoggingListenerSuite { def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") + conf.set("spark.eventLog.logBlockUpdates.enabled", "true") conf.set("spark.eventLog.testing", "true") conf.set("spark.eventLog.dir", logDir.toString) compressionCodec.foreach { codec => diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a1a858765a7d4..4abbb8e7894f5 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -96,6 +96,9 @@ class JsonProtocolSuite extends SparkFunSuite { .zipWithIndex.map { case (a, i) => a.copy(id = i) } SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates))) } + val blockUpdated = + SparkListenerBlockUpdated(BlockUpdatedInfo(BlockManagerId("Stars", + "In your multitude...", 300), RDDBlockId(0, 0), StorageLevel.MEMORY_ONLY, 100L, 0L)) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -120,6 +123,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(nodeBlacklisted, nodeBlacklistedJsonString) testEvent(nodeUnblacklisted, nodeUnblacklistedJsonString) testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) + testEvent(blockUpdated, blockUpdatedJsonString) } test("Dependent Classes") { @@ -2007,6 +2011,29 @@ private[spark] object JsonProtocolSuite extends Assertions { |} """.stripMargin + private val blockUpdatedJsonString = + """ + |{ + | "Event": "SparkListenerBlockUpdated", + | "Block Updated Info": { + | "Block Manager ID": { + | "Executor ID": "Stars", + | "Host": "In your multitude...", + | "Port": 300 + | }, + | "Block ID": "rdd_0_0", + | "Storage Level": { + | "Use Disk": false, + | "Use Memory": true, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Memory Size": 100, + | "Disk Size": 0 + | } + |} + """.stripMargin + private val executorBlacklistedJsonString = s""" |{ diff --git a/docs/configuration.md b/docs/configuration.md index bb06c8faaaed7..7b9e16a382449 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -714,6 +714,14 @@ Apart from these, the following properties are also available, and may be useful + + + + + From f3137feecd30c74c47dbddb0e22b4ddf8cf2f912 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 17 Oct 2017 20:09:12 -0700 Subject: [PATCH 727/779] [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState ## What changes were proposed in this pull request? Complex state-updating and/or timeout-handling logic in mapGroupsWithState functions may require taking decisions based on the current event-time watermark and/or processing time. Currently, you can use the SQL function `current_timestamp` to get the current processing time, but it needs to be passed inserted in every row with a select, and then passed through the encoder, which isn't efficient. Furthermore, there is no way to get the current watermark. This PR exposes both of them through the GroupState API. Additionally, it also cleans up some of the GroupState docs. ## How was this patch tested? New unit tests Author: Tathagata Das Closes #19495 from tdas/SPARK-22278. --- .../apache/spark/sql/execution/objects.scala | 8 +- .../FlatMapGroupsWithStateExec.scala | 7 +- .../execution/streaming/GroupStateImpl.scala | 50 +++--- .../spark/sql/streaming/GroupState.scala | 92 ++++++---- .../FlatMapGroupsWithStateSuite.scala | 160 +++++++++++++++--- 5 files changed, 238 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c68975bea490f..d861109436a08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.plans.logical.{FunctionUtils, LogicalGroupState} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.streaming.GroupStateTimeout @@ -361,8 +361,12 @@ object MapGroupsExec { outputObjAttr: Attribute, timeoutConf: GroupStateTimeout, child: SparkPlan): MapGroupsExec = { + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } val f = (key: Any, values: Iterator[Any]) => { - func(key, values, GroupStateImpl.createForBatch(timeoutConf)) + func(key, values, GroupStateImpl.createForBatch(timeoutConf, watermarkPresent)) } new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index c81f1a8142784..29f38fab3f896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -61,6 +61,10 @@ case class FlatMapGroupsWithStateExec( private val isTimeoutEnabled = timeoutConf != NoTimeout val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled) + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -190,7 +194,8 @@ case class FlatMapGroupsWithStateExec( batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, - hasTimedOut) + hasTimedOut, + watermarkPresent) // Call function, get the returned objects and convert them to rows val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 4401e86936af9..7f65e3ea9dd5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -43,7 +43,8 @@ private[sql] class GroupStateImpl[S] private( batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, - override val hasTimedOut: Boolean) extends GroupState[S] { + override val hasTimedOut: Boolean, + watermarkPresent: Boolean) extends GroupState[S] { private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined @@ -90,7 +91,7 @@ private[sql] class GroupStateImpl[S] private( if (timeoutConf != ProcessingTimeTimeout) { throw new UnsupportedOperationException( "Cannot set timeout duration without enabling processing time timeout in " + - "map/flatMapGroupsWithState") + "[map|flatMap]GroupsWithState") } if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") @@ -102,10 +103,6 @@ private[sql] class GroupStateImpl[S] private( setTimeoutDuration(parseDuration(duration)) } - @throws[IllegalArgumentException]("if 'timestampMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestampMs: Long): Unit = { checkTimeoutTimestampAllowed() if (timestampMs <= 0) { @@ -119,32 +116,34 @@ private[sql] class GroupStateImpl[S] private( timeoutTimestamp = timestampMs } - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs) } - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestamp: Date): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(timestamp.getTime) } - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestamp: Date, additionalDuration: String): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration)) } + override def getCurrentWatermarkMs(): Long = { + if (!watermarkPresent) { + throw new UnsupportedOperationException( + "Cannot get event time watermark timestamp without setting watermark before " + + "[map|flatMap]GroupsWithState") + } + eventTimeWatermarkMs + } + + override def getCurrentProcessingTimeMs(): Long = { + batchProcessingTimeMs + } + override def toString: String = { s"GroupState(${getOption.map(_.toString).getOrElse("")})" } @@ -187,7 +186,7 @@ private[sql] class GroupStateImpl[S] private( if (timeoutConf != EventTimeTimeout) { throw new UnsupportedOperationException( "Cannot set timeout timestamp without enabling event time timeout in " + - "map/flatMapGroupsWithState") + "[map|flatMapGroupsWithState") } } } @@ -202,17 +201,22 @@ private[sql] object GroupStateImpl { batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, - hasTimedOut: Boolean): GroupStateImpl[S] = { + hasTimedOut: Boolean, + watermarkPresent: Boolean): GroupStateImpl[S] = { new GroupStateImpl[S]( - optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut) + optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, + timeoutConf, hasTimedOut, watermarkPresent) } - def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = { + def createForBatch( + timeoutConf: GroupStateTimeout, + watermarkPresent: Boolean): GroupStateImpl[Any] = { new GroupStateImpl[Any]( optionalValue = None, - batchProcessingTimeMs = NO_TIMESTAMP, + batchProcessingTimeMs = System.currentTimeMillis, eventTimeWatermarkMs = NO_TIMESTAMP, timeoutConf, - hasTimedOut = false) + hasTimedOut = false, + watermarkPresent) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 04a956b70b022..e9510c903acae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -205,11 +205,7 @@ trait GroupState[S] extends LogicalGroupState[S] { /** Get the state value as a scala Option. */ def getOption: Option[S] - /** - * Update the value of the state. Note that `null` is not a valid value, and it throws - * IllegalArgumentException. - */ - @throws[IllegalArgumentException]("when updating with null") + /** Update the value of the state. */ def update(newState: S): Unit /** Remove this state. */ @@ -217,80 +213,114 @@ trait GroupState[S] extends LogicalGroupState[S] { /** * Whether the function has been called because the key has timed out. - * @note This can return true only when timeouts are enabled in `[map/flatmap]GroupsWithStates`. + * @note This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`. */ def hasTimedOut: Boolean + /** * Set the timeout duration in ms for this key. * - * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'durationMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(durationMs: Long): Unit + /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. * - * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") - @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(duration: String): Unit - @throws[IllegalArgumentException]("if 'timestampMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as milliseconds in epoch time. * This timestamp cannot be older than the current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ + @throws[IllegalArgumentException]( + "if 'timestampMs' is not positive or less than the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long): Unit - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as milliseconds in epoch time and an additional * duration as a string (e.g. "1 hour", "2 days", etc.). * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[IllegalArgumentException]( + "if 'additionalDuration' is invalid or the final timeout timestamp is less than " + + "the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as a java.sql.Date. * This timestamp cannot be older than the current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date): Unit - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as a java.sql.Date and an additional * duration as a string (e.g. "1 hour", "2 days", etc.). * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit + + + /** + * Get the current event time watermark as milliseconds in epoch time. + * + * @note In a streaming query, this can be called only when watermark is set before calling + * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1. + */ + @throws[UnsupportedOperationException]( + "if watermark has not been set before in [map|flatMap]GroupsWithState") + def getCurrentWatermarkMs(): Long + + + /** + * Get the current processing time as milliseconds in epoch time. + * @note In a streaming query, this will return a constant value throughout the duration of a + * trigger, even if the trigger is re-executed. + */ + def getCurrentProcessingTimeMs(): Long } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index aeb83835f981a..af08186aadbb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -21,6 +21,7 @@ import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll +import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -48,6 +49,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest import testImplicits._ import GroupStateImpl._ import GroupStateTimeout._ + import FlatMapGroupsWithStateSuite._ override def afterAll(): Unit = { super.afterAll() @@ -77,13 +79,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // === Tests for state in streaming queries === // Updating empty state - state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + None, 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + Some("2"), 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -104,8 +108,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("GroupState - setTimeout - with NoTimeout") { for (initValue <- Seq(None, Some(5))) { val states = Seq( - GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false), - GroupStateImpl.createForBatch(NoTimeout) + GroupStateImpl.createForStreaming( + initValue, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false), + GroupStateImpl.createForBatch(NoTimeout, watermarkPresent = false) ) for (state <- states) { // for streaming queries @@ -122,7 +127,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("GroupState - setTimeout - with ProcessingTimeTimeout") { // for streaming queries var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state @@ -143,7 +148,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) // for batch queries - state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + state = GroupStateImpl.createForBatch( + ProcessingTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) @@ -160,7 +166,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("GroupState - setTimeout - with EventTimeTimeout") { var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, EventTimeTimeout, false) + None, 1000, 1000, EventTimeTimeout, false, watermarkPresent = true) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) @@ -182,7 +188,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testTimeoutDurationNotAllowed[UnsupportedOperationException](state) // for batch queries - state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + state = GroupStateImpl.createForBatch(EventTimeTimeout, watermarkPresent = false) + .asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.setTimeoutTimestamp(5000) @@ -209,7 +216,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } @@ -227,7 +234,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false, watermarkPresent = false) testIllegalTimeout { state.setTimeoutTimestamp(-10000) } @@ -259,29 +266,92 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // for streaming queries for (initState <- Seq(None, Some(5))) { val state1 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = false) + initState, 1000, 1000, timeoutConf, hasTimedOut = false, watermarkPresent = false) assert(state1.hasTimedOut === false) val state2 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = true) + initState, 1000, 1000, timeoutConf, hasTimedOut = true, watermarkPresent = false) assert(state2.hasTimedOut === true) } // for batch queries - assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false) + assert( + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent = false).hasTimedOut === false) + } + } + + test("GroupState - getCurrentWatermarkMs") { + def streamingState(timeoutConf: GroupStateTimeout, watermark: Option[Long]): GroupState[Int] = { + GroupStateImpl.createForStreaming( + None, 1000, watermark.getOrElse(-1), timeoutConf, + hasTimedOut = false, watermark.nonEmpty) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + def assertWrongTimeoutError(test: => Unit): Unit = { + val e = intercept[UnsupportedOperationException] { test } + assert(e.getMessage.contains( + "Cannot get event time watermark timestamp without setting watermark")) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + // Tests for getCurrentWatermarkMs in streaming queries + assertWrongTimeoutError { streamingState(timeoutConf, None).getCurrentWatermarkMs() } + assert(streamingState(timeoutConf, Some(1000)).getCurrentWatermarkMs() === 1000) + assert(streamingState(timeoutConf, Some(2000)).getCurrentWatermarkMs() === 2000) + + // Tests for getCurrentWatermarkMs in batch queries + assertWrongTimeoutError { + batchState(timeoutConf, watermarkPresent = false).getCurrentWatermarkMs() + } + assert(batchState(timeoutConf, watermarkPresent = true).getCurrentWatermarkMs() === -1) + } + } + + test("GroupState - getCurrentProcessingTimeMs") { + def streamingState( + timeoutConf: GroupStateTimeout, + procTime: Long, + watermarkPresent: Boolean): GroupState[Int] = { + GroupStateImpl.createForStreaming( + None, procTime, -1, timeoutConf, hasTimedOut = false, watermarkPresent = false) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + for (watermarkPresent <- Seq(false, true)) { + // Tests for getCurrentProcessingTimeMs in streaming queries + assert(streamingState(timeoutConf, NO_TIMESTAMP, watermarkPresent) + .getCurrentProcessingTimeMs() === -1) + assert(streamingState(timeoutConf, 1000, watermarkPresent) + .getCurrentProcessingTimeMs() === 1000) + assert(streamingState(timeoutConf, 2000, watermarkPresent) + .getCurrentProcessingTimeMs() === 2000) + + // Tests for getCurrentProcessingTimeMs in batch queries + val currentTime = System.currentTimeMillis() + assert(batchState(timeoutConf, watermarkPresent).getCurrentProcessingTimeMs >= currentTime) + } } } + test("GroupState - primitive type") { var intState = GroupStateImpl.createForStreaming[Int]( - None, 1000, 1000, NoTimeout, hasTimedOut = false) + None, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) intState = GroupStateImpl.createForStreaming[Int]( - Some(10), 1000, 1000, NoTimeout, hasTimedOut = false) + Some(10), 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -304,7 +374,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithData( testName + "no update", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = priorState) // should not change @@ -342,7 +416,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithData( s"$timeoutConf - $testName - no update", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = timeoutConf, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -466,7 +544,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithTimeout( s"$timeoutConf - should timeout - no update/remove", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = timeoutConf, priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change @@ -525,6 +607,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -647,6 +731,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.exists) throw new IllegalArgumentException("state.exists should be false") Iterator((key, values.size)) } @@ -660,6 +747,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.hasTimedOut) { state.remove() Iterator((key, "-1")) @@ -713,10 +803,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("flatMapGroupsWithState - streaming with event time timeout + watermark") { // Function to maintain the max event time // Returns the max event time in the state, or -1 if the state was removed by timeout - val stateFunc = ( - key: String, - values: Iterator[(String, Long)], - state: GroupState[Long]) => { + val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + val timeoutDelay = 5 if (key != "a") { Iterator.empty @@ -760,6 +850,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -802,7 +894,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // - no initial state // - timeouts operations work, does not throw any error [SPARK-20792] // - works with primitive state type + // - can get processing time val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.exists) throw new IllegalArgumentException("state.exists should be false") state.setTimeoutTimestamp(0, "1 hour") state.update(10) @@ -1090,4 +1186,24 @@ object FlatMapGroupsWithStateSuite { override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) override def hasCommitted: Boolean = true } + + def assertCanGetProcessingTime(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCanGetWatermark(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCannotGetWatermark(func: => Unit): Unit = { + try { + func + } catch { + case u: UnsupportedOperationException => + return + case _ => + throw new TestFailedException("Unexpected exception when trying to get watermark", 20) + } + throw new TestFailedException("Could get watermark when not expected", 20) + } } From 72561ecf4b611d68f8bf695ddd0c4c2cce3a29d9 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 18 Oct 2017 20:59:40 +0800 Subject: [PATCH 728/779] [SPARK-22266][SQL] The same aggregate function was evaluated multiple times ## What changes were proposed in this pull request? To let the same aggregate function that appear multiple times in an Aggregate be evaluated only once, we need to deduplicate the aggregate expressions. The original code was trying to use a "distinct" call to get a set of aggregate expressions, but did not work, since the "distinct" did not compare semantic equality. And even if it did, further work should be done in result expression rewriting. In this PR, I changed the "set" to a map mapping the semantic identity of a aggregate expression to itself. Thus, later on, when rewriting result expressions (i.e., output expressions), the aggregate expression reference can be fixed. ## How was this patch tested? Added a new test in SQLQuerySuite Author: maryannxue Closes #19488 from maryannxue/spark-22266. --- .../sql/catalyst/planning/patterns.scala | 16 +++++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 +++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8d034c21a4960..cc391aae55787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -205,14 +205,17 @@ object PhysicalAggregation { case logical.Aggregate(groupingExpressions, resultExpressions, child) => // A single aggregate expression might appear multiple times in resultExpressions. // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. + // build a set of semantically distinct aggregate expressions and re-write expressions so + // that they reference the single copy of the aggregate function which actually gets computed. + // Non-deterministic aggregate expressions are not deduplicated. + val equivalentAggregateExpressions = new EquivalentExpressions val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { - case agg: AggregateExpression => agg + // addExpr() always returns false for non-deterministic expressions and do not add them. + case agg: AggregateExpression + if (!equivalentAggregateExpressions.addExpr(agg)) => agg } - }.distinct + } val namedGroupingExpressions = groupingExpressions.map { case ne: NamedExpression => ne -> ne @@ -236,7 +239,8 @@ object PhysicalAggregation { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, // so replace each aggregate expression by its corresponding attribute in the set: - ae.resultAttribute + equivalentAggregateExpressions.getEquivalentExprs(ae).headOption + .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f0c58e2e5bf45..caf332d050d7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2715,4 +2716,29 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1, 1, 1)) } } + + test("SRARK-22266: the same aggregate function was calculated multiple times") { + val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 1) + checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil) + } + + test("Non-deterministic aggregate functions should not be deduplicated") { + val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 2) + } } From 1f25d8683a84a479fd7fc77b5a1ea980289b681b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 18 Oct 2017 09:14:46 -0700 Subject: [PATCH 729/779] [SPARK-22249][FOLLOWUP][SQL] Check if list of value for IN is empty in the optimizer ## What changes were proposed in this pull request? This PR addresses the comments by gatorsmile on [the previous PR](https://github.com/apache/spark/pull/19494). ## How was this patch tested? Previous UT and added UT. Author: Marco Gaido Closes #19522 from mgaido91/SPARK-22249_FOLLOWUP. --- .../execution/columnar/InMemoryTableScanExec.scala | 4 ++-- .../columnar/InMemoryColumnarQuerySuite.scala | 12 +++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 846ec03e46a12..139da1c519da2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -102,8 +102,8 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(_: AttributeReference, list: Seq[Expression]) if list.isEmpty => Literal.FalseLiteral - case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) => + case In(a: AttributeReference, list: Seq[Expression]) + if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 75d17bc79477d..2f249c850a088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,8 +21,9 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -444,4 +445,13 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(dfNulls.filter($"id".isin(2, 3)).count() == 0) dfNulls.unpersist() } + + test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { + val attribute = AttributeReference("a", IntegerType)() + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, + LocalTableScanExec(Seq(attribute), Nil), None) + val tableScanExec = InMemoryTableScanExec(Seq(attribute), + Seq(In(attribute, Nil)), testRelation) + assert(tableScanExec.partitionFilters.isEmpty) + } } From 52facb0062a4253fa45ac0c633d0510a9b684a62 Mon Sep 17 00:00:00 2001 From: Valeriy Avanesov Date: Wed, 18 Oct 2017 10:46:46 -0700 Subject: [PATCH 730/779] [SPARK-14371][MLLIB] OnlineLDAOptimizer should not collect stats for each doc in mini-batch to driver Hi, # What changes were proposed in this pull request? as it was proposed by jkbradley , ```gammat``` are not collected to the driver anymore. # How was this patch tested? existing test suite. Author: Valeriy Avanesov Author: Valeriy Avanesov Closes #18924 from akopich/master. --- .../spark/mllib/clustering/LDAOptimizer.scala | 82 +++++++++++++------ 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index d633893e55f55..693a2a31f026b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -26,6 +26,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.PeriodicGraphCheckpointer +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -259,7 +260,7 @@ final class EMLDAOptimizer extends LDAOptimizer { */ @Since("1.4.0") @DeveloperApi -final class OnlineLDAOptimizer extends LDAOptimizer { +final class OnlineLDAOptimizer extends LDAOptimizer with Logging { // LDA common parameters private var k: Int = 0 @@ -462,31 +463,61 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta) val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape - - val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val optimizeDocConcentration = this.optimizeDocConcentration + // If and only if optimizeDocConcentration is set true, + // we calculate logphat in the same pass as other statistics. + // No calculation of loghat happens otherwise. + val logphatPartOptionBase = () => if (optimizeDocConcentration) { + Some(BDV.zeros[Double](k)) + } else { + None + } + + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) val stat = BDM.zeros[Double](k, vocabSize) - var gammaPart = List[BDV[Double]]() + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids).toDenseMatrix + sstats - gammaPart = gammad :: gammaPart + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) } - Iterator((stat, gammaPart)) - }.persist(StorageLevel.MEMORY_AND_DISK) - val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( - _ += _, _ += _) - val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) - stats.unpersist() + Iterator((stat, logphatPartOption, nonEmptyDocCount)) + } + + val elementWiseSum = ( + u: (BDM[Double], Option[BDV[Double]], Long), + v: (BDM[Double], Option[BDV[Double]], Long)) => { + u._1 += v._1 + u._2.foreach(_ += v._2.get) + (u._1, u._2, u._3 + v._3) + } + + val (statsSum: BDM[Double], logphatOption: Option[BDV[Double]], nonEmptyDocsN: Long) = stats + .treeAggregate((BDM.zeros[Double](k, vocabSize), logphatPartOptionBase(), 0L))( + elementWiseSum, elementWiseSum + ) + expElogbetaBc.destroy(false) - val batchResult = statsSum *:* expElogbeta.t + if (nonEmptyDocsN == 0) { + logWarning("No non-empty documents were submitted in the batch.") + // Therefore, there is no need to update any of the model parameters + return this + } + + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count - updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) - if (optimizeDocConcentration) updateAlpha(gammat) + val batchSize = (miniBatchFraction * corpusSize).ceil.toInt + updateLambda(batchResult, batchSize) + + logphatOption.foreach(_ /= nonEmptyDocsN.toDouble) + logphatOption.foreach(updateAlpha(_, nonEmptyDocsN)) + this } @@ -503,21 +534,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * Update alpha based on `gammat`, the inferred topic distributions for documents in the - * current mini-batch. Uses Newton-Rhapson method. + * Update alpha based on `logphat`. + * Uses Newton-Rhapson method. * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf) + * @param logphat Expectation of estimated log-posterior distribution of + * topics in a document averaged over the batch. + * @param nonEmptyDocsN number of non-empty documents */ - private def updateAlpha(gammat: BDM[Double]): Unit = { + private def updateAlpha(logphat: BDV[Double], nonEmptyDocsN: Double): Unit = { val weight = rho() - val N = gammat.rows.toDouble val alpha = this.alpha.asBreeze.toDenseVector - val logphat: BDV[Double] = - sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)).t / N - val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat) - val c = N * trigamma(sum(alpha)) - val q = -N * trigamma(alpha) + val gradf = nonEmptyDocsN * (-LDAUtils.dirichletExpectation(alpha) + logphat) + + val c = nonEmptyDocsN * trigamma(sum(alpha)) + val q = -nonEmptyDocsN * trigamma(alpha) val b = sum(gradf / q) / (1D / c + sum(1D / q)) val dalpha = -(gradf - b) / q From 6f1d0dea1cdda558c998179789b386f6e52b9e36 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 19 Oct 2017 13:30:55 +0800 Subject: [PATCH 731/779] [SPARK-22300][BUILD] Update ORC to 1.4.1 ## What changes were proposed in this pull request? Apache ORC 1.4.1 is released yesterday. - https://orc.apache.org/news/2017/10/16/ORC-1.4.1/ Like ORC-233 (Allow `orc.include.columns` to be empty), there are several important fixes. This PR updates Apache ORC dependency to use the latest one, 1.4.1. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #19521 from dongjoon-hyun/SPARK-22300. --- dev/deps/spark-deps-hadoop-2.6 | 6 +++--- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- pom.xml | 6 +++++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 76fcbd15869f1..6e2fc63d67108 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.3.jar +aircompressor-0.8.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -149,8 +149,8 @@ netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar opencsv-2.3.jar -orc-core-1.4.0-nohive.jar -orc-mapreduce-1.4.0-nohive.jar +orc-core-1.4.1-nohive.jar +orc-mapreduce-1.4.1-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index cb20072bf8b30..c2bbc253d723a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.3.jar +aircompressor-0.8.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -150,8 +150,8 @@ netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar opencsv-2.3.jar -orc-core-1.4.0-nohive.jar -orc-mapreduce-1.4.0-nohive.jar +orc-core-1.4.1-nohive.jar +orc-mapreduce-1.4.1-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 9fac8b1e53788..b9c972855204a 100644 --- a/pom.xml +++ b/pom.xml @@ -128,7 +128,7 @@ 1.2.1 10.12.1.1 1.8.2 - 1.4.0 + 1.4.1 nohive 1.6.0 9.3.20.v20170531 @@ -1712,6 +1712,10 @@ org.apache.hive hive-storage-api + + io.airlift + slice + From dc2714da50ecba1bf1fdf555a82a4314f763a76e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Oct 2017 14:56:48 +0800 Subject: [PATCH 732/779] [SPARK-22290][CORE] Avoid creating Hive delegation tokens when not necessary. Hive delegation tokens are only needed when the Spark driver has no access to the kerberos TGT. That happens only in two situations: - when using a proxy user - when using cluster mode without a keytab This change modifies the Hive provider so that it only generates delegation tokens in those situations, and tweaks the YARN AM so that it makes the proper user visible to the Hive code when running with keytabs, so that the TGT can be used instead of a delegation token. The effect of this change is that now it's possible to initialize multiple, non-concurrent SparkContext instances in the same JVM. Before, the second invocation would fail to fetch a new Hive delegation token, which then could make the second (or third or...) application fail once the token expired. With this change, the TGT will be used to authenticate to the HMS instead. This change also avoids polluting the current logged in user's credentials when launching applications. The credentials are copied only when running applications as a proxy user. This makes it possible to implement SPARK-11035 later, where multiple threads might be launching applications, and each app should have its own set of credentials. Tested by verifying HDFS and Hive access in following scenarios: - client and cluster mode - client and cluster mode with proxy user - client and cluster mode with principal / keytab - long-running cluster app with principal / keytab - pyspark app that creates (and stops) multiple SparkContext instances through its lifetime Author: Marcelo Vanzin Closes #19509 from vanzin/SPARK-22290. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 17 +++-- .../HBaseDelegationTokenProvider.scala | 4 +- .../HadoopDelegationTokenManager.scala | 2 +- .../HadoopDelegationTokenProvider.scala | 2 +- .../HadoopFSDelegationTokenProvider.scala | 4 +- .../HiveDelegationTokenProvider.scala | 20 +++++- docs/running-on-yarn.md | 9 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 69 +++++++++++++++---- .../org/apache/spark/deploy/yarn/Client.scala | 5 +- .../org/apache/spark/deploy/yarn/config.scala | 4 ++ .../sql/hive/client/HiveClientImpl.scala | 6 -- 11 files changed, 110 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 53775db251bc6..1fa10ab943f34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -61,13 +61,17 @@ class SparkHadoopUtil extends Logging { * do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems */ def runAsSparkUser(func: () => Unit) { + createSparkUser().doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) + } + + def createSparkUser(): UserGroupInformation = { val user = Utils.getCurrentUserName() - logDebug("running as user: " + user) + logDebug("creating UGI for user: " + user) val ugi = UserGroupInformation.createRemoteUser(user) transferCredentials(UserGroupInformation.getCurrentUser(), ugi) - ugi.doAs(new PrivilegedExceptionAction[Unit] { - def run: Unit = func() - }) + ugi } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { @@ -417,6 +421,11 @@ class SparkHadoopUtil extends Logging { creds.readTokenStorageStream(new DataInputStream(tokensBuf)) creds } + + def isProxyUser(ugi: UserGroupInformation): Boolean = { + ugi.getAuthenticationMethod() == UserGroupInformation.AuthenticationMethod.PROXY + } + } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index 78b0e6b2cbf39..5dcde4ec3a8a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -56,7 +56,9 @@ private[security] class HBaseDelegationTokenProvider None } - override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { hbaseConf(hadoopConf).get("hbase.security.authentication") == "kerberos" } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index c134b7ebe38fa..483d0deec8070 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -115,7 +115,7 @@ private[spark] class HadoopDelegationTokenManager( hadoopConf: Configuration, creds: Credentials): Long = { delegationTokenProviders.values.flatMap { provider => - if (provider.delegationTokensRequired(hadoopConf)) { + if (provider.delegationTokensRequired(sparkConf, hadoopConf)) { provider.obtainDelegationTokens(hadoopConf, sparkConf, creds) } else { logDebug(s"Service ${provider.serviceName} does not require a token." + diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala index 1ba245e84af4b..ed0905088ab25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala @@ -37,7 +37,7 @@ private[spark] trait HadoopDelegationTokenProvider { * Returns true if delegation tokens are required for this service. By default, it is based on * whether Hadoop security is enabled. */ - def delegationTokensRequired(hadoopConf: Configuration): Boolean + def delegationTokensRequired(sparkConf: SparkConf, hadoopConf: Configuration): Boolean /** * Obtain delegation tokens for this service and get the time of the next renewal. diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 300773c58b183..21ca669ea98f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -69,7 +69,9 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration nextRenewalDate } - def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { UserGroupInformation.isSecurityEnabled } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index b31cc595ed83b..ece5ce79c650d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -31,7 +31,9 @@ import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.Token import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.KEYTAB import org.apache.spark.util.Utils private[security] class HiveDelegationTokenProvider @@ -55,9 +57,21 @@ private[security] class HiveDelegationTokenProvider } } - override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { + // Delegation tokens are needed only when: + // - trying to connect to a secure metastore + // - either deploying in cluster mode without a keytab, or impersonating another user + // + // Other modes (such as client with or without keytab, or cluster mode with keytab) do not need + // a delegation token, since there's a valid kerberos TGT for the right user available to the + // driver, which is the only process that connects to the HMS. + val deployMode = sparkConf.get("spark.submit.deployMode", "client") UserGroupInformation.isSecurityEnabled && - hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty + hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty && + (SparkHadoopUtil.get.isProxyUser(UserGroupInformation.getCurrentUser()) || + (deployMode == "cluster" && !sparkConf.contains(KEYTAB))) } override def obtainDelegationTokens( @@ -83,7 +97,7 @@ private[security] class HiveDelegationTokenProvider val hive2Token = new Token[DelegationTokenIdentifier]() hive2Token.decodeFromUrlString(tokenStr) - logInfo(s"Get Token from hive metastore: ${hive2Token.toString}") + logDebug(s"Get Token from hive metastore: ${hive2Token.toString}") creds.addToken(new Text("hive.server2.delegation.token"), hive2Token) } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 432639588cc2b..9599d40c545b2 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -401,6 +401,15 @@ To use a custom metrics.properties for the application master and executors, upd Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) + + + + + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e227bff88f71d..f6167235f89e4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URI, URL} +import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} import scala.collection.mutable.HashMap @@ -28,6 +29,7 @@ import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -49,10 +51,7 @@ import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. */ -private[spark] class ApplicationMaster( - args: ApplicationMasterArguments, - client: YarnRMClient) - extends Logging { +private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends Logging { // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -62,6 +61,46 @@ private[spark] class ApplicationMaster( .asInstanceOf[YarnConfiguration] private val isClusterMode = args.userClass != null + private val ugi = { + val original = UserGroupInformation.getCurrentUser() + + // If a principal and keytab were provided, log in to kerberos, and set up a thread to + // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL + // of the TGT, use a configuration to define how often to check that a relogin is necessary. + // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed. + val principal = sparkConf.get(PRINCIPAL).orNull + val keytab = sparkConf.get(KEYTAB).orNull + if (principal != null && keytab != null) { + UserGroupInformation.loginUserFromKeytab(principal, keytab) + + val renewer = new Thread() { + override def run(): Unit = Utils.tryLogNonFatalError { + while (true) { + TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD)) + UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab() + } + } + } + renewer.setName("am-kerberos-renewer") + renewer.setDaemon(true) + renewer.start() + + // Transfer the original user's tokens to the new user, since that's needed to connect to + // YARN. It also copies over any delegation tokens that might have been created by the + // client, which will then be transferred over when starting executors (until new ones + // are created by the periodic task). + val newUser = UserGroupInformation.getCurrentUser() + SparkHadoopUtil.get.transferCredentials(original, newUser) + newUser + } else { + SparkHadoopUtil.get.createSparkUser() + } + } + + private val client = ugi.doAs(new PrivilegedExceptionAction[YarnRMClient]() { + def run: YarnRMClient = new YarnRMClient() + }) + // Default to twice the number of executors (twice the maximum number of executors if dynamic // allocation is enabled), with a minimum of 3. @@ -201,6 +240,13 @@ private[spark] class ApplicationMaster( } final def run(): Int = { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + def run: Unit = runImpl() + }) + exitCode + } + + private def runImpl(): Unit = { try { val appAttemptId = client.getAttemptId() @@ -254,11 +300,6 @@ private[spark] class ApplicationMaster( } } - // Call this to force generation of secret so it gets populated into the - // Hadoop UGI. This has to happen before the startUserApplication which does a - // doAs in order for the credentials to be passed on to the executor containers. - val securityMgr = new SecurityManager(sparkConf) - // If the credentials file config is present, we must periodically renew tokens. So create // a new AMDelegationTokenRenewer if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { @@ -284,6 +325,9 @@ private[spark] class ApplicationMaster( credentialRenewerThread.join() } + // Call this to force generation of secret so it gets populated into the Hadoop UGI. + val securityMgr = new SecurityManager(sparkConf) + if (isClusterMode) { runDriver(securityMgr) } else { @@ -297,7 +341,6 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, "Uncaught exception: " + e) } - exitCode } /** @@ -775,10 +818,8 @@ object ApplicationMaster extends Logging { sys.props(k) = v } } - SparkHadoopUtil.get.runAsSparkUser { () => - master = new ApplicationMaster(amArgs, new YarnRMClient) - System.exit(master.run()) - } + master = new ApplicationMaster(amArgs) + System.exit(master.run()) } private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 64b2b4d4db549..1fe25c4ddaabf 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -394,7 +394,10 @@ private[spark] class Client( if (credentials != null) { // Add credentials to current user's UGI, so that following operations don't need to use the // Kerberos tgt to get delegations again in the client side. - UserGroupInformation.getCurrentUser.addCredentials(credentials) + val currentUser = UserGroupInformation.getCurrentUser() + if (SparkHadoopUtil.get.isProxyUser(currentUser)) { + currentUser.addCredentials(credentials) + } logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 187803cc6050b..e1af8ba087d6e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -347,6 +347,10 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(Long.MaxValue) + private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1m") + // The list of cache-related config entries. This is used by Client and the AM to clean // up the environment so that these settings do not appear on the web UI. private[yarn] val CACHE_CONFIGS = Seq( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index a01c312d5e497..16c95c53b4201 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -111,12 +111,6 @@ private[hive] class HiveClientImpl( if (clientLoader.isolationOn) { // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) - // Set up kerberos credentials for UserGroupInformation.loginUser within current class loader - if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { - val principal = sparkConf.get("spark.yarn.principal") - val keytab = sparkConf.get("spark.yarn.keytab") - SparkHadoopUtil.get.loginUserFromKeytab(principal, keytab) - } try { newState() } finally { From 5a07aca4d464e96d75ea17bf6768e24b829872ec Mon Sep 17 00:00:00 2001 From: krishna-pandey Date: Thu, 19 Oct 2017 08:33:14 +0100 Subject: [PATCH 733/779] [SPARK-22188][CORE] Adding security headers for preventing XSS, MitM and MIME sniffing ## What changes were proposed in this pull request? The HTTP Strict-Transport-Security response header (often abbreviated as HSTS) is a security feature that lets a web site tell browsers that it should only be communicated with using HTTPS, instead of using HTTP. Note: The Strict-Transport-Security header is ignored by the browser when your site is accessed using HTTP; this is because an attacker may intercept HTTP connections and inject the header or remove it. When your site is accessed over HTTPS with no certificate errors, the browser knows your site is HTTPS capable and will honor the Strict-Transport-Security header. The HTTP X-XSS-Protection response header is a feature of Internet Explorer, Chrome and Safari that stops pages from loading when they detect reflected cross-site scripting (XSS) attacks. The HTTP X-Content-Type-Options response header is used to protect against MIME sniffing vulnerabilities. ## How was this patch tested? Checked on my system locally. screen shot 2017-10-03 at 6 49 20 pm Author: krishna-pandey Author: Krishna Pandey Closes #19419 from krishna-pandey/SPARK-22188. --- .../spark/internal/config/package.scala | 18 +++++++ .../org/apache/spark/ui/JettyUtils.scala | 9 ++++ docs/security.md | 47 +++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0c36bdcdd2904..6f0247b73070d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -452,6 +452,24 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val UI_X_XSS_PROTECTION = + ConfigBuilder("spark.ui.xXssProtection") + .doc("Value for HTTP X-XSS-Protection response header") + .stringConf + .createWithDefaultString("1; mode=block") + + private[spark] val UI_X_CONTENT_TYPE_OPTIONS = + ConfigBuilder("spark.ui.xContentTypeOptions.enabled") + .doc("Set to 'true' for setting X-Content-Type-Options HTTP response header to 'nosniff'") + .booleanConf + .createWithDefault(true) + + private[spark] val UI_STRICT_TRANSPORT_SECURITY = + ConfigBuilder("spark.ui.strictTransportSecurity") + .doc("Value for HTTP Strict Transport Security Response Header") + .stringConf + .createOptional + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") .doc("Class names of listeners to add to SparkContext during initialization.") .stringConf diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 5ee04dad6ed4d..0adeb4058b6e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -39,6 +39,7 @@ import org.json4s.jackson.JsonMethods.{pretty, render} import org.apache.spark.{SecurityManager, SparkConf, SSLOptions} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -89,6 +90,14 @@ private[spark] object JettyUtils extends Logging { val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") response.setHeader("X-Frame-Options", xFrameOptionsValue) + response.setHeader("X-XSS-Protection", conf.get(UI_X_XSS_PROTECTION)) + if (conf.get(UI_X_CONTENT_TYPE_OPTIONS)) { + response.setHeader("X-Content-Type-Options", "nosniff") + } + if (request.getScheme == "https") { + conf.get(UI_STRICT_TRANSPORT_SECURITY).foreach( + response.setHeader("Strict-Transport-Security", _)) + } response.getWriter.print(servletParams.extractFn(result)) } else { response.setStatus(HttpServletResponse.SC_FORBIDDEN) diff --git a/docs/security.md b/docs/security.md index 1d004003f9a32..15aadf07cf873 100644 --- a/docs/security.md +++ b/docs/security.md @@ -186,7 +186,54 @@ configure those ports.
    Property NameDefaultMeaning
    spark.eventLog.logBlockUpdates.enabledfalse + Whether to log events for every block update, if spark.eventLog.enabled is true. + *Warning*: This will increase the size of the event log considerably. +
    spark.eventLog.compress false
    spark.yarn.kerberos.relogin.period1m + How often to check whether the kerberos TGT should be renewed. This should be set to a value + that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). + The default value should be enough for most deployments. +
    spark.yarn.config.gatewayPath (none)
    +### HTTP Security Headers + +Apache Spark can be configured to include HTTP Headers which aids in preventing Cross +Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP +Strict Transport Security. + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block + Value for HTTP X-XSS-Protection response header. You can choose appropriate value + from below: +
      +
    • 0 (Disables XSS filtering)
    • +
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, + the browser will sanitize the page.)
    • +
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering + of the page if an attack is detected.)
    • +
    +
    spark.ui.xContentTypeOptions.enabledtrue + When value is set to "true", X-Content-Type-Options HTTP response header will be set + to "nosniff". Set "false" to disable. +
    spark.ui.strictTransportSecurityNone + Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate + value from below and set expire-time accordingly, when Spark is SSL/TLS enabled. +
      +
    • max-age=<expire-time>
    • +
    • max-age=<expire-time>; includeSubDomains
    • +
    • max-age=<expire-time>; preload
    • +
    +
    + See the [configuration page](configuration.html) for more details on the security configuration parameters, and org.apache.spark.SecurityManager for implementation details about security. + From 7fae7995ba05e0333d1decb7ca74ddb7c1b448d7 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Fri, 20 Oct 2017 09:40:00 +0900 Subject: [PATCH 734/779] [SPARK-22268][BUILD] Fix lint-java ## What changes were proposed in this pull request? Fix java style issues ## How was this patch tested? Run `./dev/lint-java` locally since it's not run on Jenkins Author: Andrew Ash Closes #19486 from ash211/aash/fix-lint-java. --- .../unsafe/sort/UnsafeInMemorySorter.java | 9 ++++---- .../sort/UnsafeExternalSorterSuite.java | 21 +++++++++++-------- .../sort/UnsafeInMemorySorterSuite.java | 3 ++- .../v2/reader/SupportsPushDownFilters.java | 1 - 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 869ec908be1fb..3bb87a6ed653d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -172,10 +172,11 @@ public void free() { public void reset() { if (consumer != null) { consumer.freeArray(array); - // the call to consumer.allocateArray may trigger a spill - // which in turn access this instance and eventually re-enter this method and try to free the array again. - // by setting the array to null and its length to 0 we effectively make the spill code-path a no-op. - // setting the array to null also indicates that it has already been de-allocated which prevents a double de-allocation in free(). + // the call to consumer.allocateArray may trigger a spill which in turn access this instance + // and eventually re-enter this method and try to free the array again. by setting the array + // to null and its length to 0 we effectively make the spill code-path a no-op. setting the + // array to null also indicates that it has already been de-allocated which prevents a double + // de-allocation in free(). array = null; usableCapacity = 0; pos = 0; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 6c5451d0fd2a5..d0d0334add0bf 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -516,12 +516,13 @@ public void testOOMDuringSpill() throws Exception { for (int i = 0; sorter.hasSpaceForAnotherRecord(); ++i) { insertNumber(sorter, i); } - // we expect the next insert to attempt growing the pointerssArray - // first allocation is expected to fail, then a spill is triggered which attempts another allocation - // which also fails and we expect to see this OOM here. - // the original code messed with a released array within the spill code - // and ended up with a failed assertion. - // we also expect the location of the OOM to be org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset + // we expect the next insert to attempt growing the pointerssArray first + // allocation is expected to fail, then a spill is triggered which + // attempts another allocation which also fails and we expect to see this + // OOM here. the original code messed with a released array within the + // spill code and ended up with a failed assertion. we also expect the + // location of the OOM to be + // org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset memoryManager.markconsequentOOM(2); try { insertNumber(sorter, 1024); @@ -530,9 +531,11 @@ public void testOOMDuringSpill() throws Exception { // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) catch (OutOfMemoryError oom){ String oomStackTrace = Utils.exceptionString(oom); - assertThat("expected OutOfMemoryError in org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", - oomStackTrace, - Matchers.containsString("org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); + assertThat("expected OutOfMemoryError in " + + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", + oomStackTrace, + Matchers.containsString( + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 1a3e11efe9787..594f07dd780f9 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -179,7 +179,8 @@ public int compare( } catch (OutOfMemoryError oom) { // as expected } - // [SPARK-21907] this failed on NPE at org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) + // [SPARK-21907] this failed on NPE at + // org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) sorter.free(); // simulate a 'back to back' free. sorter.free(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index d6f297c013375..6b0c9d417eeae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.sources.Filter; /** From b034f2565f72aa73c9f0be1e49d148bb4cf05153 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Oct 2017 20:24:51 -0700 Subject: [PATCH 735/779] [SPARK-22026][SQL] data source v2 write path ## What changes were proposed in this pull request? A working prototype for data source v2 write path. The writing framework is similar to the reading framework. i.e. `WriteSupport` -> `DataSourceV2Writer` -> `DataWriterFactory` -> `DataWriter`. Similar to the `FileCommitPotocol`, the writing API has job and task level commit/abort to support the transaction. ## How was this patch tested? new tests Author: Wenchen Fan Closes #19269 from cloud-fan/data-source-v2-write. --- .../spark/sql/sources/v2/WriteSupport.java | 49 ++++ .../sources/v2/writer/DataSourceV2Writer.java | 88 +++++++ .../sql/sources/v2/writer/DataWriter.java | 92 +++++++ .../sources/v2/writer/DataWriterFactory.java | 50 ++++ .../v2/writer/SupportsWriteInternalRow.java | 44 ++++ .../v2/writer/WriterCommitMessage.java | 33 +++ .../apache/spark/sql/DataFrameWriter.scala | 38 ++- .../datasources/v2/DataSourceV2Strategy.scala | 11 +- .../datasources/v2/WriteToDataSourceV2.scala | 133 ++++++++++ .../sql/sources/v2/DataSourceV2Suite.scala | 69 +++++ .../sources/v2/SimpleWritableDataSource.scala | 249 ++++++++++++++++++ 11 files changed, 842 insertions(+), 14 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java new file mode 100644 index 0000000000000..a8a961598bde3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability and save the data to the data source. + */ +@InterfaceStability.Evolving +public interface WriteSupport { + + /** + * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * sources can return None if there is no writing needed to be done according to the save mode. + * + * @param jobId A unique string for the writing job. It's possible that there are many writing + * jobs running at the same time, and the returned {@link DataSourceV2Writer} should + * use this job id to distinguish itself with writers of other jobs. + * @param schema the schema of the data to be written. + * @param mode the save mode which determines what to do when the data are already in this data + * source, please refer to {@link SaveMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + Optional createWriter( + String jobId, StructType schema, SaveMode mode, DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java new file mode 100644 index 0000000000000..8d8e33633fb0d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A data source writer that is returned by + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * It can mix in various writing optimization interfaces to speed up the data saving. The actual + * writing logic is delegated to {@link DataWriter}. + * + * The writing procedure is: + * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the + * partitions of the input data(RDD). + * 2. For each partition, create the data writer, and write the data of the partition with this + * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If + * exception happens during the writing, call {@link DataWriter#abort()}. + * 3. If all writers are successfully committed, call {@link #commit(WriterCommitMessage[])}. If + * some writers are aborted, or the job failed with an unknown reason, call + * {@link #abort(WriterCommitMessage[])}. + * + * Spark won't retry failed writing jobs, users should do it manually in their Spark applications if + * they want to retry. + * + * Please refer to the document of commit/abort methods for detailed specifications. + * + * Note that, this interface provides a protocol between Spark and data sources for transactional + * data writing, but the transaction here is Spark-level transaction, which may not be the + * underlying storage transaction. For example, Spark successfully writes data to a Cassandra data + * source, but Cassandra may need some more time to reach consistency at storage level. + */ +@InterfaceStability.Evolving +public interface DataSourceV2Writer { + + /** + * Creates a writer factory which will be serialized and sent to executors. + */ + DataWriterFactory createWriterFactory(); + + /** + * Commits this writing job with a list of commit messages. The commit messages are collected from + * successful data writers and are produced by {@link DataWriter#commit()}. If this method + * fails(throw exception), this writing job is considered to be failed, and + * {@link #abort(WriterCommitMessage[])} will be called. The written data should only be visible + * to data source readers if this method succeeds. + * + * Note that, one partition may have multiple committed data writers because of speculative tasks. + * Spark will pick the first successful one and get its commit message. Implementations should be + * aware of this and handle it correctly, e.g., have a mechanism to make sure only one data writer + * can commit successfully, or have a way to clean up the data of already-committed writers. + */ + void commit(WriterCommitMessage[] messages); + + /** + * Aborts this writing job because some data writers are failed to write the records and aborted, + * or the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} + * fails. If this method fails(throw exception), the underlying data source may have garbage that + * need to be cleaned manually, but these garbage should not be visible to data source readers. + * + * Unless the abort is triggered by the failure of commit, the given messages should have some + * null slots as there maybe only a few data writers that are committed before the abort + * happens, or some data writers were committed but their commit messages haven't reached the + * driver when the abort is triggered. So this is just a "best effort" for data sources to + * clean up the data left by data writers. + */ + void abort(WriterCommitMessage[] messages); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java new file mode 100644 index 0000000000000..14261419af6f6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A data writer returned by {@link DataWriterFactory#createWriter(int, int)} and is + * responsible for writing data for an input RDD partition. + * + * One Spark task has one exclusive data writer, so there is no thread-safe concern. + * + * {@link #write(Object)} is called for each record in the input RDD partition. If one record fails + * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will + * not be processed. If all records are successfully written, {@link #commit()} is called. + * + * If this data writer succeeds(all records are successfully written and {@link #commit()} + * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data + * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an + * exception will be sent to the driver side, and Spark will retry this writing task for some times, + * each time {@link DataWriterFactory#createWriter(int, int)} gets a different `attemptNumber`, + * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. + * + * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task + * takes too long to finish. Different from retried tasks, which are launched one by one after the + * previous one fails, speculative tasks are running simultaneously. It's possible that one input + * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * and data sources should guarantee that these data writers don't conflict and can work together. + * Implementations can coordinate with driver during {@link #commit()} to make sure only one of + * these data writers can commit successfully. Or implementations can allow all of them to commit + * successfully, and have a way to revert committed data writers without the commit message, because + * Spark only accepts the commit message that arrives first and ignore others. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers + * that mix in {@link SupportsWriteInternalRow}. + */ +@InterfaceStability.Evolving +public interface DataWriter { + + /** + * Writes one record. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + void write(T record); + + /** + * Commits this writer after all records are written successfully, returns a commit message which + * will be send back to driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * The written data should only be visible to data source readers after + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to + * do the final commitment via {@link WriterCommitMessage}. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + WriterCommitMessage commit(); + + /** + * Aborts this writer if it is failed. Implementations should clean up the data for already + * written records. + * + * This method will only be called if there is one record failed to write, or {@link #commit()} + * failed. + * + * If this method fails(throw exception), the underlying data source may have garbage that need + * to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, but + * these garbage should not be visible to data source readers. + */ + void abort(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java new file mode 100644 index 0000000000000..f812d102bda1a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A factory of {@link DataWriter} returned by {@link DataSourceV2Writer#createWriterFactory()}, + * which is responsible for creating and initializing the actual data writer at executor side. + * + * Note that, the writer factory will be serialized and sent to executors, then the data writer + * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be + * serializable and {@link DataWriter} doesn't need to be. + */ +@InterfaceStability.Evolving +public interface DataWriterFactory extends Serializable { + + /** + * Returns a data writer to do the actual writing work. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task + * failed, Spark launches a new task wth the same task id but different + * attempt number. Or a task is too slow, Spark launches new tasks wth the + * same task id but different attempt number, which means there are multiple + * tasks with the same task id running at the same time. Implementations can + * use this attempt number to distinguish writers of different task attempts. + */ + DataWriter createWriter(int partitionId, int attemptNumber); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java new file mode 100644 index 0000000000000..a8e95901f3b07 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * A mix-in interface for {@link DataSourceV2Writer}. Data source writers can implement this + * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. + * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get + * changed in the future Spark versions. + */ + +@InterfaceStability.Evolving +@Experimental +@InterfaceStability.Unstable +public interface SupportsWriteInternalRow extends DataSourceV2Writer { + + @Override + default DataWriterFactory createWriterFactory() { + throw new IllegalStateException( + "createWriterFactory should not be called with SupportsWriteInternalRow."); + } + + DataWriterFactory createInternalRowWriterFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java new file mode 100644 index 0000000000000..082d6b5dc409f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side + * as the input parameter of {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * This is an empty interface, data sources should define their own message class and use it in + * their {@link DataWriter#commit()} and {@link DataSourceV2Writer#commit(WriterCommitMessage[])} + * implementations. + */ +@InterfaceStability.Evolving +public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c9e45436ed42f..8d95b24c00619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql -import java.util.{Locale, Properties} +import java.text.SimpleDateFormat +import java.util.{Date, Locale, Properties, UUID} import scala.collection.JavaConverters._ @@ -29,7 +30,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, WriteSupport} import org.apache.spark.sql.types.StructType /** @@ -231,12 +234,33 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + val cls = DataSource.lookupDataSource(source) + if (classOf[DataSourceV2].isAssignableFrom(cls)) { + cls.newInstance() match { + case ds: WriteSupport => + val options = new DataSourceV2Options(extraOptions.asJava) + // Using a timestamp and a random UUID to distinguish different writing jobs. This is good + // enough as there won't be tons of writing jobs created at the same second. + val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + .format(new Date()) + "-" + UUID.randomUUID() + val writer = ds.createWriter(jobId, df.logicalPlan.schema, mode, options) + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get(), df.logicalPlan) + } + } + + case _ => throw new AnalysisException(s"$cls does not support data writing.") + } + } else { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index f2cda002245e8..df5b524485f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,20 +18,17 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { - // TODO: write path override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case DataSourceV2Relation(output, reader) => DataSourceV2ScanExec(output, reader) :: Nil + case WriteToDataSourceV2(writer, query) => + WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala new file mode 100644 index 0000000000000..92c1e1f4a3383 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +/** + * The logical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} + +/** + * The physical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writeTask = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new RowToInternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = query.execute() + val messages = new Array[WriterCommitMessage](rdd.partitions.length) + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${messages.length} partitions.") + + try { + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter), + rdd.partitions.indices, + (index, message: WriterCommitMessage) => messages(index) = message + ) + + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } catch { + case cause: Throwable => + logError(s"Data source writer $writer is aborting.") + try { + writer.abort(messages) + } catch { + case t: Throwable => + logError(s"Data source writer $writer failed to abort.") + cause.addSuppressed(t) + throw new SparkException("Writing job failed.", cause) + } + logError(s"Data source writer $writer aborted.") + throw new SparkException("Writing job aborted.", cause) + } + + sparkContext.emptyRDD + } +} + +object DataWritingSparkTask extends Logging { + def run( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): WriterCommitMessage = { + val dataWriter = writeTask.createWriter(context.partitionId(), context.attemptNumber()) + + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + iter.foreach(dataWriter.write) + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + msg + })(catchBlock = { + // If there is an error, abort this writer + logError(s"Writer for partition ${context.partitionId()} is aborting.") + dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } +} + +class RowToInternalRowDataWriterFactory( + rowWriterFactory: DataWriterFactory[Row], + schema: StructType) extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new RowToInternalRowDataWriter( + rowWriterFactory.createWriter(partitionId, attemptNumber), + RowEncoder.apply(schema).resolveAndBind()) + } +} + +class RowToInternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) + extends DataWriter[InternalRow] { + + override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) + + override def commit(): WriterCommitMessage = rowWriter.commit() + + override def abort(): Unit = rowWriter.abort() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index f238e565dc2fc..092702a1d5173 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,6 +21,7 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -80,6 +81,74 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple writable data source") { + // TODO: java implementation. + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + // test with different save modes + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("append").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).union(spark.range(10)).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("ignore").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + val e = intercept[Exception] { + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("error").save() + } + assert(e.getMessage.contains("data already exists")) + + // test transaction + val failingUdf = org.apache.spark.sql.functions.udf { + var count = 0 + (id: Long) => { + if (count > 5) { + throw new RuntimeException("testing error") + } + count += 1 + id + } + } + // this input data will fail to read middle way. + val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i) + val e2 = intercept[SparkException] { + input.write.format(cls.getName).option("path", path).mode("overwrite").save() + } + assert(e2.getMessage.contains("Writing job aborted")) + // make sure we don't have partial data. + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + // test internal row writer + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).option("internal", "true").mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + } + } + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala new file mode 100644 index 0000000000000..6fb60f4d848d7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.text.SimpleDateFormat +import java.util.{Collections, Date, List => JList, Locale, Optional, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.SerializableConfiguration + +/** + * A HDFS based transactional writable data source. + * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/jobId/` to `target`. + */ +class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { + + private val schema = new StructType().add("i", "long").add("j", "long") + + class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { + override def readSchema(): StructType = schema + + override def createReadTasks(): JList[ReadTask[Row]] = { + val dataPath = new Path(path) + val fs = dataPath.getFileSystem(conf) + if (fs.exists(dataPath)) { + fs.listStatus(dataPath).filterNot { status => + val name = status.getPath.getName + name.startsWith("_") || name.startsWith(".") + }.map { f => + val serializableConf = new SerializableConfiguration(conf) + new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] + }.toList.asJava + } else { + Collections.emptyList() + } + } + } + + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceV2Writer { + override def createWriterFactory(): DataWriterFactory[Row] = { + new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + val finalPath = new Path(path) + val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + try { + for (file <- fs.listStatus(jobPath).map(_.getPath)) { + val dest = new Path(finalPath, file.getName) + if(!fs.rename(file, dest)) { + throw new IOException(s"failed to rename($file, $dest)") + } + } + } finally { + fs.delete(jobPath, true) + } + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + fs.delete(jobPath, true) + } + } + + class InternalRowWriter(jobId: String, path: String, conf: Configuration) + extends Writer(jobId, path, conf) with SupportsWriteInternalRow { + + override def createWriterFactory(): DataWriterFactory[Row] = { + throw new IllegalArgumentException("not expected!") + } + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = { + val path = new Path(options.get("path").get()) + val conf = SparkContext.getActive.get.hadoopConfiguration + new Reader(path.toUri.toString, conf) + } + + override def createWriter( + jobId: String, + schema: StructType, + mode: SaveMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) + assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) + + val path = new Path(options.get("path").get()) + val internal = options.get("internal").isPresent + val conf = SparkContext.getActive.get.hadoopConfiguration + val fs = path.getFileSystem(conf) + + if (mode == SaveMode.ErrorIfExists) { + if (fs.exists(path)) { + throw new RuntimeException("data already exists.") + } + } + if (mode == SaveMode.Ignore) { + if (fs.exists(path)) { + return Optional.empty() + } + } + if (mode == SaveMode.Overwrite) { + fs.delete(path, true) + } + + Optional.of(createWriter(jobId, path, conf, internal)) + } + + private def createWriter( + jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceV2Writer = { + val pathStr = path.toUri.toString + if (internal) { + new InternalRowWriter(jobId, pathStr, conf) + } else { + new Writer(jobId, pathStr, conf) + } + } +} + +class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) + extends ReadTask[Row] with DataReader[Row] { + + @transient private var lines: Iterator[String] = _ + @transient private var currentLine: String = _ + @transient private var inputStream: FSDataInputStream = _ + + override def createReader(): DataReader[Row] = { + val filePath = new Path(path) + val fs = filePath.getFileSystem(conf.value) + inputStream = fs.open(filePath) + lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala + this + } + + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } + + override def get(): Row = Row(currentLine.split(",").map(_.trim.toLong): _*) + + override def close(): Unit = { + inputStream.close() + } +} + +class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[Row] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new SimpleCSVDataWriter(fs, filePath) + } +} + +class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { + + private val out = fs.create(file) + + override def write(record: Row): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +} + +class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new InternalRowCSVDataWriter(fs, filePath) + } +} + +class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { + + private val out = fs.create(file) + + override def write(record: InternalRow): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +} From b84f61cd79a365edd4cc893a1de416c628d9906b Mon Sep 17 00:00:00 2001 From: Eric Perry Date: Thu, 19 Oct 2017 23:57:41 -0700 Subject: [PATCH 736/779] [SQL] Mark strategies with override for clarity. ## What changes were proposed in this pull request? This is a very trivial PR, simply marking `strategies` in `SparkPlanner` with the `override` keyword for clarity since it is overriding `strategies` in `QueryPlanner` two levels up in the class hierarchy. I was reading through the code to learn a bit and got stuck on this fact for a little while, so I figured this may be helpful so that another developer new to the project doesn't get stuck where I was. I did not make a JIRA ticket for this because it is so trivial, but I'm happy to do so to adhere to the contribution guidelines if required. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Perry Closes #19537 from ericjperry/override-strategies. --- .../scala/org/apache/spark/sql/execution/SparkPlanner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index b143d44eae17b..74048871f8d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -33,7 +33,7 @@ class SparkPlanner( def numPartitions: Int = conf.numShufflePartitions - def strategies: Seq[Strategy] = + override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( DataSourceV2Strategy :: From 673876b7eadc6f382afc26fc654b0e7916c9ac5c Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 20 Oct 2017 08:28:05 +0100 Subject: [PATCH 737/779] [SPARK-22309][ML] Remove unused param in `LDAModel.getTopicDistributionMethod` ## What changes were proposed in this pull request? Remove unused param in `LDAModel.getTopicDistributionMethod` ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #19530 from zhengruifeng/lda_bc. --- mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala | 2 +- .../main/scala/org/apache/spark/mllib/clustering/LDAModel.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 3da29b1c816b1..4bab670cc159f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -458,7 +458,7 @@ abstract class LDAModel private[ml] ( if ($(topicDistributionCol).nonEmpty) { // TODO: Make the transformer natively in ml framework to avoid extra conversion. - val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext) + val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 4ab420058f33d..b8a6e94248421 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -371,7 +371,7 @@ class LocalLDAModel private[spark] ( /** * Get a method usable as a UDF for `topicDistributions()` */ - private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { + private[spark] def getTopicDistributionMethod: Vector => Vector = { val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape From e2fea8cd6058a807ff4841b496ea345ff0553044 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 20 Oct 2017 09:43:46 +0100 Subject: [PATCH 738/779] [CORE][DOC] Add event log conf. ## What changes were proposed in this pull request? Event Log Server has a total of five configuration parameters, and now the description of the other two configuration parameters on the doc, user-friendly access and use. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19242 from guoxiaolongzte/addEventLogConf. --- docs/configuration.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 7b9e16a382449..d3c358bb74173 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -748,6 +748,20 @@ Apart from these, the following properties are also available, and may be useful finished. + + spark.eventLog.overwrite + false + + Whether to overwrite any existing files. + + + + spark.eventLog.buffer.kb + 100k + + Buffer size in KB to use when writing to output streams. + + spark.ui.enabled true From 16c9cc68c5a70fd50e214f6deba591f0a9ae5cca Mon Sep 17 00:00:00 2001 From: CenYuhai Date: Fri, 20 Oct 2017 09:27:39 -0700 Subject: [PATCH 739/779] [SPARK-21055][SQL] replace grouping__id with grouping_id() ## What changes were proposed in this pull request? spark does not support grouping__id, it has grouping_id() instead. But it is not convenient for hive user to change to spark-sql so this pr is to replace grouping__id with grouping_id() hive user need not to alter their scripts ## How was this patch tested? test with SQLQuerySuite.scala Author: CenYuhai Closes #18270 from cenyuhai/SPARK-21055. --- .../sql/catalyst/analysis/Analyzer.scala | 15 +-- .../sql-tests/inputs/group-analytics.sql | 6 +- .../sql-tests/results/group-analytics.sql.out | 43 ++++--- .../sql/hive/execution/SQLQuerySuite.scala | 110 ++++++++++++++++++ 4 files changed, 148 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8edf575db7969..d6a962a14dc9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL @@ -293,12 +293,6 @@ class Analyzer( Seq(Seq.empty) } - private def hasGroupingAttribute(expr: Expression): Boolean = { - expr.collectFirst { - case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u - }.isDefined - } - private[analysis] def hasGroupingFunction(e: Expression): Boolean = { e.collectFirst { case g: Grouping => g @@ -452,9 +446,6 @@ class Analyzer( // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. - case p if p.expressions.exists(hasGroupingAttribute) => - failAnalysis( - s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") // Ensure group by expressions and aggregate expressions have been resolved. case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) @@ -1174,6 +1165,10 @@ class Analyzer( case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. + case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => + withPosition(u) { + Alias(GroupingID(Nil), VirtualColumn.hiveGroupingIdName)() + } case u @ UnresolvedGenerator(name, children) => withPosition(u) { catalog.lookupFunction(name, children) match { diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index 8aff4cb524199..9721f8c60ebce 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -38,11 +38,11 @@ SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) GROUP BY CUBE(course, year); SELECT course, year, GROUPING(course) FROM courseSales GROUP BY course, year; SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY course, year; -SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year); +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year; -- GROUPING/GROUPING_ID in having clause SELECT course, year FROM courseSales GROUP BY CUBE(course, year) -HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0; +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, year; SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING(course) > 0; SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0; SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0; @@ -54,7 +54,7 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year; SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year; -- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2); diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index ce7a16a4d0c81..3439a05727f95 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -223,22 +223,29 @@ grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 16 -SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year -- !query 16 schema -struct<> +struct -- !query 16 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java 2012 0 +Java 2013 0 +dotNET 2012 0 +dotNET 2013 0 +Java NULL 1 +dotNET NULL 1 +NULL 2012 2 +NULL 2013 2 +NULL NULL 3 -- !query 17 SELECT course, year FROM courseSales GROUP BY CUBE(course, year) -HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, year -- !query 17 schema struct -- !query 17 output -Java NULL NULL NULL +Java NULL dotNET NULL @@ -263,10 +270,13 @@ grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 20 SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0 -- !query 20 schema -struct<> +struct -- !query 20 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java NULL +NULL 2012 +NULL 2013 +NULL NULL +dotNET NULL -- !query 21 @@ -322,12 +332,19 @@ grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 25 -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year -- !query 25 schema -struct<> +struct -- !query 25 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java 2012 +Java 2013 +dotNET 2012 +dotNET 2013 +Java NULL +dotNET NULL +NULL 2012 +NULL 2013 +NULL NULL -- !query 26 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 60935c3e85c43..2476a440ad82c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1388,6 +1388,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3))) } + test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH ROLLUP"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + test("SPARK-8976 Wrong Result for Rollup #2") { checkAnswer(sql( """ @@ -1409,6 +1422,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM src GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-8976 Wrong Result for Rollup #3") { checkAnswer(sql( """ @@ -1430,6 +1464,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #3") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-8976 Wrong Result for CUBE #1") { checkAnswer(sql( "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"), @@ -1443,6 +1498,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3))) } + test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH CUBE"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + test("SPARK-8976 Wrong Result for CUBE #2") { checkAnswer(sql( """ @@ -1464,6 +1532,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-8976 Wrong Result for GroupingSet") { checkAnswer(sql( """ @@ -1485,6 +1574,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for GroupingSet") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + ignore("SPARK-10562: partition by column with mixed case name") { withTable("tbl10562") { val df = Seq(2012 -> "a").toDF("Year", "val") From 568763bafb7acfcf5921d6492034d1f6f87875e2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 20 Oct 2017 12:32:45 -0700 Subject: [PATCH 740/779] [INFRA] Close stale PRs. Closes #19541 Closes #19542 From b8624b06e5d531ebc14acb05da286f96f4bc9515 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 20 Oct 2017 12:44:30 -0700 Subject: [PATCH 741/779] [SPARK-20396][SQL][PYSPARK][FOLLOW-UP] groupby().apply() with pandas udf ## What changes were proposed in this pull request? This is a follow-up of #18732. This pr modifies `GroupedData.apply()` method to convert pandas udf to grouped udf implicitly. ## How was this patch tested? Exisiting tests. Author: Takuya UESHIN Closes #19517 from ueshin/issues/SPARK-20396/fup2. --- .../spark/api/python/PythonRunner.scala | 1 + python/pyspark/serializers.py | 1 + python/pyspark/sql/functions.py | 33 ++++++++++------ python/pyspark/sql/group.py | 14 ++++--- python/pyspark/sql/tests.py | 37 ++++++++++++++++++ python/pyspark/worker.py | 39 ++++++++----------- .../logical/pythonLogicalOperators.scala | 9 +++-- .../spark/sql/RelationalGroupedDataset.scala | 7 ++-- .../execution/python/ExtractPythonUDFs.scala | 6 ++- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../sql/execution/python/PythonUDF.scala | 2 +- .../python/UserDefinedPythonFunction.scala | 13 ++++++- .../python/BatchEvalPythonExecSuite.scala | 2 +- 13 files changed, 114 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 3688a149443c1..d417303bb147d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -36,6 +36,7 @@ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 1 val SQL_PANDAS_UDF = 2 + val SQL_PANDAS_GROUPED_UDF = 3 } /** diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ad18bd0c81eaa..a0adeed994456 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -86,6 +86,7 @@ class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 SQL_PANDAS_UDF = 2 + SQL_PANDAS_GROUPED_UDF = 3 class Serializer(object): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bc12c3b7a162..9bc374b93a433 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2038,13 +2038,22 @@ def _wrap_function(sc, func, returnType): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class PythonUdfType(object): + # row-at-a-time UDFs + NORMAL_UDF = 0 + # scalar vectorized UDFs + PANDAS_UDF = 1 + # grouped vectorized UDFs + PANDAS_GROUPED_UDF = 2 + + class UserDefinedFunction(object): """ User defined function in Python .. versionadded:: 1.3 """ - def __init__(self, func, returnType, name=None, vectorized=False): + def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF): if not callable(func): raise TypeError( "Not a function or callable (__call__ is not defined): " @@ -2058,7 +2067,7 @@ def __init__(self, func, returnType, name=None, vectorized=False): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self.vectorized = vectorized + self.pythonUdfType = pythonUdfType @property def returnType(self): @@ -2090,7 +2099,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.vectorized) + self._name, wrapped_func, jdt, self.pythonUdfType) return judf def __call__(self, *cols): @@ -2121,15 +2130,15 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType - wrapper.vectorized = self.vectorized + wrapper.pythonUdfType = self.pythonUdfType return wrapper -def _create_udf(f, returnType, vectorized): +def _create_udf(f, returnType, pythonUdfType): - def _udf(f, returnType=StringType(), vectorized=vectorized): - if vectorized: + def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType): + if pythonUdfType == PythonUdfType.PANDAS_UDF: import inspect argspec = inspect.getargspec(f) if len(argspec.args) == 0 and argspec.varargs is None: @@ -2137,7 +2146,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): "0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) + udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType) return udf_obj._wrapped() # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf @@ -2145,9 +2154,9 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): # If DataType has been passed as a positional argument # for decorator use it as a returnType return_type = f or returnType - return functools.partial(_udf, returnType=return_type, vectorized=vectorized) + return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType) else: - return _udf(f=f, returnType=returnType, vectorized=vectorized) + return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType) @since(1.3) @@ -2181,7 +2190,7 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - return _create_udf(f, returnType=returnType, vectorized=False) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF) @since(2.3) @@ -2252,7 +2261,7 @@ def pandas_udf(f=None, returnType=StringType()): .. note:: The user-defined function must be deterministic. """ - return _create_udf(f, returnType=returnType, vectorized=True) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 817d0bc83bb77..e11388d604312 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,6 +19,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import PythonUdfType, UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -235,11 +236,13 @@ def apply(self, udf): .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ - from pyspark.sql.functions import pandas_udf + import inspect # Columns are special because hasattr always return True - if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: - raise ValueError("The argument to apply must be a pandas_udf") + if isinstance(udf, Column) or not hasattr(udf, 'func') \ + or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \ + or len(inspect.getargspec(udf.func).args) != 1: + raise ValueError("The argument to apply must be a 1-arg pandas_udf") if not isinstance(udf.returnType, StructType): raise ValueError("The returnType of the pandas_udf must be a StructType") @@ -268,8 +271,9 @@ def wrapped(*cols): return [(result[result.columns[i]], arrow_type) for i, arrow_type in enumerate(arrow_return_types)] - wrapped_udf_obj = pandas_udf(wrapped, returnType) - udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) + udf_obj = UserDefinedFunction( + wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF) + udf_column = udf_obj(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bac2ef84ae7a7..685eebcafefba 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3383,6 +3383,15 @@ def test_vectorized_udf_varargs(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, DateType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.select(f(col('dt'))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @@ -3492,6 +3501,18 @@ def normalize(pdf): expected = expected.assign(norm=expected.norm.astype('float64')) self.assertFramesEqual(expected, result) + def test_datatype_string(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + "id long, v int, v1 double, v2 long") + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf df = self.data @@ -3517,9 +3538,25 @@ def test_wrong_args(self): df.groupby('id').apply(sum(df.v)) with self.assertRaisesRegexp(ValueError, 'pandas_udf'): df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'returnType'): df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + def test_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType( + [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, df.schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.groupby('id').apply(f).collect() + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index eb6d48688dc0a..5e100e0a9a95d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,7 +32,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type, StructType +from pyspark.sql.types import to_arrow_type from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,28 +74,19 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - # If the return_type is a StructType, it indicates this is a groupby apply udf, - # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. - # We can distinguish these two by return type because in groupby apply, we always specify - # returnType as a StructType, and in vectorized column udf, StructType is not supported. - # - # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs - if isinstance(return_type, StructType): - return lambda *a: f(*a) - else: - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type(return_type) - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " - "Pandas.Series, but is {}".format(type(result))) - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): @@ -111,6 +102,9 @@ def read_single_udf(pickleSer, infile, eval_type): # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: return arg_offsets, wrap_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + # a groupby apply udf has already been wrapped under apply() + return arg_offsets, row_func else: return arg_offsets, wrap_udf(row_func, return_type) @@ -133,7 +127,8 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_UDF: + if eval_type == PythonEvalType.SQL_PANDAS_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 8abab24bc9b44..254687ec00880 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre * This is used by DataFrame.groupby().apply(). */ case class FlatMapGroupsInPandas( - groupingAttributes: Seq[Attribute], - functionExpr: Expression, - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** * This is needed because output attributes are considered `references` when * passed through the constructor. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 33ec3a27110a8..6b45790d5ff6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF +import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -437,7 +437,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * Applies a vectorized python user-defined function to each group of data. + * Applies a grouped vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results * for all groups are combined into a new [[DataFrame]]. @@ -449,7 +449,8 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.vectorized, "Must pass a vectorized python udf") + require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF, + "Must pass a grouped vectorized python udf") require(expr.dataType.isInstanceOf[StructType], "The returnType of the vectorized python udf must be a StructType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e3f952e221d53..d6825369f7378 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -137,11 +137,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { + if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { + throw new IllegalArgumentException("Can not use grouped vectorized UDFs") + } + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = validUdfs.partition(_.vectorized) match { + val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index b996b5bb38ba5..5ed88ada428cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 84a6d9e5be59c..9c07c7638de57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - vectorized: Boolean) + pythonUdfType: Int) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index a30a80acf5c23..b2fe6c300846a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -22,6 +22,15 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType +private[spark] object PythonUdfType { + // row-at-a-time UDFs + val NORMAL_UDF = 0 + // scalar vectorized UDFs + val PANDAS_UDF = 1 + // grouped vectorized UDFs + val PANDAS_GROUPED_UDF = 2 +} + /** * A user-defined Python function. This is used by the Python API. */ @@ -29,10 +38,10 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - vectorized: Boolean) { + pythonUdfType: Int) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, vectorized) + PythonUDF(name, func, dataType, e, pythonUdfType) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 153e6e1f88c70..95b21fc9f16ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - vectorized = false) + pythonUdfType = PythonUdfType.NORMAL_UDF) From d9f286d261c6ee9e8dcb46e78d4666318ea25af2 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Fri, 20 Oct 2017 20:58:55 -0700 Subject: [PATCH 742/779] [SPARK-22326][SQL] Remove unnecessary hashCode and equals methods ## What changes were proposed in this pull request? Plan equality should be computed by `canonicalized`, so we can remove unnecessary `hashCode` and `equals` methods. ## How was this patch tested? Existing tests. Author: Zhenhua Wang Closes #19539 from wzhfy/remove_equals. --- .../apache/spark/sql/catalyst/catalog/interface.scala | 11 ----------- .../sql/execution/datasources/LogicalRelation.scala | 11 ----------- 2 files changed, 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 975b084aa6188..1dbae4d37d8f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -22,8 +22,6 @@ import java.util.Date import scala.collection.mutable -import com.google.common.base.Objects - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -440,15 +438,6 @@ case class HiveTableRelation( def isPartitioned: Boolean = partitionCols.nonEmpty - override def equals(relation: Any): Boolean = relation match { - case other: HiveTableRelation => tableMeta == other.tableMeta && output == other.output - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode(tableMeta.identifier, output) - } - override lazy val canonicalized: HiveTableRelation = copy( tableMeta = tableMeta.copy( storage = CatalogStorageFormat.empty, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 17a61074d3b5c..3e98cb28453a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -34,17 +34,6 @@ case class LogicalRelation( override val isStreaming: Boolean) extends LeafNode with MultiInstanceRelation { - // Logical Relations are distinct if they have different output for the sake of transformations. - override def equals(other: Any): Boolean = other match { - case l @ LogicalRelation(otherRelation, _, _, isStreaming) => - relation == otherRelation && output == l.output && isStreaming == l.isStreaming - case _ => false - } - - override def hashCode: Int = { - com.google.common.base.Objects.hashCode(relation, output) - } - // Only care about relation when canonicalizing. override lazy val canonicalized: LogicalPlan = copy( output = output.map(QueryPlan.normalizeExprId(_, output)), From d8cada8d1d3fce979a4bc1f9879593206722a3b9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Oct 2017 10:05:45 -0700 Subject: [PATCH 743/779] [SPARK-20331][SQL][FOLLOW-UP] Add a SQLConf for enhanced Hive partition pruning predicate pushdown ## What changes were proposed in this pull request? This is a follow-up PR of https://github.com/apache/spark/pull/17633. This PR is to add a conf `spark.sql.hive.advancedPartitionPredicatePushdown.enabled`, which can be used to turn the enhancement off. ## How was this patch tested? Add a test case Author: gatorsmile Closes #19547 from gatorsmile/Spark20331FollowUp. --- .../apache/spark/sql/internal/SQLConf.scala | 10 +++++++ .../spark/sql/hive/client/HiveShim.scala | 29 ++++++++++++++++++ .../spark/sql/hive/client/FiltersSuite.scala | 30 +++++++++++++++---- 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 618d4a0d6148a..4cfe53b2c115b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -173,6 +173,13 @@ object SQLConf { .intConf .createWithDefault(4) + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = + buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") + .internal() + .doc("When true, advanced partition predicate pushdown into Hive metastore is enabled.") + .booleanConf + .createWithDefault(true) + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = buildConf("spark.sql.statistics.fallBackToHdfs") .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + @@ -1092,6 +1099,9 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def advancedPartitionPredicatePushdownEnabled: Boolean = + getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index cde20da186acd..5c1ff2b76fdaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -585,6 +585,35 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { * Unsupported predicates are skipped. */ def convertFilters(table: Table, filters: Seq[Expression]): String = { + if (SQLConf.get.advancedPartitionPredicatePushdownEnabled) { + convertComplexFilters(table, filters) + } else { + convertBasicFilters(table, filters) + } + } + + private def convertBasicFilters(table: Table, filters: Seq[Expression]): String = { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + lazy val varcharKeys = table.getPartitionKeys.asScala + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + filters.collect { + case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + s"${a.name} ${op.symbol} $v" + case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + s"$v ${op.symbol} ${a.name}" + case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + if !varcharKeys.contains(a.name) => + s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" + case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + if !varcharKeys.contains(a.name) => + s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" + }.mkString(" and ") + } + + private def convertComplexFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. lazy val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 031c1a5ec0ec3..19765695fbcb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -26,13 +26,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * A set of tests for the filter conversion logic used when pushing partition pruning into the * metastore */ -class FiltersSuite extends SparkFunSuite with Logging { +class FiltersSuite extends SparkFunSuite with Logging with PlanTest { private val shim = new Shim_v0_13 private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") @@ -72,10 +74,28 @@ class FiltersSuite extends SparkFunSuite with Logging { private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { - val converted = shim.convertFilters(testTable, filters) - if (converted != result) { - fail( - s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { + val converted = shim.convertFilters(testTable, filters) + if (converted != result) { + fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'") + } + } + } + } + + test("turn on/off ADVANCED_PARTITION_PREDICATE_PUSHDOWN") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> enabled.toString) { + val filters = + (Literal(1) === a("intcol", IntegerType) || + Literal(2) === a("intcol", IntegerType)) :: Nil + val converted = shim.convertFilters(testTable, filters) + if (enabled) { + assert(converted == "(1 = intcol or 2 = intcol)") + } else { + assert(converted.isEmpty) + } } } } From a763607e4fc24f4dc0f455b67a63acba5be1c80a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Oct 2017 10:07:31 -0700 Subject: [PATCH 744/779] [SPARK-21055][SQL][FOLLOW-UP] replace grouping__id with grouping_id() ## What changes were proposed in this pull request? Simplifies the test cases that were added in the PR https://github.com/apache/spark/pull/18270. ## How was this patch tested? N/A Author: gatorsmile Closes #19546 from gatorsmile/backportSPARK-21055. --- .../sql/hive/execution/SQLQuerySuite.scala | 306 ++++++------------ 1 file changed, 104 insertions(+), 202 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2476a440ad82c..1cf1c5cd5a472 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1376,223 +1376,125 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-8976 Wrong Result for Rollup #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH ROLLUP"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s"SELECT count(*) AS cnt, key % 5, $gid FROM src GROUP BY key%5 WITH ROLLUP"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } } test("SPARK-8976 Wrong Result for Rollup #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM src GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM src GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM src GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for Rollup #3") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #3") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for CUBE #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH CUBE"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s"SELECT count(*) AS cnt, key % 5, $gid FROM src GROUP BY key%5 WITH CUBE"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } } test("SPARK-8976 Wrong Result for CUBE #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for GroupingSet") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for GroupingSet") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } ignore("SPARK-10562: partition by column with mixed case name") { From ff8de99a1c7b4a291e661cd0ad12748f4321e43d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 22 Oct 2017 02:22:35 +0900 Subject: [PATCH 745/779] [SPARK-22302][INFRA] Remove manual backports for subprocess and print explicit message for < Python 2.7 ## What changes were proposed in this pull request? Seems there was a mistake - missing import for `subprocess.call`, while refactoring this script a long ago, which should be used for backports of some missing functions in `subprocess`, specifically in < Python 2.7. Reproduction is: ``` cd dev && python2.6 ``` ``` >>> from sparktestsupport import shellutils >>> shellutils.subprocess_check_call("ls") Traceback (most recent call last): File "", line 1, in File "sparktestsupport/shellutils.py", line 46, in subprocess_check_call retcode = call(*popenargs, **kwargs) NameError: global name 'call' is not defined ``` For Jenkins logs, please see https://amplab.cs.berkeley.edu/jenkins/job/NewSparkPullRequestBuilder/3950/console Since we dropped the Python 2.6.x support, looks better we remove those workarounds and print out explicit error messages in order to reduce the efforts to find out the root causes for such cases, for example, `https://github.com/apache/spark/pull/19513#issuecomment-337406734`. ## How was this patch tested? Manually tested: ``` ./dev/run-tests ``` ``` Python versions prior to 2.7 are not supported. ``` ``` ./dev/run-tests-jenkins ``` ``` Python versions prior to 2.7 are not supported. ``` Author: hyukjinkwon Closes #19524 from HyukjinKwon/SPARK-22302. --- dev/run-tests | 6 ++++++ dev/run-tests-jenkins | 7 ++++++- dev/sparktestsupport/shellutils.py | 31 ++---------------------------- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 257d1e8d50bb4..9cf93d000d0ea 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -20,4 +20,10 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" +PYTHON_VERSION_CHECK=$(python -c 'import sys; print(sys.version_info < (2, 7, 0))') +if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then + echo "Python versions prior to 2.7 are not supported." + exit -1 +fi + exec python -u ./dev/run-tests.py "$@" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f41f1ac79e381..03fd6ff0fba40 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -25,5 +25,10 @@ FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" -export PATH=/home/anaconda/bin:$PATH +PYTHON_VERSION_CHECK=$(python -c 'import sys; print(sys.version_info < (2, 7, 0))') +if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then + echo "Python versions prior to 2.7 are not supported." + exit -1 +fi + exec python -u ./dev/run-tests-jenkins.py "$@" diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index 05af87189b18d..c7644da88f770 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -21,35 +21,8 @@ import subprocess import sys - -if sys.version_info >= (2, 7): - subprocess_check_output = subprocess.check_output - subprocess_check_call = subprocess.check_call -else: - # SPARK-8763 - # backported from subprocess module in Python 2.7 - def subprocess_check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output - - # backported from subprocess module in Python 2.7 - def subprocess_check_call(*popenargs, **kwargs): - retcode = call(*popenargs, **kwargs) - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise CalledProcessError(retcode, cmd) - return 0 +subprocess_check_output = subprocess.check_output +subprocess_check_call = subprocess.check_call def exit_from_command_with_retcode(cmd, retcode): From ca2a780e7c4c4df2488ef933241c6e65264f8d3c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 21 Oct 2017 18:01:45 -0700 Subject: [PATCH 746/779] [SPARK-21929][SQL] Support `ALTER TABLE table_name ADD COLUMNS(..)` for ORC data source ## What changes were proposed in this pull request? When [SPARK-19261](https://issues.apache.org/jira/browse/SPARK-19261) implements `ALTER TABLE ADD COLUMNS`, ORC data source is omitted due to SPARK-14387, SPARK-16628, and SPARK-18355. Now, those issues are fixed and Spark 2.3 is [using Spark schema to read ORC table instead of ORC file schema](https://github.com/apache/spark/commit/e6e36004afc3f9fc8abea98542248e9de11b4435). This PR enables `ALTER TABLE ADD COLUMNS` for ORC data source. ## How was this patch tested? Pass the updated and added test cases. Author: Dongjoon Hyun Closes #19545 from dongjoon-hyun/SPARK-21929. --- .../spark/sql/execution/command/tables.scala | 3 +- .../sql/execution/command/DDLSuite.scala | 90 ++++++++++--------- .../sql/hive/execution/HiveDDLSuite.scala | 8 ++ 3 files changed, 58 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8d95ca6921cf8..38f91639c0422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -235,11 +235,10 @@ case class AlterTableAddColumnsCommand( DataSource.lookupDataSource(catalogTable.provider.get).newInstance() match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" - // OrcFileFormat can not handle difference between user-specified schema and - // inferred schema yet. TODO, once this issue is resolved , we can add Orc back. // Hive type is already considered as hive serde table, so the logic will not // come in here. case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat => + case s if s.getClass.getCanonicalName.endsWith("OrcFileFormat") => case s => throw new AnalysisException( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4ed2cecc5faff..21a2c62929146 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2202,56 +2202,64 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + protected def testAddColumn(provider: String): Unit = { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int) USING $provider") + sql("INSERT INTO t1 VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 is null"), + Seq(Row(1, null)) + ) + + sql("INSERT INTO t1 VALUES (3, 2)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 2"), + Seq(Row(3, 2)) + ) + } + } + + protected def testAddColumnPartitioned(provider: String): Unit = { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") + sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null, 2)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 is null"), + Seq(Row(1, null, 2)) + ) + sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 = 3"), + Seq(Row(2, 3, 1)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 1"), + Seq(Row(2, 3, 1)) + ) + } + } + val supportedNativeFileFormatsForAlterTableAddColumns = Seq("parquet", "json", "csv") supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => test(s"alter datasource table add columns - $provider") { - withTable("t1") { - sql(s"CREATE TABLE t1 (c1 int) USING $provider") - sql("INSERT INTO t1 VALUES (1)") - sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") - checkAnswer( - spark.table("t1"), - Seq(Row(1, null)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 is null"), - Seq(Row(1, null)) - ) - - sql("INSERT INTO t1 VALUES (3, 2)") - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 = 2"), - Seq(Row(3, 2)) - ) - } + testAddColumn(provider) } } supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => test(s"alter datasource table add columns - partitioned - $provider") { - withTable("t1") { - sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") - sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") - sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") - checkAnswer( - spark.table("t1"), - Seq(Row(1, null, 2)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c3 is null"), - Seq(Row(1, null, 2)) - ) - sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") - checkAnswer( - sql("SELECT * FROM t1 WHERE c3 = 3"), - Seq(Row(2, 3, 1)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 = 1"), - Seq(Row(2, 3, 1)) - ) - } + testAddColumnPartitioned(provider) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 02e26bbe876a0..d3465a641a1a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -166,6 +166,14 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA test("drop table") { testDropTable(isDatasourceTable = false) } + + test("alter datasource table add columns - orc") { + testAddColumn("orc") + } + + test("alter datasource table add columns - partitioned - orc") { + testAddColumnPartitioned("orc") + } } class HiveDDLSuite From 57accf6e3965ff69adc4408623916c5003918235 Mon Sep 17 00:00:00 2001 From: Steven Rand Date: Mon, 23 Oct 2017 09:43:45 +0800 Subject: [PATCH 747/779] [SPARK-22319][CORE] call loginUserFromKeytab before accessing hdfs In `SparkSubmit`, call `loginUserFromKeytab` before attempting to make RPC calls to the NameNode. I manually tested this patch by: 1. Confirming that my Spark application failed to launch with the error reported in https://issues.apache.org/jira/browse/SPARK-22319. 2. Applying this patch and confirming that the app no longer fails to launch, even when I have not manually run `kinit` on the host. Presumably we also want integration tests for secure clusters so that we catch this sort of thing. I'm happy to take a shot at this if it's feasible and someone can point me in the right direction. Author: Steven Rand Closes #19540 from sjrand/SPARK-22319. Change-Id: Ic306bfe7181107fbcf92f61d75856afcb5b6f761 --- .../org/apache/spark/deploy/SparkSubmit.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 135bbe93bf28e..b7e6d0ea021a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -342,6 +342,22 @@ object SparkSubmit extends CommandLineUtils with Logging { val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() + // assure a keytab is available from any place in a JVM + if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (args.principal != null) { + if (args.keytab != null) { + require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") + // Add keytab and principal configurations in sysProps to make them available + // for later use; e.g. in spark sql, the isolated class loader used to talk + // to HiveMetastore will use these settings. They will be set as Java system + // properties and then loaded by SparkConf + sysProps.put("spark.yarn.keytab", args.keytab) + sysProps.put("spark.yarn.principal", args.principal) + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } + } + } + // Resolve glob path for different resources. args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull @@ -641,22 +657,6 @@ object SparkSubmit extends CommandLineUtils with Logging { } } - // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { - if (args.principal != null) { - if (args.keytab != null) { - require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") - // Add keytab and principal configurations in sysProps to make them available - // for later use; e.g. in spark sql, the isolated class loader used to talk - // to HiveMetastore will use these settings. They will be set as Java system - // properties and then loaded by SparkConf - sysProps.put("spark.yarn.keytab", args.keytab) - sysProps.put("spark.yarn.principal", args.principal) - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) - } - } - } - if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) { setRMPrincipal(sysProps) } From 5a5b6b78517b526771ee5b579d56aa1daa4b3ef1 Mon Sep 17 00:00:00 2001 From: Kohki Nishio Date: Mon, 23 Oct 2017 09:55:46 -0700 Subject: [PATCH 748/779] [SPARK-22303][SQL] Handle Oracle specific jdbc types in OracleDialect TIMESTAMP (-101), BINARY_DOUBLE (101) and BINARY_FLOAT (100) are handled in OracleDialect ## What changes were proposed in this pull request? When a oracle table contains columns whose type is BINARY_FLOAT or BINARY_DOUBLE, spark sql fails to load a table with SQLException ``` java.sql.SQLException: Unsupported type 101 at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.org$apache$spark$sql$execution$datasources$jdbc$JdbcUtils$$getCatalystType(JdbcUtils.scala:235) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$8.apply(JdbcUtils.scala:292) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$8.apply(JdbcUtils.scala:292) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.getSchema(JdbcUtils.scala:291) at org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD$.resolveTable(JDBCRDD.scala:64) at org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation.(JDBCRelation.scala:113) at org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider.createRelation(JdbcRelationProvider.scala:47) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:306) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:178) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:146) ``` ## How was this patch tested? I updated a UT which covers type conversion test for types (-101, 100, 101), on top of that I tested this change against actual table with those columns and it was able to read and write to the table. Author: Kohki Nishio Closes #19548 from taroplus/oracle_sql_types_101. --- .../sql/jdbc/OracleIntegrationSuite.scala | 43 +++++++++++++++--- .../datasources/jdbc/JdbcUtils.scala | 1 - .../apache/spark/sql/jdbc/OracleDialect.scala | 44 +++++++++++-------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 6 +++ 4 files changed, 68 insertions(+), 26 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 7680ae3835132..90343182712ed 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp} import java.util.Properties import java.math.BigDecimal -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SaveMode} import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -52,7 +52,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo import testImplicits._ override val db = new DatabaseOnDocker { - override val imageName = "wnameless/oracle-xe-11g:14.04.4" + override val imageName = "wnameless/oracle-xe-11g:16.04" override val env = Map( "ORACLE_ROOT_PASSWORD" -> "oracle" ) @@ -104,15 +104,18 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate(); + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate() conn.prepareStatement( - "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate(); - conn.commit(); + "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate() + conn.commit() } test("SPARK-16625 : Importing Oracle numeric types") { - val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties); + val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties) val rows = df.collect() assert(rows.size == 1) val row = rows(0) @@ -307,4 +310,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo assert(values.getInt(1).equals(1)) assert(values.getBoolean(2).equals(false)) } + + test("SPARK-22303: handle BINARY_DOUBLE and BINARY_FLOAT as DoubleType and FloatType") { + val tableName = "oracle_types" + val schema = StructType(Seq( + StructField("d", DoubleType, true), + StructField("f", FloatType, true))) + val props = new Properties() + + // write it back to the table (append mode) + val data = spark.sparkContext.parallelize(Seq(Row(1.1, 2.2f))) + val dfWrite = spark.createDataFrame(data, schema) + dfWrite.write.mode(SaveMode.Append).jdbc(jdbcUrl, tableName, props) + + // read records from oracle_types + val dfRead = sqlContext.read.jdbc(jdbcUrl, tableName, new Properties) + val rows = dfRead.collect() + assert(rows.size == 1) + + // check data types + val types = dfRead.schema.map(field => field.dataType) + assert(types(0).equals(DoubleType)) + assert(types(1).equals(FloatType)) + + // check values + val values = rows(0) + assert(values.getDouble(0) === 1.1) + assert(values.getFloat(1) === 2.2f) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 71133666b3249..9debc4ff82748 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -230,7 +230,6 @@ object JdbcUtils extends Logging { case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => TimestampType - case -101 => TimestampType // Value for Timestamp with Time Zone in Oracle case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 3b44c1de93a61..e3f106c41c7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -23,30 +23,36 @@ import org.apache.spark.sql.types._ private case object OracleDialect extends JdbcDialect { + private[jdbc] val BINARY_FLOAT = 100 + private[jdbc] val BINARY_DOUBLE = 101 + private[jdbc] val TIMESTAMPTZ = -101 override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.NUMERIC) { - val scale = if (null != md) md.build().getLong("scale") else 0L - size match { - // Handle NUMBER fields that have no precision/scale in special way - // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale - // For more details, please see - // https://github.com/apache/spark/pull/8780#issuecomment-145598968 - // and - // https://github.com/apache/spark/pull/8780#issuecomment-144541760 - case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts - // this to NUMERIC with -127 scale - // Not sure if there is a more robust way to identify the field as a float (or other - // numeric types that do not specify a scale. - case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case _ => None - } - } else { - None + sqlType match { + case Types.NUMERIC => + val scale = if (null != md) md.build().getLong("scale") else 0L + size match { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts + // this to NUMERIC with -127 scale + // Not sure if there is a more robust way to identify the field as a float (or other + // numeric types that do not specify a scale. + case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + case _ => None + } + case TIMESTAMPTZ => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle + case BINARY_FLOAT => Some(FloatType) // Value for OracleTypes.BINARY_FLOAT + case BINARY_DOUBLE => Some(DoubleType) // Value for OracleTypes.BINARY_DOUBLE + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 34205e0b2bf08..167b3e0190026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -815,6 +815,12 @@ class JDBCSuite extends SparkFunSuite Some(DecimalType(DecimalType.MAX_PRECISION, 10))) assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "numeric", 0, null) == Some(DecimalType(DecimalType.MAX_PRECISION, 10))) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_FLOAT, "BINARY_FLOAT", 0, null) == + Some(FloatType)) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_DOUBLE, "BINARY_DOUBLE", 0, null) == + Some(DoubleType)) + assert(oracleDialect.getCatalystType(OracleDialect.TIMESTAMPTZ, "TIMESTAMP", 0, null) == + Some(TimestampType)) } test("table exists query by jdbc dialect") { From f6290aea24efeb238db88bdaef4e24d50740ca4c Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 23 Oct 2017 23:02:36 +0100 Subject: [PATCH 749/779] [SPARK-22285][SQL] Change implementation of ApproxCountDistinctForIntervals to TypedImperativeAggregate ## What changes were proposed in this pull request? The current implementation of `ApproxCountDistinctForIntervals` is `ImperativeAggregate`. The number of `aggBufferAttributes` is the number of total words in the hllppHelper array. Each hllppHelper has 52 words by default relativeSD. Since this aggregate function is used in equi-height histogram generation, and the number of buckets in histogram is usually hundreds, the number of `aggBufferAttributes` can easily reach tens of thousands or even more. This leads to a huge method in codegen and causes error: ``` org.codehaus.janino.JaninoRuntimeException: Code of method "apply(Lorg/apache/spark/sql/catalyst/InternalRow;)Lorg/apache/spark/sql/catalyst/expressions/UnsafeRow;" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection" grows beyond 64 KB. ``` Besides, huge generated methods also result in performance regression. In this PR, we change its implementation to `TypedImperativeAggregate`. After the fix, `ApproxCountDistinctForIntervals` can deal with more than thousands endpoints without throwing codegen error, and improve performance from `20 sec` to `2 sec` in a test case of 500 endpoints. ## How was this patch tested? Test by an added test case and existing tests. Author: Zhenhua Wang Closes #19506 from wzhfy/change_forIntervals_typedAgg. --- .../ApproxCountDistinctForIntervals.scala | 97 ++++++++++--------- ...ApproxCountDistinctForIntervalsSuite.scala | 34 +++---- ...xCountDistinctForIntervalsQuerySuite.scala | 61 ++++++++++++ 3 files changed, 130 insertions(+), 62 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index 096d1b35a8620..d4421ca20a9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -22,9 +22,10 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExpectsInputTypes, Expression} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform /** * This function counts the approximate number of distinct values (ndv) in @@ -46,16 +47,7 @@ case class ApproxCountDistinctForIntervals( relativeSD: Double = 0.05, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with ExpectsInputTypes { - - def this(child: Expression, endpointsExpression: Expression) = { - this( - child = child, - endpointsExpression = endpointsExpression, - relativeSD = 0.05, - mutableAggBufferOffset = 0, - inputAggBufferOffset = 0) - } + extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes { def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = { this( @@ -114,29 +106,11 @@ case class ApproxCountDistinctForIntervals( private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length /** Allocate enough words to store all registers. */ - override lazy val aggBufferAttributes: Seq[AttributeReference] = { - Seq.tabulate(totalNumWords) { i => - AttributeReference(s"MS[$i]", LongType)() - } + override def createAggregationBuffer(): Array[Long] = { + Array.fill(totalNumWords)(0L) } - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - /** Fill all words with zeros. */ - override def initialize(buffer: InternalRow): Unit = { - var word = 0 - while (word < totalNumWords) { - buffer.setLong(mutableAggBufferOffset + word, 0) - word += 1 - } - } - - override def update(buffer: InternalRow, input: InternalRow): Unit = { + override def update(buffer: Array[Long], input: InternalRow): Array[Long] = { val value = child.eval(input) // Ignore empty rows if (value != null) { @@ -153,13 +127,14 @@ case class ApproxCountDistinctForIntervals( // endpoints are sorted into ascending order already if (endpoints.head > doubleValue || endpoints.last < doubleValue) { // ignore if the value is out of the whole range - return + return buffer } val hllppIndex = findHllppIndex(doubleValue) - val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp - hllppArray(hllppIndex).update(buffer, offset, value, child.dataType) + val offset = hllppIndex * numWordsPerHllpp + hllppArray(hllppIndex).update(LongArrayInternalRow(buffer), offset, value, child.dataType) } + buffer } // Find which interval (HyperLogLogPlusPlusHelper) should receive the given value. @@ -196,17 +171,18 @@ case class ApproxCountDistinctForIntervals( } } - override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: Array[Long], buffer2: Array[Long]): Array[Long] = { for (i <- hllppArray.indices) { hllppArray(i).merge( - buffer1 = buffer1, - buffer2 = buffer2, - offset1 = mutableAggBufferOffset + i * numWordsPerHllpp, - offset2 = inputAggBufferOffset + i * numWordsPerHllpp) + buffer1 = LongArrayInternalRow(buffer1), + buffer2 = LongArrayInternalRow(buffer2), + offset1 = i * numWordsPerHllpp, + offset2 = i * numWordsPerHllpp) } + buffer1 } - override def eval(buffer: InternalRow): Any = { + override def eval(buffer: Array[Long]): Any = { val ndvArray = hllppResults(buffer) // If the endpoints contains multiple elements with the same value, // we set ndv=1 for intervals between these elements. @@ -218,19 +194,23 @@ case class ApproxCountDistinctForIntervals( new GenericArrayData(ndvArray) } - def hllppResults(buffer: InternalRow): Array[Long] = { + def hllppResults(buffer: Array[Long]): Array[Long] = { val ndvArray = new Array[Long](hllppArray.length) for (i <- ndvArray.indices) { - ndvArray(i) = hllppArray(i).query(buffer, mutableAggBufferOffset + i * numWordsPerHllpp) + ndvArray(i) = hllppArray(i).query(LongArrayInternalRow(buffer), i * numWordsPerHllpp) } ndvArray } - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(mutableAggBufferOffset = newMutableAggBufferOffset) + } - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(inputAggBufferOffset = newInputAggBufferOffset) + } override def children: Seq[Expression] = Seq(child, endpointsExpression) @@ -239,4 +219,31 @@ case class ApproxCountDistinctForIntervals( override def dataType: DataType = ArrayType(LongType) override def prettyName: String = "approx_count_distinct_for_intervals" + + override def serialize(obj: Array[Long]): Array[Byte] = { + val byteArray = new Array[Byte](obj.length * 8) + var i = 0 + while (i < obj.length) { + Platform.putLong(byteArray, Platform.BYTE_ARRAY_OFFSET + i * 8, obj(i)) + i += 1 + } + byteArray + } + + override def deserialize(bytes: Array[Byte]): Array[Long] = { + assert(bytes.length % 8 == 0) + val length = bytes.length / 8 + val longArray = new Array[Long](length) + var i = 0 + while (i < length) { + longArray(i) = Platform.getLong(bytes, Platform.BYTE_ARRAY_OFFSET + i * 8) + i += 1 + } + longArray + } + + private case class LongArrayInternalRow(array: Array[Long]) extends GenericInternalRow { + override def getLong(offset: Int): Long = array(offset) + override def setLong(offset: Int, value: Long): Unit = { array(offset) = value } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index d6c38c3608bf8..73f18d4feef3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -32,7 +32,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType), MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType)))) wrongColumnTypes.foreach { dataType => - val wrongColumn = new ApproxCountDistinctForIntervals( + val wrongColumn = ApproxCountDistinctForIntervals( AttributeReference("a", dataType)(), endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_)))) assert( @@ -43,7 +43,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { }) } - var wrongEndpoints = new ApproxCountDistinctForIntervals( + var wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = Literal(0.5d)) assert( @@ -52,19 +52,19 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { case _ => false }) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)()))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The endpoints provided must be constant literals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array(10L).map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array("foobar").map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == @@ -75,25 +75,18 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { private def createEstimator[T]( endpoints: Array[T], dt: DataType, - rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = { + rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = { val input = new SpecificInternalRow(Seq(dt)) val aggFunc = ApproxCountDistinctForIntervals( BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd) - val buffer = createBuffer(aggFunc) - (aggFunc, input, buffer) - } - - private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = { - val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType)) - aggFunc.initialize(buffer) - buffer + (aggFunc, input, aggFunc.createAggregationBuffer()) } test("merging ApproxCountDistinctForIntervals instances") { val (aggFunc, input, buffer1a) = createEstimator(Array[Int](0, 10, 2000, 345678, 1000000), IntegerType) - val buffer1b = createBuffer(aggFunc) - val buffer2 = createBuffer(aggFunc) + val buffer1b = aggFunc.createAggregationBuffer() + val buffer2 = aggFunc.createAggregationBuffer() // Add the lower half to `buffer1a`. var i = 0 @@ -123,7 +116,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { } // Check if the buffers are equal. - assert(buffer2 == buffer1a, "Buffers should be equal") + assert(buffer2.sameElements(buffer1a), "Buffers should be equal") } test("test findHllppIndex(value) for values in the range") { @@ -152,6 +145,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { checkHllppIndex(endpoints = Array(1, 3, 5, 7, 7, 9), value = 7, expectedIntervalIndex = 2) } + test("round trip serialization") { + val (aggFunc, _, _) = createEstimator(Array(1, 2), DoubleType) + val longArray = (1L to 100L).toArray + val roundtrip = aggFunc.deserialize(aggFunc.serialize(longArray)) + assert(roundtrip.sameElements(longArray)) + } + test("basic operations: update, merge, eval...") { val endpoints = Array[Double](0, 0.33, 0.6, 0.6, 0.6, 1.0) val data: Seq[Double] = Seq(0, 0.6, 0.3, 1, 0.6, 0.5, 0.6, 0.33) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala new file mode 100644 index 0000000000000..c7d86bc955d67 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.test.SharedSQLContext + +class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height + // histogram usually contains hundreds of buckets. So we need to test + // ApproxCountDistinctForIntervals with large number of endpoints + // (the number of endpoints == the number of buckets + 1). + test("test ApproxCountDistinctForIntervals with large number of endpoints") { + val table = "approx_count_distinct_for_intervals_tbl" + withTable(table) { + (1 to 100000).toDF("col").createOrReplaceTempView(table) + // percentiles of 0, 0.001, 0.002 ... 0.999, 1 + val endpoints = (0 to 1000).map(_ * 100000 / 1000) + + // Since approx_count_distinct_for_intervals is not a public function, here we do + // the computation by constructing logical plan. + val relation = spark.table(table).logicalPlan + val attr = relation.output.find(_.name == "col").get + val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_)))) + val aggExpr = aggFunc.toAggregateExpression() + val namedExpr = Alias(aggExpr, aggExpr.toString)() + val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation)) + .executedPlan.executeTake(1).head + val ndvArray = ndvsRow.getArray(0).toLongArray() + assert(endpoints.length == ndvArray.length + 1) + + // Each bucket has 100 distinct values. + val expectedNdv = 100 + for (i <- ndvArray.indices) { + val ndv = ndvArray(i) + val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) + assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.") + } + } + } +} From 884d4f95f7ebfaa9d8c57cf770d10a2c6ab82d62 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 23 Oct 2017 17:21:49 -0700 Subject: [PATCH 750/779] [SPARK-21912][SQL][FOLLOW-UP] ORC/Parquet table should not create invalid column names ## What changes were proposed in this pull request? During [SPARK-21912](https://issues.apache.org/jira/browse/SPARK-21912), we skipped testing 'ADD COLUMNS' on ORC tables due to ORC limitation. Since [SPARK-21929](https://issues.apache.org/jira/browse/SPARK-21929) is resolved now, we can test both `ORC` and `PARQUET` completely. ## How was this patch tested? Pass the updated test case. Author: Dongjoon Hyun Closes #19562 from dongjoon-hyun/SPARK-21912-2. --- .../spark/sql/hive/execution/SQLQuerySuite.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1cf1c5cd5a472..39e918c3d5209 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2031,8 +2031,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-21912 ORC/Parquet table should not create invalid column names") { Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name => - withTable("t21912") { - Seq("ORC", "PARQUET").foreach { source => + Seq("ORC", "PARQUET").foreach { source => + withTable("t21912") { val m = intercept[AnalysisException] { sql(s"CREATE TABLE t21912(`col$name` INT) USING $source") }.getMessage @@ -2049,15 +2049,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { }.getMessage assert(m3.contains(s"contains invalid character(s)")) } - } - // TODO: After SPARK-21929, we need to check ORC, too. - Seq("PARQUET").foreach { source => sql(s"CREATE TABLE t21912(`col` INT) USING $source") - val m = intercept[AnalysisException] { + val m4 = intercept[AnalysisException] { sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)") }.getMessage - assert(m.contains(s"contains invalid character(s)")) + assert(m4.contains(s"contains invalid character(s)")) } } } From d9798c834f3fed060cfd18a8d38c398cb2efcc82 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 24 Oct 2017 12:44:47 +0900 Subject: [PATCH 751/779] [SPARK-22313][PYTHON] Mark/print deprecation warnings as DeprecationWarning for deprecated APIs ## What changes were proposed in this pull request? This PR proposes to mark the existing warnings as `DeprecationWarning` and print out warnings for deprecated functions. This could be actually useful for Spark app developers. I use (old) PyCharm and this IDE can detect this specific `DeprecationWarning` in some cases: **Before** **After** For console usage, `DeprecationWarning` is usually disabled (see https://docs.python.org/2/library/warnings.html#warning-categories and https://docs.python.org/3/library/warnings.html#warning-categories): ``` >>> import warnings >>> filter(lambda f: f[2] == DeprecationWarning, warnings.filters) [('ignore', <_sre.SRE_Pattern object at 0x10ba58c00>, , <_sre.SRE_Pattern object at 0x10bb04138>, 0), ('ignore', None, , None, 0)] ``` so, it won't actually mess up the terminal much unless it is intended. If this is intendedly enabled, it'd should as below: ``` >>> import warnings >>> warnings.simplefilter('always', DeprecationWarning) >>> >>> from pyspark.sql import functions >>> functions.approxCountDistinct("a") .../spark/python/pyspark/sql/functions.py:232: DeprecationWarning: Deprecated in 2.1, use approx_count_distinct instead. "Deprecated in 2.1, use approx_count_distinct instead.", DeprecationWarning) ... ``` These instances were found by: ``` cd python/pyspark grep -r "Deprecated" . grep -r "deprecated" . grep -r "deprecate" . ``` ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #19535 from HyukjinKwon/deprecated-warning. --- python/pyspark/ml/util.py | 8 ++- python/pyspark/mllib/classification.py | 2 +- python/pyspark/mllib/evaluation.py | 6 +-- python/pyspark/mllib/regression.py | 8 +-- python/pyspark/sql/dataframe.py | 3 ++ python/pyspark/sql/functions.py | 18 +++++++ python/pyspark/streaming/flume.py | 14 ++++- python/pyspark/streaming/kafka.py | 72 ++++++++++++++++++++++---- 8 files changed, 110 insertions(+), 21 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 67772910c0d38..c3c47bd79459a 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -175,7 +175,9 @@ def context(self, sqlContext): .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") + warnings.warn( + "Deprecated in 2.1 and will be removed in 3.0, use session instead.", + DeprecationWarning) self._jwrite.context(sqlContext._ssql_ctx) return self @@ -256,7 +258,9 @@ def context(self, sqlContext): .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") + warnings.warn( + "Deprecated in 2.1 and will be removed in 3.0, use session instead.", + DeprecationWarning) self._jread.context(sqlContext._ssql_ctx) return self diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index e04eeb2b60d71..cce703d432b5a 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -311,7 +311,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ warnings.warn( "Deprecated in 2.0.0. Use ml.classification.LogisticRegression or " - "LogisticRegressionWithLBFGS.") + "LogisticRegressionWithLBFGS.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index fc2a0b3b5038a..2cd1da3fbf9aa 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -234,7 +234,7 @@ def precision(self, label=None): """ if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("precision") else: return self.call("precision", float(label)) @@ -246,7 +246,7 @@ def recall(self, label=None): """ if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("recall") else: return self.call("recall", float(label)) @@ -259,7 +259,7 @@ def fMeasure(self, label=None, beta=None): if beta is None: if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("fMeasure") else: return self.call("fMeasure", label) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 1b66f5b51044b..ea107d400621d 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -278,7 +278,8 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, A condition which decides iteration termination. (default: 0.001) """ - warnings.warn("Deprecated in 2.0.0. Use ml.regression.LinearRegression.") + warnings.warn( + "Deprecated in 2.0.0. Use ml.regression.LinearRegression.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -421,7 +422,8 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, """ warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 1.0. " - "Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.") + "Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", + DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), @@ -566,7 +568,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 0.0. " "Note the default regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for " - "LinearRegression.") + "LinearRegression.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 38b01f0011671..c0b574e2b93a1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -130,6 +130,8 @@ def registerTempTable(self, name): .. note:: Deprecated in 2.0, use createOrReplaceTempView instead. """ + warnings.warn( + "Deprecated in 2.0, use createOrReplaceTempView instead.", DeprecationWarning) self._jdf.createOrReplaceTempView(name) @since(2.0) @@ -1308,6 +1310,7 @@ def unionAll(self, other): .. note:: Deprecated in 2.0, use :func:`union` instead. """ + warnings.warn("Deprecated in 2.0, use union instead.", DeprecationWarning) return self.union(other) @since(2.3) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bc374b93a433..0d40368c9cd6e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -21,6 +21,7 @@ import math import sys import functools +import warnings if sys.version < "3": from itertools import imap as map @@ -44,6 +45,14 @@ def _(col): return _ +def _wrap_deprecated_function(func, message): + """ Wrap the deprecated function to print out deprecation warnings""" + def _(col): + warnings.warn(message, DeprecationWarning) + return func(col) + return functools.wraps(func)(_) + + def _create_binary_mathfunction(name, doc=""): """ Create a binary mathfunction by name""" def _(col1, col2): @@ -207,6 +216,12 @@ def _(): """returns the relative rank (i.e. percentile) of rows within a window partition.""", } +# Wraps deprecated functions (keys) with the messages (values). +_functions_deprecated = { + 'toDegrees': 'Deprecated in 2.1, use degrees instead.', + 'toRadians': 'Deprecated in 2.1, use radians instead.', +} + for _name, _doc in _functions.items(): globals()[_name] = since(1.3)(_create_function(_name, _doc)) for _name, _doc in _functions_1_4.items(): @@ -219,6 +234,8 @@ def _(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) for _name, _doc in _functions_2_1.items(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) +for _name, _message in _functions_deprecated.items(): + globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) del _name, _doc @@ -227,6 +244,7 @@ def approxCountDistinct(col, rsd=None): """ .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead. """ + warnings.warn("Deprecated in 2.1, use approx_count_distinct instead.", DeprecationWarning) return approx_count_distinct(col, rsd) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index 2fed5940b31ea..5a975d050b0d8 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -54,8 +54,13 @@ def createStream(ssc, hostname, port, :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. + See SPARK-22142. """ + warnings.warn( + "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " + "See SPARK-22142.", + DeprecationWarning) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) helper = FlumeUtils._get_helper(ssc._sc) jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) @@ -82,8 +87,13 @@ def createPollingStream(ssc, addresses, :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. + See SPARK-22142. """ + warnings.warn( + "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " + "See SPARK-22142.", + DeprecationWarning) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) hosts = [] ports = [] diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 4af4135c81958..fdb9308604489 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,6 +15,8 @@ # limitations under the License. # +import warnings + from py4j.protocol import Py4JJavaError from pyspark.rdd import RDD @@ -56,8 +58,13 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if kafkaParams is None: kafkaParams = dict() kafkaParams.update({ @@ -105,8 +112,13 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, :return: A DStream object .. note:: Experimental - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if fromOffsets is None: fromOffsets = dict() if not isinstance(topics, list): @@ -159,8 +171,13 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, :return: An RDD object .. note:: Experimental - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if leaders is None: leaders = dict() if not isinstance(kafkaParams, dict): @@ -229,7 +246,8 @@ class OffsetRange(object): """ Represents a range of offsets from a single Kafka TopicAndPartition. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition, fromOffset, untilOffset): @@ -240,6 +258,10 @@ def __init__(self, topic, partition, fromOffset, untilOffset): :param fromOffset: Inclusive starting offset. :param untilOffset: Exclusive ending offset. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self.topic = topic self.partition = partition self.fromOffset = fromOffset @@ -270,7 +292,8 @@ class TopicAndPartition(object): """ Represents a specific topic and partition for Kafka. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition): @@ -279,6 +302,10 @@ def __init__(self, topic, partition): :param topic: Kafka topic name. :param partition: Kafka partition id. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self._topic = topic self._partition = partition @@ -303,7 +330,8 @@ class Broker(object): """ Represent the host and port info for a Kafka broker. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, host, port): @@ -312,6 +340,10 @@ def __init__(self, host, port): :param host: Broker's hostname. :param port: Broker's port. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self._host = host self._port = port @@ -323,10 +355,15 @@ class KafkaRDD(RDD): """ A Python wrapper of KafkaRDD, to provide additional information on normal RDD. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, jrdd, ctx, jrdd_deserializer): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) RDD.__init__(self, jrdd, ctx, jrdd_deserializer) def offsetRanges(self): @@ -345,10 +382,15 @@ class KafkaDStream(DStream): """ A Python wrapper of KafkaDStream - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, jdstream, ssc, jrdd_deserializer): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) DStream.__init__(self, jdstream, ssc, jrdd_deserializer) def foreachRDD(self, func): @@ -383,10 +425,15 @@ class KafkaTransformedDStream(TransformedDStream): """ Kafka specific wrapper of TransformedDStream to transform on Kafka RDD. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, prev, func): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) TransformedDStream.__init__(self, prev, func) @property @@ -405,7 +452,8 @@ class KafkaMessageAndMetadata(object): """ Kafka message and metadata information. Including topic, partition, offset and message - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition, offset, key, message): @@ -419,6 +467,10 @@ def __init__(self, topic, partition, offset, key, message): :param message: actual message payload of this Kafka message, the return data is undecoded bytearray. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self.topic = topic self.partition = partition self.offset = offset From c30d5cfc7117bdadd63bf730e88398139e0f65f4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 24 Oct 2017 08:46:22 +0100 Subject: [PATCH 752/779] [SPARK-20822][SQL] Generate code to directly get value from ColumnVector for table cache ## What changes were proposed in this pull request? This PR generates the Java code to directly get a value for a column in `ColumnVector` without using an iterator (e.g. at lines 54-69 in the generated code example) for table cache (e.g. `dataframe.cache`). This PR improves runtime performance by eliminating data copy from column-oriented storage to `InternalRow` in a `SpecificColumnarIterator` iterator for primitive type. Another PR will support primitive type array. Benchmark result: **1.2x** ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Int Sum with IntDelta cache: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ InternalRow codegen 731 / 812 43.0 23.2 1.0X ColumnVector codegen 616 / 772 51.0 19.6 1.2X ``` Benchmark program ``` intSumBenchmark(sqlContext, 1024 * 1024 * 30) def intSumBenchmark(sqlContext: SQLContext, values: Int): Unit = { import sqlContext.implicits._ val benchmarkPT = new Benchmark("Int Sum with IntDelta cache", values, 20) Seq(("InternalRow", "false"), ("ColumnVector", "true")).foreach { case (str, value) => withSQLConf(sqlContext, SQLConf. COLUMN_VECTOR_CODEGEN.key -> value) { // tentatively added for benchmarking val dfPassThrough = sqlContext.sparkContext.parallelize(0 to values - 1, 1).toDF().cache() dfPassThrough.count() // force to create df.cache() benchmarkPT.addCase(s"$str codegen") { iter => dfPassThrough.agg(sum("value")).collect } dfPassThrough.unpersist(true) } } benchmarkPT.run() } ``` Motivating example ``` val dsInt = spark.range(3).cache dsInt.count // force to build cache dsInt.filter(_ > 0).collect ``` Generated code ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inmemorytablescan_input; /* 009 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_numOutputRows; /* 010 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_scanTime; /* 011 */ private long inmemorytablescan_scanTime1; /* 012 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch inmemorytablescan_batch; /* 013 */ private int inmemorytablescan_batchIdx; /* 014 */ private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector inmemorytablescan_colInstance0; /* 015 */ private UnsafeRow inmemorytablescan_result; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder inmemorytablescan_holder; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter inmemorytablescan_rowWriter; /* 018 */ private org.apache.spark.sql.execution.metric.SQLMetric filter_numOutputRows; /* 019 */ private UnsafeRow filter_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder filter_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter filter_rowWriter; /* 022 */ /* 023 */ public GeneratedIterator(Object[] references) { /* 024 */ this.references = references; /* 025 */ } /* 026 */ /* 027 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 028 */ partitionIndex = index; /* 029 */ this.inputs = inputs; /* 030 */ inmemorytablescan_input = inputs[0]; /* 031 */ inmemorytablescan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 032 */ inmemorytablescan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 033 */ inmemorytablescan_scanTime1 = 0; /* 034 */ inmemorytablescan_batch = null; /* 035 */ inmemorytablescan_batchIdx = 0; /* 036 */ inmemorytablescan_colInstance0 = null; /* 037 */ inmemorytablescan_result = new UnsafeRow(1); /* 038 */ inmemorytablescan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(inmemorytablescan_result, 0); /* 039 */ inmemorytablescan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(inmemorytablescan_holder, 1); /* 040 */ filter_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 041 */ filter_result = new UnsafeRow(1); /* 042 */ filter_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(filter_result, 0); /* 043 */ filter_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_holder, 1); /* 044 */ /* 045 */ } /* 046 */ /* 047 */ protected void processNext() throws java.io.IOException { /* 048 */ if (inmemorytablescan_batch == null) { /* 049 */ inmemorytablescan_nextBatch(); /* 050 */ } /* 051 */ while (inmemorytablescan_batch != null) { /* 052 */ int inmemorytablescan_numRows = inmemorytablescan_batch.numRows(); /* 053 */ int inmemorytablescan_localEnd = inmemorytablescan_numRows - inmemorytablescan_batchIdx; /* 054 */ for (int inmemorytablescan_localIdx = 0; inmemorytablescan_localIdx < inmemorytablescan_localEnd; inmemorytablescan_localIdx++) { /* 055 */ int inmemorytablescan_rowIdx = inmemorytablescan_batchIdx + inmemorytablescan_localIdx; /* 056 */ int inmemorytablescan_value = inmemorytablescan_colInstance0.getInt(inmemorytablescan_rowIdx); /* 057 */ /* 058 */ boolean filter_isNull = false; /* 059 */ /* 060 */ boolean filter_value = false; /* 061 */ filter_value = inmemorytablescan_value > 1; /* 062 */ if (!filter_value) continue; /* 063 */ /* 064 */ filter_numOutputRows.add(1); /* 065 */ /* 066 */ filter_rowWriter.write(0, inmemorytablescan_value); /* 067 */ append(filter_result); /* 068 */ if (shouldStop()) { inmemorytablescan_batchIdx = inmemorytablescan_rowIdx + 1; return; } /* 069 */ } /* 070 */ inmemorytablescan_batchIdx = inmemorytablescan_numRows; /* 071 */ inmemorytablescan_batch = null; /* 072 */ inmemorytablescan_nextBatch(); /* 073 */ } /* 074 */ inmemorytablescan_scanTime.add(inmemorytablescan_scanTime1 / (1000 * 1000)); /* 075 */ inmemorytablescan_scanTime1 = 0; /* 076 */ } /* 077 */ /* 078 */ private void inmemorytablescan_nextBatch() throws java.io.IOException { /* 079 */ long getBatchStart = System.nanoTime(); /* 080 */ if (inmemorytablescan_input.hasNext()) { /* 081 */ org.apache.spark.sql.execution.columnar.CachedBatch inmemorytablescan_cachedBatch = (org.apache.spark.sql.execution.columnar.CachedBatch)inmemorytablescan_input.next(); /* 082 */ inmemorytablescan_batch = org.apache.spark.sql.execution.columnar.InMemoryRelation$.MODULE$.createColumn(inmemorytablescan_cachedBatch); /* 083 */ /* 084 */ inmemorytablescan_numOutputRows.add(inmemorytablescan_batch.numRows()); /* 085 */ inmemorytablescan_batchIdx = 0; /* 086 */ inmemorytablescan_colInstance0 = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) inmemorytablescan_batch.column(0); org.apache.spark.sql.execution.columnar.ColumnAccessor$.MODULE$.decompress(inmemorytablescan_cachedBatch.buffers()[0], (org.apache.spark.sql.execution.vectorized.WritableColumnVector) inmemorytablescan_colInstance0, org.apache.spark.sql.types.DataTypes.IntegerType, inmemorytablescan_cachedBatch.numRows()); /* 087 */ /* 088 */ } /* 089 */ inmemorytablescan_scanTime1 += System.nanoTime() - getBatchStart; /* 090 */ } /* 091 */ } ``` ## How was this patch tested? Add test cases into `DataFrameTungstenSuite` and `WholeStageCodegenSuite` Author: Kazuaki Ishizaki Closes #18747 from kiszk/SPARK-20822a. --- .../sql/execution/ColumnarBatchScan.scala | 3 - .../sql/execution/WholeStageCodegenExec.scala | 24 ++++---- .../execution/columnar/ColumnAccessor.scala | 8 +++ .../columnar/InMemoryTableScanExec.scala | 57 +++++++++++++++++-- .../spark/sql/DataFrameTungstenSuite.scala | 36 ++++++++++++ .../execution/WholeStageCodegenSuite.scala | 32 +++++++++++ 6 files changed, 141 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 1afe83ea3539e..eb01e126bcbef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType @@ -31,8 +30,6 @@ import org.apache.spark.sql.types.DataType */ private[sql] trait ColumnarBatchScan extends CodegenSupport { - val inMemoryTableScan: InMemoryTableScanExec = null - def vectorTypes: Option[Seq[String]] = None override lazy val metrics = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1aaaf896692d1..e37d133ff336a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -282,6 +282,18 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp object WholeStageCodegenExec { val PIPELINE_DURATION_METRIC = "duration" + + private def numOfNestedFields(dataType: DataType): Int = dataType match { + case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum + case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) + case a: ArrayType => numOfNestedFields(a.elementType) + case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) + case _ => 1 + } + + def isTooManyFields(conf: SQLConf, dataType: DataType): Boolean = { + numOfNestedFields(dataType) > conf.wholeStageMaxNumFields + } } /** @@ -490,22 +502,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case _ => true } - private def numOfNestedFields(dataType: DataType): Int = dataType match { - case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum - case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) - case a: ArrayType => numOfNestedFields(a.elementType) - case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) - case _ => 1 - } - private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns val hasTooManyOutputFields = - numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + WholeStageCodegenExec.isTooManyFields(conf, plan.schema) val hasTooManyInputFields = - plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields) + plan.children.exists(p => WholeStageCodegenExec.isTooManyFields(conf, p.schema)) !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 24c8ac81420cb..445933d98e9d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -163,4 +163,12 @@ private[sql] object ColumnAccessor { throw new RuntimeException("Not support non-primitive type now") } } + + def decompress( + array: Array[Byte], columnVector: WritableColumnVector, dataType: DataType, numRows: Int): + Unit = { + val byteBuffer = ByteBuffer.wrap(array) + val columnAccessor = ColumnAccessor(dataType, byteBuffer) + decompress(columnAccessor, columnVector, numRows) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 139da1c519da2..43386e7a03c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,21 +23,66 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.UserDefinedType +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.vectorized._ +import org.apache.spark.sql.types._ case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode { + extends LeafExecNode with ColumnarBatchScan { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override def vectorTypes: Option[Seq[String]] = + Option(Seq.fill(attributes.length)(classOf[OnHeapColumnVector].getName)) + + /** + * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. + * If false, get data from UnsafeRow build from ColumnVector + */ + override val supportCodegen: Boolean = { + // In the initial implementation, for ease of review + // support only primitive data types and # of fields is less than wholeStageMaxNumFields + relation.schema.fields.forall(f => f.dataType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType => true + case _ => false + }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) + } + + private val columnIndices = + attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray + + private val relationSchema = relation.schema.toArray + + private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) + + private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + val rowCount = cachedColumnarBatch.numRows + val columnVectors = OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + val columnarBatch = new ColumnarBatch( + columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + columnarBatch.setNumRows(rowCount) + + for (i <- 0 until attributes.length) { + ColumnAccessor.decompress( + cachedColumnarBatch.buffers(columnIndices(i)), + columnarBatch.column(i).asInstanceOf[WritableColumnVector], + columnarBatchSchema.fields(i).dataType, rowCount) + } + columnarBatch + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + assert(supportCodegen) + val buffers = relation.cachedColumnBuffers + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + } override def output: Seq[Attribute] = attributes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index fe6ba83b4cbfb..0881212a64de8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -73,4 +73,40 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(outerStruct)) } + + test("primitive data type accesses in persist data") { + val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong, + 31.25.toFloat, 63.75, null) + val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, IntegerType) + val schemas = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data))) + val df = spark.createDataFrame(rdd, StructType(schemas)) + val row = df.persist.take(1).apply(0) + checkAnswer(df, row) + } + + test("access cache multiple times") { + val df0 = sparkContext.parallelize(Seq(1, 2, 3), 1).toDF("x").cache + df0.count + val df1 = df0.filter("x > 1") + checkAnswer(df1, Seq(Row(2), Row(3))) + val df2 = df0.filter("x > 2") + checkAnswer(df2, Row(3)) + + val df10 = sparkContext.parallelize(Seq(3, 4, 5, 6), 1).toDF("x").cache + for (_ <- 0 to 2) { + val df11 = df10.filter("x > 5") + checkAnswer(df11, Row(6)) + } + } + + test("access only some column of the all of columns") { + val df = spark.range(1, 10).map(i => (i, (i + 1).toDouble)).toDF("l", "d") + df.cache + df.count + assert(df.filter("d < 3").count == 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 098e4cfeb15b2..bc05dca578c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed @@ -117,6 +118,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } + test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { + import testImplicits._ + + val dsInt = spark.range(3).cache + dsInt.count + val dsIntFilter = dsInt.filter(_ > 0) + val planInt = dsIntFilter.queryExecution.executedPlan + assert(planInt.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined + ) + assert(dsIntFilter.collect() === Array(1, 2)) + + // cache for string type is not supported for InMemoryTableScanExec + val dsString = spark.range(3).map(_.toString).cache + dsString.count + val dsStringFilter = dsString.filter(_ == "1") + val planString = dsStringFilter.queryExecution.executedPlan + assert(planString.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec]).isDefined + ) + assert(dsStringFilter.collect() === Array("1")) + } + test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10) From 8beeaed66bde0ace44495b38dc967816e16b3464 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 24 Oct 2017 13:56:10 +0100 Subject: [PATCH 753/779] [SPARK-21936][SQL][FOLLOW-UP] backward compatibility test framework for HiveExternalCatalog ## What changes were proposed in this pull request? Adjust Spark download in test to use Apache mirrors and respect its load balancer, and use Spark 2.1.2. This follows on a recent PMC list thread about removing the cloudfront download rather than update it further. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #19564 from srowen/SPARK-21936.2. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 305f5b533d592..5f8c9d5799662 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -53,7 +53,9 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def downloadSpark(version: String): Unit = { import scala.sys.process._ - val url = s"https://d3kbcqa49mib13.cloudfront.net/spark-$version-bin-hadoop2.7.tgz" + val preferredMirror = + Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim + val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! @@ -142,7 +144,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.1", "2.2.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0") protected var spark: SparkSession = _ From 3f5ba968c5af7911a2f6c452500b6a629a3de8db Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 24 Oct 2017 09:11:52 -0700 Subject: [PATCH 754/779] [SPARK-22301][SQL] Add rule to Optimizer for In with not-nullable value and empty list ## What changes were proposed in this pull request? For performance reason, we should resolve in operation on an empty list as false in the optimizations phase, ad discussed in #19522. ## How was this patch tested? Added UT cc gatorsmile Author: Marco Gaido Author: Marco Gaido Closes #19523 from mgaido91/SPARK-22301. --- .../sql/catalyst/optimizer/expressions.scala | 7 +++++-- .../sql/catalyst/optimizer/OptimizeInSuite.scala | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 273bc6ce27c5d..523b53b39d6b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -169,13 +169,16 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { /** * Optimize IN predicates: - * 1. Removes literal repetitions. - * 2. Replaces [[In (value, seq[Literal])]] with optimized version + * 1. Converts the predicate to false when the list is empty and + * the value is not nullable. + * 2. Removes literal repetitions. + * 3. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { + case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index eaad1e32a8aba..d7acd139225cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -175,4 +175,20 @@ class OptimizeInSuite extends PlanTest { } } } + + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + + "when value is not nullable") { + val originalQuery = + testRelation + .where(In(Literal("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(Literal(false)) + .analyze + + comparePlans(optimized, correctAnswer) + } } From bc1e76632ddec8fc64726086905183d1f312bca4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 25 Oct 2017 06:33:44 +0100 Subject: [PATCH 755/779] [SPARK-22348][SQL] The table cache providing ColumnarBatch should also do partition batch pruning ## What changes were proposed in this pull request? We enable table cache `InMemoryTableScanExec` to provide `ColumnarBatch` now. But the cached batches are retrieved without pruning. In this case, we still need to do partition batch pruning. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #19569 from viirya/SPARK-22348. --- .../columnar/InMemoryTableScanExec.scala | 70 ++++++++++--------- .../columnar/InMemoryColumnarQuerySuite.scala | 27 ++++++- 2 files changed, 64 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 43386e7a03c32..2ae3f35eb1da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,7 +78,7 @@ case class InMemoryTableScanExec( override def inputRDDs(): Seq[RDD[InternalRow]] = { assert(supportCodegen) - val buffers = relation.cachedColumnBuffers + val buffers = filteredCachedBatches() // HACK ALERT: This is actually an RDD[ColumnarBatch]. // We're taking advantage of Scala's type erasure here to pass these batches along. Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) @@ -180,19 +180,11 @@ case class InMemoryTableScanExec( private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - + private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. val schema = relation.partitionStatistics.schema val schemaIndex = schema.zipWithIndex - val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => @@ -201,35 +193,49 @@ case class InMemoryTableScanExec( schema) partitionFilter.initialize(index) + // Do partition batch pruning if enabled + if (inMemoryPartitionPruningEnabled) { + cachedBatchIterator.filter { cachedBatch => + if (!partitionFilter.eval(cachedBatch.stats)) { + logDebug { + val statsString = schemaIndex.map { case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") + s"Skipping partition based on stats $statsString" + } + false + } else { + true + } + } + } else { + cachedBatchIterator + } + } + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulators) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = attributes.map { a => relOutput.indexOf(a.exprId) -> a.dataType }.unzip - // Do partition batch pruning if enabled - val cachedBatchesToScan = - if (inMemoryPartitionPruningEnabled) { - cachedBatchIterator.filter { cachedBatch => - if (!partitionFilter.eval(cachedBatch.stats)) { - logDebug { - val statsString = schemaIndex.map { case (a, i) => - val value = cachedBatch.stats.get(i, a.dataType) - s"${a.name}: $value" - }.mkString(", ") - s"Skipping partition based on stats $statsString" - } - false - } else { - true - } - } - } else { - cachedBatchIterator - } - // update SQL metrics - val withMetrics = cachedBatchesToScan.map { batch => + val withMetrics = cachedBatchIterator.map { batch => if (enableAccumulators) { readBatches.add(1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 2f249c850a088..e662e294228db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.LocalTableScanExec +import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -454,4 +454,29 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) } + + test("SPARK-22348: table cache should do partition batch pruning") { + Seq("true", "false").foreach { enabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enabled) { + val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y") + df1.unpersist() + df1.cache() + + // Push predicate to the cached table. + val df2 = df1.where("y = 3") + + val planBeforeFilter = df2.queryExecution.executedPlan.collect { + case f: FilterExec => f.child + } + assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) + + val execPlan = if (enabled == "true") { + WholeStageCodegenExec(planBeforeFilter.head) + } else { + planBeforeFilter.head + } + assert(execPlan.executeCollectPublic().length == 0) + } + } + } } From 524abb996abc9970d699623c13469ea3b6d2d3fc Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 24 Oct 2017 22:59:46 -0700 Subject: [PATCH 756/779] [SPARK-21101][SQL] Catch IllegalStateException when CREATE TEMPORARY FUNCTION ## What changes were proposed in this pull request? It must `override` [`public StructObjectInspector initialize(ObjectInspector[] argOIs)`](https://github.com/apache/hive/blob/release-2.0.0/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTF.java#L70) when create a UDTF. If you `override` [`public StructObjectInspector initialize(StructObjectInspector argOIs)`](https://github.com/apache/hive/blob/release-2.0.0/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTF.java#L49), `IllegalStateException` will throw. per: [HIVE-12377](https://issues.apache.org/jira/browse/HIVE-12377). This PR catch `IllegalStateException` and point user to `override` `public StructObjectInspector initialize(ObjectInspector[] argOIs)`. ## How was this patch tested? unit tests Source code and binary jar: [SPARK-21101.zip](https://github.com/apache/spark/files/1123763/SPARK-21101.zip) These two source code copy from : https://github.com/apache/hive/blob/release-2.0.0/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTFStack.java Author: Yuming Wang Closes #18527 from wangyum/SPARK-21101. --- .../spark/sql/hive/HiveSessionCatalog.scala | 11 ++++++-- .../src/test/resources/SPARK-21101-1.0.jar | Bin 0 -> 7439 bytes .../sql/hive/execution/SQLQuerySuite.scala | 26 ++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 sql/hive/src/test/resources/SPARK-21101-1.0.jar diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index b256ffc27b199..1f11adbd4f62e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -94,8 +94,15 @@ private[sql] class HiveSessionCatalog( } } catch { case NonFatal(e) => - val analysisException = - new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") + val noHandlerMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" + val errorMsg = + if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + s"$noHandlerMsg\nPlease make sure your function overrides " + + "`public StructObjectInspector initialize(ObjectInspector[] args)`." + } else { + noHandlerMsg + } + val analysisException = new AnalysisException(errorMsg) analysisException.setStackTrace(e.getStackTrace) throw analysisException } diff --git a/sql/hive/src/test/resources/SPARK-21101-1.0.jar b/sql/hive/src/test/resources/SPARK-21101-1.0.jar new file mode 100644 index 0000000000000000000000000000000000000000..768b2334db5c3aa8b1e4186af5fde86120d3231b GIT binary patch literal 7439 zcmb7J1z1#D*XAOFgmkw^#~@uINJ=*%LyqJK2t#+*05TE|HH4DV-Q7rsgi4oO8m=&+ zzz3gKz2EQkKKEbq%sJB?WmMT^==sBelT+RTUu57@i7{b9iv1 zQk!>DU~$cfTY0#TTLmbCb$vDaK>|5f8?#3}GD@37MO()ujkB1P7MD0)K%2~mWI+4q z@{Y2AvvS+A{HISWgnp4F`pUw6j<|P&J)E6QcuaWEzJ>L3^ca_6IXGE=5Bz5j+&?|Q zj$m^e%YSer`d>$9N3fIaKe&_qox3yC?jIo3zk=96-2N#t=6}RldRUsfxk;BBJ-@-KWJ^N#rinTf~4%X6gKbOjmDORleAB(l_w?F9>Q9~XJ@2E!6-9*Gl0>o zg@v=KaIdBFt>}1BJs8`8E}_Q2xeK9bsT^54w)1BiX(aZN9c3kyC&D@yWku->4&O&^ zx9Y)^B^Wwt*H%N72iv2hn@APtT1YzFJ5lr|h@q2U6qqs!tfFKRI|LCs)54NMP?+K^ zu`wbEma1?1ak6-@gkSwBZR>tVemCtCEIn{*|?j8tz4td~>TBI^Z{YnJb7)X8!xb6LfD8)fHC)(G&T}nuonup|g_zYt-?nkFHOP|`O zfvKM~21pUp5n{Q3u)SV+6@q`in;hHO&@4+#oIm@xyLvLfipDA3^hTX5Cdchm_K5hY z?WH8yyLWv*l6WC=%LiwMpWbOWH3mw)SFW)FO?fJUa? zRDl4`LpMl4mmk}`%&^1*T;;c5j2YI} z)X#G#nWwHePY{!4PmxK(i6CAt-Lw`0flCv%;yudCGi^^SMjw&ieURn7*2~|i`hg9@ z2b1piI~_@-C-X3J&@AVT_nuN_zFXMTr@cFy9L`a|ocY4Ajo_Fv|=$ z#qki151<*P>H+{-J83ZTf)-ZP0&5x%VL`R~HEm4pw4ylY0Z}-|;MTUt+htq>vf*ii z(I0XR2NTg%KD)cPNRq<}PO+8dwS@H5Hd9s#cA+Y74k?i2kP;PDLjjWNcpi75Nv62#&m!$q!VL|RtgL`K$BiZI;7HSewdKMiDEJ$;k)r{wfZ#+6s6d{u+OK)Ur7xt^9f^}8S74LE8dI#P=XjW0@j4$ zl4wfUIsu74G+j4Hhc{;Lx@BvVpqFScSl}1JJ=DpFBg3 z7{lL_#?Zwkljpz6c_qv*WI0e}Qj#;#kW6TkzT1|N<0JQ41hf13^VVLKZd_3=Ic3I{ z>rMuqPSuqDA{w3o6*oL#P41eA7fK4${(522x0dwWNtIlJ5sC?v*_s>2<2RE9+vN^v ztWFh>hXpsQ8~VKxsG%#y&=2Dl-f{g=Q`U~j*nS1?(W@V1lOl7DlBI&4bUtf%!%HC` zQa=}bx53BIMqAsLx^Us+iaoUp;E3T#?fwy(l%W28DiwPtmB9?-KH0DjigTmrfHou6 zvZd}^sHO<@%%A(zZlJz*8vPy5aMr*>a!VGwG+NKMM{B=dJ_ ztMX@QxDmJ+7nLu>qJCGQ--H@Z#jT3}J|5;(rIRHLf;O{_EALJ&g6(=&nxtzymsV}a zPT4_JzM~p<(MzPFXyHu8%AJ-MjAQIipu$djgXfB_kO9bY3IQ6H=jh|3Liy zn)yT3V*;y#Q67J9-}8O;;i(Nc-Wxg`+BhCalVoIelU8DP9L>Y`4{2=g za=02VYLcti-qL1_>D7y@c!fuLqL^9thPPx#Olr=1842IHCu4k#1QJ!m?m>eS)Cw2fWfI2RaRPPTyA>(kG(nou=B-+B>8VYe`h<$#uM)lD;;N_~`b~|8ujmd!_gZuLVfkdGH5}IT$ zV>woe4WrA0N3{=wrON?I<#e*k*vHTtQQ+*5U?M1Ht_L2XuJ6EdhCb#bE9j3fCADP~ z-l2F~EW_<3d)n1LHu@a9>rG>PR`6?LO;P3O^J3z1cDPQn_)o`{nM#48gIf4|M|E-U zq9{&?-@n`G_IQ^tSt)154eS_(uj2Dx0V8zbbp6?b8r$Va;##99*F+$ceTeTdgC)!PrGp^_^8=^;_pQDV%mqvI-l zg^t>#M~)>=dn)U63zPKuv&+K@R2qlwe^R49VkJFqfTTH%Jkd*k`q~qkYR%$DzB=J0 z;GHacCRCiLti63iqFp34G6ZLu=eQN#%q-NAyxAu6wp~m78kVywpI@7g{(l)?2V+oaKej1+S+9we<#yOXLqZb~ORj@l zX6V5XZwm|IIPwq8ak7WgQdvW`Fr>8=@mcQH4p}}8-vnZPjaGIM7P;hmRuivoV*nJg zXU-LxtEq%RoaYtm!m?-@iH}#KV-HmahAJx3x}Gg)AWpp->-}9`yA9jK4D$uzvEfql zr#$BwUQhTKBYifd=``cGGuX652$>*ezZ;X42JEn>;~8u=)`7eTh^yl6+#Iqgv2{#U zt+wB_nc7snm8Zcnj#V5AFS< zsy|u_Id5w%>f1@&>KGHd9wt%in``MchaULQ$LQva?UqDN^;xdvjrDa_JZs%2%%d|7 zc~ygT>q!-+q?pZ+_&aO}bZBLYK&m4o&!M#3Ec*lXJLPCP>K&Scs!8tZsa*b=0lb&t z6xs?8@EgHR{Vf>gkExCf@LP!Q)GIr6>DT#nIXiKbI#F~80p}hP5-J5!PgwBD<^q&F zLH=gd>@0z5Oj!hvo#Ql!T333LDJ6C@#$9=v56L`9paN zXy}B{tD4>lJB=eM=22*3e!QG2E&uxj^V5*5mznbLBNyZPyIY>Xc=Of`-Gv#my;n(Z zUgT1W85og106p&HW852*`0`N<>-+OQBYgBTY%a}Q@wbKnIQyJ&>?v6oZ)A;Fbu|xDX3y25n zX|J>5>1DCudaLNaw#)7rugT*?;kOhndW$NtI`)0Mp#~>8*AacS_!$)#a0>C%XeT&8 z+|$poRrMQ)B~m@{bG6o|q#$j2%D{o*o5xvULua}OXD(UHLnULNoB2jG1xD6ro#kgp3_HpoB8QGXM<03M2k2VPu7OY{btVwLN+&$ zFF2BU;|~f%T)McqJW>A3Odo92Z3(RHN2QnqdqJyAJILOwgc*ieeVcCk&w9L6*M$5<} z#PnHn)|WVsqB8j*U`hU3w(ZWe?d_LK)W53>DwJlPkL&~*`v*GxX0PpnhbL&Jw(?eD z3cAQSg*B%t zGcwZ{$F-V-{pMTpO`Q)uP^Y{)PnQ72TyHRpiodO=C8Gq)6c=$qN7bW`t#Ys8CK@?QUPv@LUYV8SM#lY8kF zK75w5*G^C)A2%t}2WR9!Lc@90bK#^BTS8rm%%_#S*B5@>CP`<%POxy*x+W~iK4h1uxFTnlXz4`g;+B)A9&BC&p=~^ zYHCZ=EwUM#f|}Jhd*4$zD?xZgWzfg<0`P60@^f{K;Buexy?SU5cC-9Zh5Ff<{AXu# z2Ya5!7T+sZKVxqF5~HIjt*rv&=i}q$`;Ny#dvIB|^zto}eT$BU=5kr_Fm{;WCH37ajo>sf7umUIbI8w#LHvls9MDxr(fF!yGR7`%Cj_;JVp5M61kbIU z>SlQFejKW2z;2MVi?X?&^C-wvjf6Kzp$X7*^W7$h6xu%XBRIzGO&zhq{^ILhDHvl&D^Mc1 zPL-zqa$g`O@wI`2Tv+wM?yO-JUyn2qK~K!}LC=(wfzbzqpf88;3~`zuy1_Unj)Fs6 zC4q7PO^}(WAq^s>*fDnq>UpwsZY!Y?XAoqDTU!NM5-g|K29>MDz(qwK=h1KzNzfO^ zt^m~%SrG8C!_R8EvCkuSeOL%xdSFqE@=8O#x;g4?n!tU{@*p%8&g3x-S3MnlUFV4l z(K+5YL`TSm7~yjsuDj%q6E-^S_wj?X2z$zdJ;td?7++%|gWb)@tu4;TjRF8ccb#`S z-Kxk$H@&K!J^;=<$DJ)An~JuVmPSRnO96_93;3_t!$S}D?mY|sPJc1Ilivb{jk-W(9_T~hN< zTp`lzjbJm=P(A}3Eh0{+-B8}LPtnlE){4^2E`Zf`3X~snre3odX9v_)K7a+}V6%Ex zq27Y`#RVgcv50O$1YG+G`oS(6A!~cx^AVD4P!Q^Wr`o6^v+2`BT2*UWF&Vu#`Mjp>jiMM9Z74wRyi>*SEKsC&Yy$ zS6`M7zLIo(b>N5ZCL3v+L|j=WMGHc#a}| zor^K;%e8Ma`0JO3S*gU@qsxh5arx){*Av9i$-&Xm$pvERtmCFJqykjtR#MYaEYRIo z<5n43(Z;`i*re=CdP9}$23}D>-=t={n5bNPU)w0Mkh?lxR6uHEN^>NDtC_vEs!gC< zSWHU)HdiQDc|NlMkgpztotSo!a<$mj!p6p|4{%xW{|%k=+OHLh%b+if5N${L>im5L z`T_fU1>?&3%CZC5E|1?~f6+63F#Nn0U5O1}?W*CH%ge9Q!>>cV5+A)Vy!gSB z^^ZJ%l`_6D{=Jy-jWauN?O!tfUuomVK>z*M_m373m!tTXVf)(qSKasDXa5oN_ZGf> zi~nfg=QrspZGJUEw2KSLOL6V{yZ9}MetZ6(!B=VY2iW#!@b{$pYmTe5`hz3mmmL3_ zWWVwKeVYBoxD&YWulW8 true, "udtf_stack2" -> true) { + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack1 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + val cnt = + sql("SELECT udtf_stack1(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')").count() + assert(cnt === 2) + + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack2 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack2' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql("SELECT udtf_stack2(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')") + } + assert( + e.getMessage.contains("public StructObjectInspector initialize(ObjectInspector[] args)")) + } + } + test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") { val table = "test21721" withTable(table) { From 427359f077ad469d78c97972d021535f30a1e418 Mon Sep 17 00:00:00 2001 From: Ruben Berenguel Montoro Date: Tue, 24 Oct 2017 23:02:11 -0700 Subject: [PATCH 757/779] [SPARK-13947][SQL] The error message from using an invalid column reference is not clear ## What changes were proposed in this pull request? Rewritten error message for clarity. Added extra information in case of attribute name collision, hinting the user to double-check referencing two different tables ## How was this patch tested? No functional changes, only final message has changed. It has been tested manually against the situation proposed in the JIRA ticket. Automated tests in repository pass. This PR is original work from me and I license this work to the Spark project Author: Ruben Berenguel Montoro Author: Ruben Berenguel Montoro Author: Ruben Berenguel Closes #17100 from rberenguel/SPARK-13947-error-message. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 19 ++++++++++++--- .../analysis/AnalysisErrorSuite.scala | 23 +++++++++++++------ .../invalid-correlation.sql.out | 2 +- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d9906bb6e6ede..b5e8bdd79869e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -272,10 +272,23 @@ trait CheckAnalysis extends PredicateHelper { case o if o.children.nonEmpty && o.missingInput.nonEmpty => val missingAttributes = o.missingInput.mkString(",") val input = o.inputSet.mkString(",") + val msgForMissingAttributes = s"Resolved attribute(s) $missingAttributes missing " + + s"from $input in operator ${operator.simpleString}." - failAnalysis( - s"resolved attribute(s) $missingAttributes missing from $input " + - s"in operator ${operator.simpleString}") + val resolver = plan.conf.resolver + val attrsWithSameName = o.missingInput.filter { missing => + o.inputSet.exists(input => resolver(missing.name, input.name)) + } + + val msg = if (attrsWithSameName.nonEmpty) { + val sameNames = attrsWithSameName.map(_.name).mkString(",") + s"$msgForMissingAttributes Attribute(s) with the same name appear in the " + + s"operation: $sameNames. Please check if the right attribute(s) are used." + } else { + msgForMissingAttributes + } + + failAnalysis(msg) case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => failAnalysis( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 884e113537c93..5d2f8e735e3d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -408,16 +408,25 @@ class AnalysisErrorSuite extends AnalysisTest { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) // Since we manually construct the logical plan at here and Sum only accept // LongType, DoubleType, and DecimalType. We use LongType as the type of a. - val plan = - Aggregate( - Nil, - Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil, - LocalRelation( - AttributeReference("a", LongType)(exprId = ExprId(2)))) + val attrA = AttributeReference("a", LongType)(exprId = ExprId(1)) + val otherA = AttributeReference("a", LongType)(exprId = ExprId(2)) + val attrC = AttributeReference("c", LongType)(exprId = ExprId(3)) + val aliases = Alias(sum(attrA), "b")() :: Alias(sum(attrC), "d")() :: Nil + val plan = Aggregate( + Nil, + aliases, + LocalRelation(otherA)) assert(plan.resolved) - assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil) + val resolved = s"${attrA.toString},${attrC.toString}" + + val errorMsg = s"Resolved attribute(s) $resolved missing from ${otherA.toString} " + + s"in operator !Aggregate [${aliases.mkString(", ")}]. " + + s"Attribute(s) with the same name appear in the operation: a. " + + "Please check if the right attribute(s) are used." + + assertAnalysisError(plan, errorMsg :: Nil) } test("error test for self-join") { diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index e4b1a2dbc675c..2586f26f71c35 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -63,7 +63,7 @@ WHERE t1a IN (SELECT min(t2a) struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); +Resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]).; -- !query 5 From 6c6950839da991bd41accdb8fb03fbc3b588c1e4 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 25 Oct 2017 12:51:20 +0100 Subject: [PATCH 758/779] [SPARK-22322][CORE] Update FutureAction for compatibility with Scala 2.12 Future ## What changes were proposed in this pull request? Scala 2.12's `Future` defines two new methods to implement, `transform` and `transformWith`. These can be implemented naturally in Spark's `FutureAction` extension and subclasses, but, only in terms of the new methods that don't exist in Scala 2.11. To support both at the same time, reflection is used to implement these. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #19561 from srowen/SPARK-22322. --- .../scala/org/apache/spark/FutureAction.scala | 59 ++++++++++++++++++- pom.xml | 2 +- .../FlatMapGroupsWithStateSuite.scala | 3 +- .../sql/streaming/StreamingQuerySuite.scala | 2 +- .../util/FileBasedWriteAheadLog.scala | 2 +- 5 files changed, 62 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 1034fdcae8e8c..036c9a60630ea 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -89,7 +89,11 @@ trait FutureAction[T] extends Future[T] { */ override def value: Option[Try[T]] - // These two methods must be implemented in Scala 2.12, but won't be used by Spark + // These two methods must be implemented in Scala 2.12. They're implemented as a no-op here + // and then filled in with a real implementation in the two subclasses below. The no-op exists + // here so that those implementations can declare "override", necessary in 2.12, while working + // in 2.11, where the method doesn't exist in the superclass. + // After 2.11 support goes away, remove these two: def transform[S](f: (Try[T]) => Try[S])(implicit executor: ExecutionContext): Future[S] = throw new UnsupportedOperationException() @@ -113,6 +117,42 @@ trait FutureAction[T] extends Future[T] { } +/** + * Scala 2.12 defines the two new transform/transformWith methods mentioned above. Impementing + * these for 2.12 in the Spark class here requires delegating to these same methods in an + * underlying Future object. But that only exists in 2.12. But these methods are only called + * in 2.12. So define helper shims to access these methods on a Future by reflection. + */ +private[spark] object FutureAction { + + private val transformTryMethod = + try { + classOf[Future[_]].getMethod("transform", classOf[(_) => _], classOf[ExecutionContext]) + } catch { + case _: NoSuchMethodException => null // Would fail later in 2.11, but not called in 2.11 + } + + private val transformWithTryMethod = + try { + classOf[Future[_]].getMethod("transformWith", classOf[(_) => _], classOf[ExecutionContext]) + } catch { + case _: NoSuchMethodException => null // Would fail later in 2.11, but not called in 2.11 + } + + private[spark] def transform[T, S]( + future: Future[T], + f: (Try[T]) => Try[S], + executor: ExecutionContext): Future[S] = + transformTryMethod.invoke(future, f, executor).asInstanceOf[Future[S]] + + private[spark] def transformWith[T, S]( + future: Future[T], + f: (Try[T]) => Future[S], + executor: ExecutionContext): Future[S] = + transformWithTryMethod.invoke(future, f, executor).asInstanceOf[Future[S]] + +} + /** * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include @@ -153,6 +193,18 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)} def jobIds: Seq[Int] = Seq(jobWaiter.jobId) + + override def transform[S](f: (Try[T]) => Try[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transform( + jobWaiter.completionFuture, + (u: Try[Unit]) => f(u.map(_ => resultFunc)), + e) + + override def transformWith[S](f: (Try[T]) => Future[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transformWith( + jobWaiter.completionFuture, + (u: Try[Unit]) => f(u.map(_ => resultFunc)), + e) } @@ -246,6 +298,11 @@ class ComplexFutureAction[T](run : JobSubmitter => Future[T]) def jobIds: Seq[Int] = subActions.flatMap(_.jobIds) + override def transform[S](f: (Try[T]) => Try[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transform(p.future, f, e) + + override def transformWith[S](f: (Try[T]) => Future[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transformWith(p.future, f, e) } diff --git a/pom.xml b/pom.xml index b9c972855204a..2d59f06811a82 100644 --- a/pom.xml +++ b/pom.xml @@ -2692,7 +2692,7 @@ scala-2.12 - 2.12.3 + 2.12.4 2.12 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index af08186aadbb0..b906393a379ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} -import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -1201,7 +1200,7 @@ object FlatMapGroupsWithStateSuite { } catch { case u: UnsupportedOperationException => return - case _ => + case _: Throwable => throw new TestFailedException("Unexpected exception when trying to get watermark", 20) } throw new TestFailedException("Could get watermark when not expected", 20) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index c53889bb8566c..cc693909270f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -744,7 +744,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(returnedValue === expectedReturnValue, "Returned value does not match expected") } } - AwaitTerminationTester.test(expectedBehavior, awaitTermFunc) + AwaitTerminationTester.test(expectedBehavior, () => awaitTermFunc()) true // If the control reached here, then everything worked as expected } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index d6e15cfdd2723..ab7c8558321c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -139,7 +139,7 @@ private[streaming] class FileBasedWriteAheadLog( def readFile(file: String): Iterator[ByteBuffer] = { logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) - CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) + CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, () => reader.close()) } if (!closeFileAfterWrite) { logFilesToRead.iterator.map(readFile).flatten.asJava From 1051ebec70bf05971ddc80819d112626b1f1614f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 25 Oct 2017 16:31:58 +0100 Subject: [PATCH 759/779] [SPARK-20783][SQL][FOLLOW-UP] Create ColumnVector to abstract existing compressed column ## What changes were proposed in this pull request? Removed one unused method. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #19508 from viirya/SPARK-20783-followup. --- .../apache/spark/sql/execution/columnar/ColumnAccessor.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 445933d98e9d4..85c36b7da9498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -63,9 +63,6 @@ private[columnar] abstract class BasicColumnAccessor[JvmType]( } protected def underlyingBuffer = buffer - - def getByteBuffer: ByteBuffer = - buffer.duplicate.order(ByteOrder.nativeOrder()) } private[columnar] class NullColumnAccessor(buffer: ByteBuffer) From 3d43a9f939764ec265de945921d1ecf2323ca230 Mon Sep 17 00:00:00 2001 From: liuxian Date: Wed, 25 Oct 2017 21:34:00 +0530 Subject: [PATCH 760/779] [SPARK-22349] In on-heap mode, when allocating memory from pool,we should fill memory with `MEMORY_DEBUG_FILL_CLEAN_VALUE` ## What changes were proposed in this pull request? In on-heap mode, when allocating memory from pool,we should fill memory with `MEMORY_DEBUG_FILL_CLEAN_VALUE` ## How was this patch tested? added unit tests Author: liuxian Closes #19572 from 10110346/MEMORY_DEBUG. --- .../spark/unsafe/memory/HeapMemoryAllocator.java | 3 +++ .../org/apache/spark/unsafe/PlatformUtilSuite.java | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 355748238540b..cc9cc429643ad 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -56,6 +56,9 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final MemoryBlock memory = blockReference.get(); if (memory != null) { assert (memory.size() == size); + if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { + memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + } return memory; } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 4ae49d82efa29..4b141339ec816 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -66,10 +66,21 @@ public void overlappingCopyMemory() { public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); MemoryBlock onheap = MemoryAllocator.HEAP.allocate(1); - MemoryBlock offheap = MemoryAllocator.UNSAFE.allocate(1); Assert.assertEquals( Platform.getByte(onheap.getBaseObject(), onheap.getBaseOffset()), MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + + MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + MemoryAllocator.HEAP.free(onheap1); + Assert.assertEquals( + Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()), + MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); + MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Assert.assertEquals( + Platform.getByte(onheap2.getBaseObject(), onheap2.getBaseOffset()), + MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + + MemoryBlock offheap = MemoryAllocator.UNSAFE.allocate(1); Assert.assertEquals( Platform.getByte(offheap.getBaseObject(), offheap.getBaseOffset()), MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); From 6ea8a56ca26a7e02e6574f5f763bb91059119a80 Mon Sep 17 00:00:00 2001 From: Andrea zito Date: Wed, 25 Oct 2017 10:10:24 -0700 Subject: [PATCH 761/779] [SPARK-21991][LAUNCHER] Fix race condition in LauncherServer#acceptConnections ## What changes were proposed in this pull request? This patch changes the order in which _acceptConnections_ starts the client thread and schedules the client timeout action ensuring that the latter has been scheduled before the former get a chance to cancel it. ## How was this patch tested? Due to the non-deterministic nature of the patch I wasn't able to add a new test for this issue. Author: Andrea zito Closes #19217 from nivox/SPARK-21991. --- .../apache/spark/launcher/LauncherServer.java | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 865d4926da6a9..454bc7a7f924d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -232,20 +232,20 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); - synchronized (timeout) { - clientThread.start(); - synchronized (clients) { - clients.add(clientConnection); - } - long timeoutMs = getConnectionTimeout(); - // 0 is used for testing to avoid issues with clock resolution / thread scheduling, - // and force an immediate timeout. - if (timeoutMs > 0) { - timeoutTimer.schedule(timeout, getConnectionTimeout()); - } else { - timeout.run(); - } + synchronized (clients) { + clients.add(clientConnection); + } + + long timeoutMs = getConnectionTimeout(); + // 0 is used for testing to avoid issues with clock resolution / thread scheduling, + // and force an immediate timeout. + if (timeoutMs > 0) { + timeoutTimer.schedule(timeout, timeoutMs); + } else { + timeout.run(); } + + clientThread.start(); } } catch (IOException ioe) { if (running) { From d212ef14be7c2864cc529e48a02a47584e46f7a5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Oct 2017 13:53:01 -0700 Subject: [PATCH 762/779] [SPARK-22341][YARN] Impersonate correct user when preparing resources. The bug was introduced in SPARK-22290, which changed how the app's user is impersonated in the AM. The changed missed an initialization function that needs to be run as the app owner (who has the right credentials to read from HDFS). Author: Marcelo Vanzin Closes #19566 from vanzin/SPARK-22341. --- .../spark/deploy/yarn/ApplicationMaster.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index f6167235f89e4..244d912b9f3aa 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -97,9 +97,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private val client = ugi.doAs(new PrivilegedExceptionAction[YarnRMClient]() { - def run: YarnRMClient = new YarnRMClient() - }) + private val client = doAsUser { new YarnRMClient() } // Default to twice the number of executors (twice the maximum number of executors if dynamic // allocation is enabled), with a minimum of 3. @@ -178,7 +176,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // Load the list of localized files set by the client. This is used when launching executors, // and is loaded here so that these configs don't pollute the Web UI's environment page in // cluster mode. - private val localResources = { + private val localResources = doAsUser { logInfo("Preparing Local resources") val resources = HashMap[String, LocalResource]() @@ -240,9 +238,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } final def run(): Int = { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - def run: Unit = runImpl() - }) + doAsUser { + runImpl() + } exitCode } @@ -790,6 +788,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } + private def doAsUser[T](fn: => T): T = { + ugi.doAs(new PrivilegedExceptionAction[T]() { + override def run: T = fn + }) + } + } object ApplicationMaster extends Logging { From b377ef133cdc38d49b460b2cc6ece0b5892804cc Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 25 Oct 2017 22:15:44 +0100 Subject: [PATCH 763/779] [SPARK-22227][CORE] DiskBlockManager.getAllBlocks now tolerates temp files ## What changes were proposed in this pull request? Prior to this commit getAllBlocks implicitly assumed that the directories managed by the DiskBlockManager contain only the files corresponding to valid block IDs. In reality, this assumption was violated during shuffle, which produces temporary files in the same directory as the resulting blocks. As a result, calls to getAllBlocks during shuffle were unreliable. The fix could be made more efficient, but this is probably good enough. ## How was this patch tested? `DiskBlockManagerSuite` Author: Sergei Lebedev Closes #19458 from superbobry/block-id-option. --- .../scala/org/apache/spark/storage/BlockId.scala | 16 +++++++++++++--- .../apache/spark/storage/DiskBlockManager.scala | 11 ++++++++++- .../org/apache/spark/storage/BlockIdSuite.scala | 9 +++------ .../spark/storage/DiskBlockManagerSuite.scala | 7 +++++++ 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index a441baed2800e..7ac2c71c18eb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.UUID +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi /** @@ -95,6 +96,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId { override def name: String = "test_" + id } +@DeveloperApi +class UnrecognizedBlockId(name: String) + extends SparkException(s"Failed to parse $name into a block ID") + @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r @@ -104,10 +109,11 @@ object BlockId { val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r + val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r + val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r val TEST = "test_(.*)".r - /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String): BlockId = id match { + def apply(name: String): BlockId = name match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => @@ -122,9 +128,13 @@ object BlockId { TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => StreamBlockId(streamId.toInt, uniqueId.toLong) + case TEMP_LOCAL(uuid) => + TempLocalBlockId(UUID.fromString(uuid)) + case TEMP_SHUFFLE(uuid) => + TempShuffleBlockId(UUID.fromString(uuid)) case TEST(value) => TestBlockId(value) case _ => - throw new IllegalStateException("Unrecognized BlockId: " + id) + throw new UnrecognizedBlockId(name) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3d43e3c367aac..a69bcc9259995 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -100,7 +100,16 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea /** List all the blocks currently stored on disk by the disk manager. */ def getAllBlocks(): Seq[BlockId] = { - getAllFiles().map(f => BlockId(f.getName)) + getAllFiles().flatMap { f => + try { + Some(BlockId(f.getName)) + } catch { + case _: UnrecognizedBlockId => + // Skip files which do not correspond to blocks, for example temporary + // files created by [[SortShuffleWriter]]. + None + } + } } /** Produces a unique block id and File suitable for storing local intermediate results. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index f0c521b00b583..ff4755833a916 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -35,13 +35,8 @@ class BlockIdSuite extends SparkFunSuite { } test("test-bad-deserialization") { - try { - // Try to deserialize an invalid block id. + intercept[UnrecognizedBlockId] { BlockId("myblock") - fail() - } catch { - case e: IllegalStateException => // OK - case _: Throwable => fail() } } @@ -139,6 +134,7 @@ class BlockIdSuite extends SparkFunSuite { assert(id.id.getMostSignificantBits() === 5) assert(id.id.getLeastSignificantBits() === 2) assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) } test("temp shuffle") { @@ -151,6 +147,7 @@ class BlockIdSuite extends SparkFunSuite { assert(id.id.getMostSignificantBits() === 1) assert(id.id.getLeastSignificantBits() === 2) assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) } test("test") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 7859b0bba2b48..0c4f3c48ef802 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import java.util.UUID import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} @@ -79,6 +80,12 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } + test("SPARK-22227: non-block files are skipped") { + val file = diskBlockManager.getFile("unmanaged_file") + writeToFile(file, 10) + assert(diskBlockManager.getAllBlocks().isEmpty) + } + def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) From 841f1d776f420424c20d99cf7110d06c73f9ca20 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 25 Oct 2017 14:31:36 -0700 Subject: [PATCH 764/779] [SPARK-22332][ML][TEST] Fix NaiveBayes unit test occasionly fail (cause by test dataset not deterministic) ## What changes were proposed in this pull request? Fix NaiveBayes unit test occasionly fail: Set seed for `BrzMultinomial.sample`, make `generateNaiveBayesInput` output deterministic dataset. (If we do not set seed, the generated dataset will be random, and the model will be possible to exceed the tolerance in the test, which trigger this failure) ## How was this patch tested? Manually run tests multiple times and check each time output models contains the same values. Author: WeichenXu Closes #19558 from WeichenXu123/fix_nb_test_seed. --- .../org/apache/spark/ml/classification/NaiveBayesSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 9730dd68a3b27..0d3adf993383f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import scala.util.Random import breeze.linalg.{DenseVector => BDV, Vector => BV} -import breeze.stats.distributions.{Multinomial => BrzMultinomial} +import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} @@ -335,6 +335,7 @@ object NaiveBayesSuite { val _pi = pi.map(math.exp) val _theta = theta.map(row => row.map(math.exp)) + implicit val rngForBrzMultinomial = BrzRandBasis.withSeed(seed) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { From 5433be44caecaeef45ed1fdae10b223c698a9d14 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Wed, 25 Oct 2017 14:41:02 -0700 Subject: [PATCH 765/779] [SPARK-21991][LAUNCHER][FOLLOWUP] Fix java lint ## What changes were proposed in this pull request? Fix java lint ## How was this patch tested? Run `./dev/lint-java` Author: Andrew Ash Closes #19574 from ash211/aash/fix-java-lint. --- .../main/java/org/apache/spark/launcher/LauncherServer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 454bc7a7f924d..4353e3f263c51 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -235,7 +235,7 @@ public void run() { synchronized (clients) { clients.add(clientConnection); } - + long timeoutMs = getConnectionTimeout(); // 0 is used for testing to avoid issues with clock resolution / thread scheduling, // and force an immediate timeout. @@ -244,7 +244,7 @@ public void run() { } else { timeout.run(); } - + clientThread.start(); } } catch (IOException ioe) { From 592cfeab9caeff955d115a1ca5014ede7d402907 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Thu, 26 Oct 2017 00:29:49 -0700 Subject: [PATCH 766/779] [SPARK-22308] Support alternative unit testing styles in external applications ## What changes were proposed in this pull request? Support unit tests of external code (i.e., applications that use spark) using scalatest that don't want to use FunSuite. SharedSparkContext already supports this, but SharedSQLContext does not. I've introduced SharedSparkSession as a parent to SharedSQLContext, written in a way that it does support all scalatest styles. ## How was this patch tested? There are three new unit test suites added that just test using FunSpec, FlatSpec, and WordSpec. Author: Nathan Kronenfeld Closes #19529 from nkronenfeld/alternative-style-tests-2. --- .../org/apache/spark/SharedSparkContext.scala | 17 +- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +- .../spark/sql/test/GenericFlatSpecSuite.scala | 45 +++++ .../spark/sql/test/GenericFunSpecSuite.scala | 47 +++++ .../spark/sql/test/GenericWordSpecSuite.scala | 51 ++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 173 ++++++++++-------- .../spark/sql/test/SharedSQLContext.scala | 84 +-------- .../spark/sql/test/SharedSparkSession.scala | 119 ++++++++++++ 8 files changed, 381 insertions(+), 165 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 6aedcb1271ff6..1aa1c421d792e 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -29,10 +29,23 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) + /** + * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in + * test using styles other than FunSuite, there is often code that relies on the session between + * test group constructs and the actual tests, which may need this session. It is purely a + * semantic difference, but semantically, it makes more sense to call 'initializeContext' between + * a 'describe' and an 'it' call than it does to call 'beforeAll'. + */ + protected def initializeContext(): Unit = { + if (null == _sc) { + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + } + } + override def beforeAll() { super.beforeAll() - _sc = new SparkContext( - "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + initializeContext() } override def afterAll() { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 10bdfafd6f933..82c5307d54360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import org.scalatest.Suite + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer @@ -29,7 +31,13 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PredicateHelper { +trait PlanTest extends SparkFunSuite with PlanTestBase + +/** + * Provides helper methods for comparing plans, but without the overhead of + * mandating a FunSuite. + */ +trait PlanTestBase extends PredicateHelper { self: Suite => // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala new file mode 100644 index 0000000000000..6179585a0d39a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FlatSpec + +/** + * The purpose of this suite is to make sure that generic FlatSpec-based scala + * tests work with a shared spark session + */ +class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" should "have the specified number of elements" in { + assert(8 === ds.count) + } + it should "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + it should "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it should "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala new file mode 100644 index 0000000000000..15139ee8b3047 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FunSpec + +/** + * The purpose of this suite is to make sure that generic FunSpec-based scala + * tests work with a shared spark session + */ +class GenericFunSpecSuite extends FunSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + describe("Simple Dataset") { + it("should have the specified number of elements") { + assert(8 === ds.count) + } + it("should have the specified number of unique elements") { + assert(8 === ds.distinct.count) + } + it("should have the specified number of elements in each column") { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it("should have the correct number of distinct elements in each column") { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala new file mode 100644 index 0000000000000..b6548bf95fec8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.WordSpec + +/** + * The purpose of this suite is to make sure that generic WordSpec-based scala + * tests work with a shared spark session + */ +class GenericWordSpecSuite extends WordSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" when { + "looked at as complete rows" should { + "have the specified number of elements" in { + assert(8 === ds.count) + } + "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + } + "refined to specific columns" should { + "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index a14a1441a4313..b4248b74f50ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite @@ -36,14 +36,17 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{UninterruptibleThread, Utils} +import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.util.Utils /** - * Helper trait that should be extended by all SQL test suites. + * Helper trait that should be extended by all SQL test suites within the Spark + * code base. * * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. @@ -52,17 +55,99 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait SQLTestUtils - extends SparkFunSuite with Eventually +private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest { + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * Materialize the test data immediately after the `SQLContext` is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** + * Run a test on a separate `UninterruptibleThread`. + */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } +} + +/** + * Helper trait that can be extended by all external SQL test suites. + * + * This allows subclasses to plugin a custom `SQLContext`. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. + * + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtilsBase + extends Eventually with BeforeAndAfterAll with SQLTestData - with PlanTest { self => + with PlanTestBase { self: Suite => protected def sparkContext = spark.sparkContext - // Whether to materialize all test data before the first test is run - private var loadTestDataBeforeTests = false - // Shorthand for running a query using our SQLContext protected lazy val sql = spark.sql _ @@ -77,21 +162,6 @@ private[sql] trait SQLTestUtils protected override def _sqlContext: SQLContext = self.spark.sqlContext } - /** - * Materialize the test data immediately after the `SQLContext` is set up. - * This is necessary if the data is accessed by name but not through direct reference. - */ - protected def setupTestData(): Unit = { - loadTestDataBeforeTests = true - } - - protected override def beforeAll(): Unit = { - super.beforeAll() - if (loadTestDataBeforeTests) { - loadTestData() - } - } - protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) @@ -297,61 +367,6 @@ private[sql] trait SQLTestUtils Dataset.ofRows(spark, plan) } - /** - * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of - * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if - * we change System.out and System.err. - */ - protected def testQuietly(name: String)(f: => Unit): Unit = { - test(name) { - quietly { - f - } - } - } - - /** - * Run a test on a separate `UninterruptibleThread`. - */ - protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) - (body: => Unit): Unit = { - val timeoutMillis = 10000 - @transient var ex: Throwable = null - - def runOnThread(): Unit = { - val thread = new UninterruptibleThread(s"Testing thread for test $name") { - override def run(): Unit = { - try { - body - } catch { - case NonFatal(e) => - ex = e - } - } - } - thread.setDaemon(true) - thread.start() - thread.join(timeoutMillis) - if (thread.isAlive) { - thread.interrupt() - // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and - // we rather let the JVM terminate the thread on exit. - fail( - s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + - s" $timeoutMillis ms") - } else if (ex != null) { - throw ex - } - } - - if (quietly) { - testQuietly(name) { runOnThread() } - } else { - test(name) { runOnThread() } - } - } /** * This method is used to make the given path qualified, when a path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index cd8d0708d8a32..4d578e21f5494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,86 +17,4 @@ package org.apache.spark.sql.test -import scala.concurrent.duration._ - -import org.scalatest.BeforeAndAfterEach -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - -/** - * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. - */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - - protected def sparkConf = { - new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected implicit def spark: SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - new TestSparkSession(sparkConf) - } - - /** - * Initialize the [[TestSparkSession]]. - */ - protected override def beforeAll(): Unit = { - SparkSession.sqlListener.set(null) - if (_spark == null) { - _spark = createSparkSession - } - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } -} +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala new file mode 100644 index 0000000000000..e0568a3c5c99f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. + */ +trait SharedSparkSession + extends SQLTestUtilsBase + with BeforeAndAfterEach + with Eventually { self: Suite => + + protected def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: TestSparkSession = { + new TestSparkSession(sparkConf) + } + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + super.afterAll() + if (_spark != null) { + _spark.sessionState.catalog.reset() + _spark.stop() + _spark = null + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} From 3073344a2551fb198d63f2114a519ab97904cb55 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 26 Oct 2017 15:50:27 +0800 Subject: [PATCH 767/779] [SPARK-21840][CORE] Add trait that allows conf to be directly set in application. Currently SparkSubmit uses system properties to propagate configuration to applications. This makes it hard to implement features such as SPARK-11035, which would allow multiple applications to be started in the same JVM. The current code would cause the config data from multiple apps to get mixed up. This change introduces a new trait, currently internal to Spark, that allows the app configuration to be passed directly to the application, without having to use system properties. The current "call main() method" behavior is maintained as an implementation of this new trait. This will be useful to allow multiple cluster mode apps to be submitted from the same JVM. As part of this, SparkSubmit was modified to collect all configuration directly into a SparkConf instance. Most of the changes are to tests so they use SparkConf instead of an opaque map. Tested with existing and added unit tests. Author: Marcelo Vanzin Closes #19519 from vanzin/SPARK-21840. --- .../spark/deploy/SparkApplication.scala | 55 +++++ .../org/apache/spark/deploy/SparkSubmit.scala | 160 +++++++------ .../spark/deploy/SparkSubmitSuite.scala | 213 ++++++++++-------- .../rest/StandaloneRestSubmitSuite.scala | 4 +- 4 files changed, 257 insertions(+), 175 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala b/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala new file mode 100644 index 0000000000000..118b4605675b0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.lang.reflect.Modifier + +import org.apache.spark.SparkConf + +/** + * Entry point for a Spark application. Implementations must provide a no-argument constructor. + */ +private[spark] trait SparkApplication { + + def start(args: Array[String], conf: SparkConf): Unit + +} + +/** + * Implementation of SparkApplication that wraps a standard Java class with a "main" method. + * + * Configuration is propagated to the application via system properties, so running multiple + * of these in the same JVM may lead to undefined behavior due to configuration leaks. + */ +private[deploy] class JavaMainApplication(klass: Class[_]) extends SparkApplication { + + override def start(args: Array[String], conf: SparkConf): Unit = { + val mainMethod = klass.getMethod("main", new Array[String](0).getClass) + if (!Modifier.isStatic(mainMethod.getModifiers)) { + throw new IllegalStateException("The main method in the given main class must be static") + } + + val sysProps = conf.getAll.toMap + sysProps.foreach { case (k, v) => + sys.props(k) = v + } + + mainMethod.invoke(null, args) + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b7e6d0ea021a4..73b956ef3e470 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -158,7 +158,7 @@ object SparkSubmit extends CommandLineUtils with Logging { */ @tailrec private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = { - val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) + val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args) def doRunMain(): Unit = { if (args.proxyUser != null) { @@ -167,7 +167,7 @@ object SparkSubmit extends CommandLineUtils with Logging { try { proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose) } }) } catch { @@ -185,7 +185,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } } } else { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose) } } @@ -235,11 +235,11 @@ object SparkSubmit extends CommandLineUtils with Logging { private[deploy] def prepareSubmitEnvironment( args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) - : (Seq[String], Seq[String], Map[String, String], String) = { + : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() - val sysProps = new HashMap[String, String]() + val sparkConf = new SparkConf() var childMainClass = "" // Set the cluster manager @@ -337,7 +337,6 @@ object SparkSubmit extends CommandLineUtils with Logging { } } - val sparkConf = new SparkConf(false) args.sparkProperties.foreach { case (k, v) => sparkConf.set(k, v) } val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() @@ -351,8 +350,8 @@ object SparkSubmit extends CommandLineUtils with Logging { // for later use; e.g. in spark sql, the isolated class loader used to talk // to HiveMetastore will use these settings. They will be set as Java system // properties and then loaded by SparkConf - sysProps.put("spark.yarn.keytab", args.keytab) - sysProps.put("spark.yarn.principal", args.principal) + sparkConf.set(KEYTAB, args.keytab) + sparkConf.set(PRINCIPAL, args.principal) UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) } } @@ -364,23 +363,24 @@ object SparkSubmit extends CommandLineUtils with Logging { args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull + // This security manager will not need an auth secret, but set a dummy value in case + // spark.authenticate is enabled, otherwise an exception is thrown. + lazy val downloadConf = sparkConf.clone().set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + lazy val secMgr = new SecurityManager(downloadConf) + // In client mode, download remote files. var localPrimaryResource: String = null var localJars: String = null var localPyFiles: String = null if (deployMode == CLIENT) { - // This security manager will not need an auth secret, but set a dummy value in case - // spark.authenticate is enabled, otherwise an exception is thrown. - sparkConf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") - val secMgr = new SecurityManager(sparkConf) localPrimaryResource = Option(args.primaryResource).map { - downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFile(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull localJars = Option(args.jars).map { - downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull localPyFiles = Option(args.pyFiles).map { - downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull } @@ -409,7 +409,7 @@ object SparkSubmit extends CommandLineUtils with Logging { if (file.exists()) { file.toURI.toString } else { - downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) + downloadFile(resource, targetDir, downloadConf, hadoopConf, secMgr) } case _ => uri.toString } @@ -449,7 +449,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.files = mergeFileLists(args.files, args.pyFiles) } if (localPyFiles != null) { - sysProps("spark.submit.pyFiles") = localPyFiles + sparkConf.set("spark.submit.pyFiles", localPyFiles) } } @@ -515,69 +515,69 @@ object SparkSubmit extends CommandLineUtils with Logging { } // Special flag to avoid deprecation warnings at the client - sysProps("SPARK_SUBMIT") = "true" + sys.props("SPARK_SUBMIT") = "true" // A list of rules to map each argument to system properties or command-line options in // each deploy mode; we iterate through these below val options = List[OptionAssigner]( // All cluster managers - OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.master"), OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.submit.deployMode"), - OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), - OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), + confKey = "spark.submit.deployMode"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.app.name"), + OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, - sysProp = "spark.driver.memory"), + confKey = "spark.driver.memory"), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraClassPath"), + confKey = "spark.driver.extraClassPath"), OptionAssigner(args.driverExtraJavaOptions, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraJavaOptions"), + confKey = "spark.driver.extraJavaOptions"), OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraLibraryPath"), + confKey = "spark.driver.extraLibraryPath"), // Propagate attributes for dependency resolution at the driver side - OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.packages"), + OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.packages"), OptionAssigner(args.repositories, STANDALONE | MESOS, CLUSTER, - sysProp = "spark.jars.repositories"), - OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.ivy"), + confKey = "spark.jars.repositories"), + OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.ivy"), OptionAssigner(args.packagesExclusions, STANDALONE | MESOS, - CLUSTER, sysProp = "spark.jars.excludes"), + CLUSTER, confKey = "spark.jars.excludes"), // Yarn only - OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"), + OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.instances"), - OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.pyFiles"), - OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.jars"), - OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.files"), - OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.archives"), - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.keytab"), + confKey = "spark.executor.instances"), + OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles"), + OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"), + OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"), + OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.cores"), + confKey = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.memory"), + confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.cores.max"), + confKey = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files"), - OptionAssigner(args.jars, LOCAL, CLIENT, sysProp = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), + confKey = "spark.files"), + OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN, CLUSTER, - sysProp = "spark.driver.memory"), + confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN, CLUSTER, - sysProp = "spark.driver.cores"), + confKey = "spark.driver.cores"), OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, - sysProp = "spark.driver.supervise"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), + confKey = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, confKey = "spark.jars.ivy"), // An internal option used only for spark-shell to add user jars to repl's classloader, // previously it uses "spark.jars" or "spark.yarn.dist.jars" which now may be pointed to // remote jars, so adding a new option to only specify local jars for spark-shell internally. - OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.repl.local.jars") + OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.repl.local.jars") ) // In client mode, launch the application main class directly @@ -610,24 +610,24 @@ object SparkSubmit extends CommandLineUtils with Logging { (deployMode & opt.deployMode) != 0 && (clusterManager & opt.clusterManager) != 0) { if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } - if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) } + if (opt.confKey != null) { sparkConf.set(opt.confKey, opt.value) } } } // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) { - sysProps(UI_SHOW_CONSOLE_PROGRESS.key) = "true" + sparkConf.set(UI_SHOW_CONSOLE_PROGRESS, true) } // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python and R files, the primary resource is already distributed as a regular file if (!isYarnCluster && !args.isPython && !args.isR) { - var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) + var jars = sparkConf.getOption("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) } - sysProps.put("spark.jars", jars.mkString(",")) + sparkConf.set("spark.jars", jars.mkString(",")) } // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+). @@ -653,12 +653,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // Let YARN know it's a pyspark app, so it distributes needed libraries. if (clusterManager == YARN) { if (args.isPython) { - sysProps.put("spark.yarn.isPython", "true") + sparkConf.set("spark.yarn.isPython", "true") } } if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) { - setRMPrincipal(sysProps) + setRMPrincipal(sparkConf) } // In yarn-cluster mode, use yarn.Client as a wrapper around the user class @@ -689,7 +689,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // Second argument is main class childArgs += (args.primaryResource, "") if (args.pyFiles != null) { - sysProps("spark.submit.pyFiles") = args.pyFiles + sparkConf.set("spark.submit.pyFiles", args.pyFiles) } } else if (args.isR) { // Second argument is main class @@ -704,12 +704,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { - sysProps.getOrElseUpdate(k, v) + sparkConf.setIfMissing(k, v) } // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sysProps -= "spark.driver.host" + sparkConf.remove("spark.driver.host") } // Resolve paths in certain spark properties @@ -721,15 +721,15 @@ object SparkSubmit extends CommandLineUtils with Logging { "spark.yarn.dist.jars") pathConfigs.foreach { config => // Replace old URIs with resolved URIs, if they exist - sysProps.get(config).foreach { oldValue => - sysProps(config) = Utils.resolveURIs(oldValue) + sparkConf.getOption(config).foreach { oldValue => + sparkConf.set(config, Utils.resolveURIs(oldValue)) } } // Resolve and format python file paths properly before adding them to the PYTHONPATH. // The resolving part is redundant in the case of --py-files, but necessary if the user // explicitly sets `spark.submit.pyFiles` in his/her default properties file. - sysProps.get("spark.submit.pyFiles").foreach { pyFiles => + sparkConf.getOption("spark.submit.pyFiles").foreach { pyFiles => val resolvedPyFiles = Utils.resolveURIs(pyFiles) val formattedPyFiles = if (!isYarnCluster && !isMesosCluster) { PythonRunner.formatPaths(resolvedPyFiles).mkString(",") @@ -739,22 +739,22 @@ object SparkSubmit extends CommandLineUtils with Logging { // locally. resolvedPyFiles } - sysProps("spark.submit.pyFiles") = formattedPyFiles + sparkConf.set("spark.submit.pyFiles", formattedPyFiles) } - (childArgs, childClasspath, sysProps, childMainClass) + (childArgs, childClasspath, sparkConf, childMainClass) } // [SPARK-20328]. HadoopRDD calls into a Hadoop library that fetches delegation tokens with // renewer set to the YARN ResourceManager. Since YARN isn't configured in Mesos mode, we // must trick it into thinking we're YARN. - private def setRMPrincipal(sysProps: HashMap[String, String]): Unit = { + private def setRMPrincipal(sparkConf: SparkConf): Unit = { val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" // scalastyle:off println printStream.println(s"Setting ${key} to ${shortUserName}") // scalastyle:off println - sysProps.put(key, shortUserName) + sparkConf.set(key, shortUserName) } /** @@ -766,7 +766,7 @@ object SparkSubmit extends CommandLineUtils with Logging { private def runMain( childArgs: Seq[String], childClasspath: Seq[String], - sysProps: Map[String, String], + sparkConf: SparkConf, childMainClass: String, verbose: Boolean): Unit = { // scalastyle:off println @@ -774,14 +774,14 @@ object SparkSubmit extends CommandLineUtils with Logging { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - printStream.println(s"System properties:\n${Utils.redact(sysProps).mkString("\n")}") + printStream.println(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } // scalastyle:on println val loader = - if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { new ChildFirstURLClassLoader(new Array[URL](0), Thread.currentThread.getContextClassLoader) } else { @@ -794,10 +794,6 @@ object SparkSubmit extends CommandLineUtils with Logging { addJarToClasspath(jar, loader) } - for ((key, value) <- sysProps) { - System.setProperty(key, value) - } - var mainClass: Class[_] = null try { @@ -823,14 +819,14 @@ object SparkSubmit extends CommandLineUtils with Logging { System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } - // SPARK-4170 - if (classOf[scala.App].isAssignableFrom(mainClass)) { - printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") - } - - val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) - if (!Modifier.isStatic(mainMethod.getModifiers)) { - throw new IllegalStateException("The main method in the given main class must be static") + val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { + mainClass.newInstance().asInstanceOf[SparkApplication] + } else { + // SPARK-4170 + if (classOf[scala.App].isAssignableFrom(mainClass)) { + printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + } + new JavaMainApplication(mainClass) } @tailrec @@ -844,7 +840,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } try { - mainMethod.invoke(null, childArgs.toArray) + app.start(childArgs.toArray, sparkConf) } catch { case t: Throwable => findCause(t) match { @@ -1271,4 +1267,4 @@ private case class OptionAssigner( clusterManager: Int, deployMode: Int, clOption: String = null, - sysProp: String = null) + confKey: String = null) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b52da4c0c8bc3..cfbf56fb8c369 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -176,10 +176,10 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, _) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = prepareSubmitEnvironment(appArgs) appArgs.deployMode should be ("client") - sysProps("spark.submit.deployMode") should be ("client") + conf.get("spark.submit.deployMode") should be ("client") // Both cmd line and configuration are specified, cmdline option takes the priority val clArgs1 = Seq( @@ -190,10 +190,10 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) appArgs1.deployMode should be ("cluster") - sysProps1("spark.submit.deployMode") should be ("cluster") + conf1.get("spark.submit.deployMode") should be ("cluster") // Neither cmdline nor configuration are specified, client mode is the default choice val clArgs2 = Seq( @@ -204,9 +204,9 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) appArgs2.deployMode should be (null) - val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) appArgs2.deployMode should be ("client") - sysProps2("spark.submit.deployMode") should be ("client") + conf2.get("spark.submit.deployMode") should be ("client") } test("handles YARN cluster mode") { @@ -227,7 +227,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") @@ -240,16 +240,16 @@ class SparkSubmitSuite classpath(2) should endWith ("two.jar") classpath(3) should endWith ("three.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.driver.memory") should be ("4g") - sysProps("spark.executor.cores") should be ("5") - sysProps("spark.yarn.queue") should be ("thequeue") - sysProps("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") - sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") - sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.app.name") should be ("beauty") - sysProps("spark.ui.enabled") should be ("false") - sysProps("SPARK_SUBMIT") should be ("true") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.driver.memory") should be ("4g") + conf.get("spark.executor.cores") should be ("5") + conf.get("spark.yarn.queue") should be ("thequeue") + conf.get("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") + conf.get("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + conf.get("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") + conf.get("spark.app.name") should be ("beauty") + conf.get("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") } test("handles YARN client mode") { @@ -270,7 +270,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -278,17 +278,17 @@ class SparkSubmitSuite classpath(1) should endWith ("one.jar") classpath(2) should endWith ("two.jar") classpath(3) should endWith ("three.jar") - sysProps("spark.app.name") should be ("trill") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.executor.cores") should be ("5") - sysProps("spark.yarn.queue") should be ("thequeue") - sysProps("spark.executor.instances") should be ("6") - sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") - sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.yarn.dist.jars") should include + conf.get("spark.app.name") should be ("trill") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.executor.cores") should be ("5") + conf.get("spark.yarn.queue") should be ("thequeue") + conf.get("spark.executor.instances") should be ("6") + conf.get("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + conf.get("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") + conf.get("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") - sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") } test("handles standalone cluster mode") { @@ -316,7 +316,7 @@ class SparkSubmitSuite "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) appArgs.useRest = useRest - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") @@ -327,17 +327,18 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.Client") } classpath should have size 0 - sysProps should have size 9 - sysProps.keys should contain ("SPARK_SUBMIT") - sysProps.keys should contain ("spark.master") - sysProps.keys should contain ("spark.app.name") - sysProps.keys should contain ("spark.jars") - sysProps.keys should contain ("spark.driver.memory") - sysProps.keys should contain ("spark.driver.cores") - sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.ui.enabled") - sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") + + val confMap = conf.getAll.toMap + confMap.keys should contain ("spark.master") + confMap.keys should contain ("spark.app.name") + confMap.keys should contain ("spark.jars") + confMap.keys should contain ("spark.driver.memory") + confMap.keys should contain ("spark.driver.cores") + confMap.keys should contain ("spark.driver.supervise") + confMap.keys should contain ("spark.ui.enabled") + confMap.keys should contain ("spark.submit.deployMode") + conf.get("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -352,14 +353,14 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) classpath(0) should endWith ("thejar.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.cores.max") should be ("5") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.cores.max") should be ("5") + conf.get("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -374,14 +375,14 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) classpath(0) should endWith ("thejar.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.cores.max") should be ("5") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.cores.max") should be ("5") + conf.get("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { @@ -394,23 +395,26 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.master") should be ("yarn") - sysProps("spark.submit.deployMode") should be ("cluster") + val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs) + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.master") should be ("yarn") + conf.get("spark.submit.deployMode") should be ("cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") } test("SPARK-21568 ConsoleProgressBar should be enabled only in shells") { + // Unset from system properties since this config is defined in the root pom's test config. + sys.props -= UI_SHOW_CONSOLE_PROGRESS.key + val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) - sysProps1(UI_SHOW_CONSOLE_PROGRESS.key) should be ("true") + val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true) val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) - sysProps2.keys should not contain UI_SHOW_CONSOLE_PROGRESS.key + val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS)) } test("launch simple application with spark-submit") { @@ -585,11 +589,11 @@ class SparkSubmitSuite "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) - sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) - sysProps("spark.files") should be (Utils.resolveURIs(files)) + conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be (Utils.resolveURIs(files)) // Test files and archives (Yarn) val clArgs2 = Seq( @@ -600,11 +604,11 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 + val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should be (Utils.resolveURIs(archives)) - sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - sysProps2("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) // Test python files val clArgs3 = Seq( @@ -615,12 +619,12 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 + val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) - sysProps3("spark.submit.pyFiles") should be ( + conf3.get("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) - sysProps3(PYSPARK_DRIVER_PYTHON.key) should be ("python3.4") - sysProps3(PYSPARK_PYTHON.key) should be ("python3.5") + conf3.get(PYSPARK_DRIVER_PYTHON.key) should be ("python3.4") + conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") } test("resolves config paths correctly") { @@ -644,9 +648,9 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 - sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) - sysProps("spark.files") should be(Utils.resolveURIs(files)) + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be(Utils.resolveURIs(files)) // Test files and archives (Yarn) val f2 = File.createTempFile("test-submit-files-archives", "", tmpDir) @@ -661,9 +665,9 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 - sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) - sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) + val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) // Test python files val f3 = File.createTempFile("test-submit-python-files", "", tmpDir) @@ -676,8 +680,8 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 - sysProps3("spark.submit.pyFiles") should be( + val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + conf3.get("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) // Test remote python files @@ -693,11 +697,9 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val sysProps4 = SparkSubmit.prepareSubmitEnvironment(appArgs4)._3 + val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4) // Should not format python path for yarn cluster mode - sysProps4("spark.submit.pyFiles") should be( - Utils.resolveURIs(remotePyFiles) - ) + conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } test("user classpath first in driver") { @@ -771,14 +773,14 @@ class SparkSubmitSuite jar2.toString) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 - sysProps("spark.yarn.dist.jars").split(",").toSet should be + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + conf.get("spark.yarn.dist.jars").split(",").toSet should be (Set(jar1.toURI.toString, jar2.toURI.toString)) - sysProps("spark.yarn.dist.files").split(",").toSet should be + conf.get("spark.yarn.dist.files").split(",").toSet should be (Set(file1.toURI.toString, file2.toURI.toString)) - sysProps("spark.yarn.dist.pyFiles").split(",").toSet should be + conf.get("spark.yarn.dist.pyFiles").split(",").toSet should be (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) - sysProps("spark.yarn.dist.archives").split(",").toSet should be + conf.get("spark.yarn.dist.archives").split(",").toSet should be (Set(archive1.toURI.toString, archive2.toURI.toString)) } @@ -897,18 +899,18 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) // All the resources should still be remote paths, so that YARN client will not upload again. - sysProps("spark.yarn.dist.jars") should be (tmpJarPath) - sysProps("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") - sysProps("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.yarn.dist.jars") should be (tmpJarPath) + conf.get("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") + conf.get("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") // Local repl jars should be a local path. - sysProps("spark.repl.local.jars") should (startWith("file:")) + conf.get("spark.repl.local.jars") should (startWith("file:")) // local py files should not be a URI format. - sysProps("spark.submit.pyFiles") should (startWith("/")) + conf.get("spark.submit.pyFiles") should (startWith("/")) } test("download remote resource if it is not supported by yarn service") { @@ -955,9 +957,9 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) - val jars = sysProps("spark.yarn.dist.jars").split(",").toSet + val jars = conf.get("spark.yarn.dist.jars").split(",").toSet // The URI of remote S3 resource should still be remote. assert(jars.contains(tmpS3JarPath)) @@ -996,6 +998,21 @@ class SparkSubmitSuite conf.set("fs.s3a.impl", classOf[TestFileSystem].getCanonicalName) conf.set("fs.s3a.impl.disable.cache", "true") } + + test("start SparkApplication without modifying system properties") { + val args = Array( + "--class", classOf[TestSparkApplication].getName(), + "--master", "local", + "--conf", "spark.test.hello=world", + "spark-internal", + "hello") + + val exception = intercept[SparkException] { + SparkSubmit.main(args) + } + + assert(exception.getMessage() === "hello") + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { @@ -1115,3 +1132,17 @@ class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { override def open(path: Path): FSDataInputStream = super.open(local(path)) } + +class TestSparkApplication extends SparkApplication with Matchers { + + override def start(args: Array[String], conf: SparkConf): Unit = { + assert(args.size === 1) + assert(args(0) === "hello") + assert(conf.get("spark.test.hello") === "world") + assert(sys.props.get("spark.test.hello") === None) + + // This is how the test verifies the application was actually run. + throw new SparkException(args(0)) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 70887dc5dd97a..490baf040491f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -445,9 +445,9 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) - val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) + val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( - mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) + mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) } /** Return the response as a submit response, or fail with error otherwise. */ From a83d8d5adcb4e0061e43105767242ba9770dda96 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Oct 2017 20:54:36 +0900 Subject: [PATCH 768/779] [SPARK-17902][R] Revive stringsAsFactors option for collect() in SparkR ## What changes were proposed in this pull request? This PR proposes to revive `stringsAsFactors` option in collect API, which was mistakenly removed in https://github.com/apache/spark/commit/71a138cd0e0a14e8426f97877e3b52a562bbd02c. Simply, it casts `charactor` to `factor` if it meets the condition, `stringsAsFactors && is.character(vec)` in primitive type conversion. ## How was this patch tested? Unit test in `R/pkg/tests/fulltests/test_sparkSQL.R`. Author: hyukjinkwon Closes #19551 from HyukjinKwon/SPARK-17902. --- R/pkg/R/DataFrame.R | 3 +++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 176bb3b8a8d0c..aaa3349d57506 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1191,6 +1191,9 @@ setMethod("collect", vec <- do.call(c, col) stopifnot(class(vec) != "list") class(vec) <- PRIMITIVE_TYPES[[colType]] + if (is.character(vec) && stringsAsFactors) { + vec <- as.factor(vec) + } df[[colIndex]] <- vec } else { df[[colIndex]] <- col diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4382ef2ed4525..0c8118a7c73f3 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -499,6 +499,12 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) +test_that("SPARK-17902: collect() with stringsAsFactors enabled", { + df <- suppressWarnings(collect(createDataFrame(iris), stringsAsFactors = TRUE)) + expect_equal(class(iris$Species), class(df$Species)) + expect_equal(iris$Species, df$Species) +}) + test_that("SPARK-17811: can create DataFrame containing NA as date and time", { df <- data.frame( id = 1:2, From 0e9a750a8d389b3a17834584d31c204c77c6970d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 26 Oct 2017 11:05:16 -0500 Subject: [PATCH 769/779] [SPARK-20643][CORE] Add listener implementation to collect app state. The initial listener code is based on the existing JobProgressListener (and others), and tries to mimic their behavior as much as possible. The change also includes some minor code movement so that some types and methods from the initial history server code code can be reused. The code introduces a few mutable versions of public API types, used internally, to make it easier to update information without ugly copy methods, and also to make certain updates cheaper. Note the code here is not 100% correct. This is meant as a building ground for the UI integration in the next milestones. As different parts of the UI are ported, fixes will be made to the different parts of this code to account for the needed behavior. I also added annotations to API types so that Jackson is able to correctly deserialize options, sequences and maps that store primitive types. Author: Marcelo Vanzin Closes #19383 from vanzin/SPARK-20643. --- .../apache/spark/util/kvstore/KVTypeInfo.java | 2 + .../apache/spark/util/kvstore/LevelDB.java | 2 +- .../spark/status/api/v1/StageStatus.java | 3 +- .../deploy/history/FsHistoryProvider.scala | 37 +- .../apache/spark/deploy/history/config.scala | 6 - .../spark/status/AppStatusListener.scala | 531 ++++++++++++++ .../org/apache/spark/status/KVUtils.scala | 73 ++ .../org/apache/spark/status/LiveEntity.scala | 526 +++++++++++++ .../status/api/v1/AllStagesResource.scala | 4 +- .../org/apache/spark/status/api/v1/api.scala | 11 +- .../org/apache/spark/status/storeTypes.scala | 98 +++ .../history/FsHistoryProviderSuite.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 690 ++++++++++++++++++ project/MimaExcludes.scala | 2 + 14 files changed, 1942 insertions(+), 45 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusListener.scala create mode 100644 core/src/main/scala/org/apache/spark/status/KVUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/status/LiveEntity.scala create mode 100644 core/src/main/scala/org/apache/spark/status/storeTypes.scala create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java index a2b077e4531ee..870b484f99068 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java @@ -46,6 +46,7 @@ public KVTypeInfo(Class type) throws Exception { KVIndex idx = f.getAnnotation(KVIndex.class); if (idx != null) { checkIndex(idx, indices); + f.setAccessible(true); indices.put(idx.value(), idx); f.setAccessible(true); accessors.put(idx.value(), new FieldAccessor(f)); @@ -58,6 +59,7 @@ public KVTypeInfo(Class type) throws Exception { checkIndex(idx, indices); Preconditions.checkArgument(m.getParameterTypes().length == 0, "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); + m.setAccessible(true); indices.put(idx.value(), idx); m.setAccessible(true); accessors.put(idx.value(), new MethodAccessor(m)); diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index ff48b155fab31..4f9e10ca20066 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -76,7 +76,7 @@ public LevelDB(File path, KVStoreSerializer serializer) throws Exception { this.types = new ConcurrentHashMap<>(); Options options = new Options(); - options.createIfMissing(!path.exists()); + options.createIfMissing(true); this._db = new AtomicReference<>(JniDBFactory.factory.open(path, options)); byte[] versionData = db().get(STORE_VERSION_KEY); diff --git a/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java index 9dbb565aab707..40b5f627369d5 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java @@ -23,7 +23,8 @@ public enum StageStatus { ACTIVE, COMPLETE, FAILED, - PENDING; + PENDING, + SKIPPED; public static StageStatus fromString(String str) { return EnumUtil.parseIgnoreCase(StageStatus.class, str); diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 3889dd097ee59..cf97597b484d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -42,6 +42,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ +import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1 import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} @@ -129,29 +130,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => val dbPath = new File(path, "listing.ldb") - - def openDB(): LevelDB = new LevelDB(dbPath, new KVStoreScalaSerializer()) + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, logDir.toString()) try { - val db = openDB() - val meta = db.getMetadata(classOf[KVStoreMetadata]) - - if (meta == null) { - db.setMetadata(new KVStoreMetadata(CURRENT_LISTING_VERSION, logDir)) - db - } else if (meta.version != CURRENT_LISTING_VERSION || !logDir.equals(meta.logDir)) { - logInfo("Detected mismatched config in existing DB, deleting...") - db.close() - Utils.deleteRecursively(dbPath) - openDB() - } else { - db - } + open(new File(path, "listing.ldb"), metadata) } catch { - case _: UnsupportedStoreVersionException => + case _: UnsupportedStoreVersionException | _: MetadataMismatchException => logInfo("Detected incompatible DB versions, deleting...") Utils.deleteRecursively(dbPath) - openDB() + open(new File(path, "listing.ldb"), metadata) } }.getOrElse(new InMemoryStore()) @@ -720,19 +707,7 @@ private[history] object FsHistoryProvider { private[history] val CURRENT_LISTING_VERSION = 1L } -/** - * A KVStoreSerializer that provides Scala types serialization too, and uses the same options as - * the API serializer. - */ -private class KVStoreScalaSerializer extends KVStoreSerializer { - - mapper.registerModule(DefaultScalaModule) - mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) - mapper.setDateFormat(v1.JacksonMessageWriter.makeISODateFormat) - -} - -private[history] case class KVStoreMetadata( +private[history] case class FsHistoryProviderMetadata( version: Long, logDir: String) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index fb9e997def0dd..52dedc1a2ed41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -19,16 +19,10 @@ package org.apache.spark.deploy.history import java.util.concurrent.TimeUnit -import scala.annotation.meta.getter - import org.apache.spark.internal.config.ConfigBuilder -import org.apache.spark.util.kvstore.KVIndex private[spark] object config { - /** Use this to annotate constructor params to be used as KVStore indices. */ - type KVIndexParam = KVIndex @getter - val DEFAULT_LOG_DIR = "file:/tmp/spark-events" val EVENT_LOG_DIR = ConfigBuilder("spark.history.fs.logDirectory") diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala new file mode 100644 index 0000000000000..f120685c941df --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1 +import org.apache.spark.storage._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.kvstore.KVStore + +/** + * A Spark listener that writes application information to a data store. The types written to the + * store are defined in the `storeTypes.scala` file and are based on the public REST API. + */ +private class AppStatusListener(kvstore: KVStore) extends SparkListener with Logging { + + private var sparkVersion = SPARK_VERSION + private var appInfo: v1.ApplicationInfo = null + private var coresPerTask: Int = 1 + + // Keep track of live entities, so that task metrics can be efficiently updated (without + // causing too many writes to the underlying store, and other expensive operations). + private val liveStages = new HashMap[(Int, Int), LiveStage]() + private val liveJobs = new HashMap[Int, LiveJob]() + private val liveExecutors = new HashMap[String, LiveExecutor]() + private val liveTasks = new HashMap[Long, LiveTask]() + private val liveRDDs = new HashMap[Int, LiveRDD]() + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(version) => sparkVersion = version + case _ => + } + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { + assert(event.appId.isDefined, "Application without IDs are not supported.") + + val attempt = new v1.ApplicationAttemptInfo( + event.appAttemptId, + new Date(event.time), + new Date(-1), + new Date(event.time), + -1L, + event.sparkUser, + false, + sparkVersion) + + appInfo = new v1.ApplicationInfo( + event.appId.get, + event.appName, + None, + None, + None, + None, + Seq(attempt)) + + kvstore.write(new ApplicationInfoWrapper(appInfo)) + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { + val old = appInfo.attempts.head + val attempt = new v1.ApplicationAttemptInfo( + old.attemptId, + old.startTime, + new Date(event.time), + new Date(event.time), + event.time - old.startTime.getTime(), + old.sparkUser, + true, + old.appSparkVersion) + + appInfo = new v1.ApplicationInfo( + appInfo.id, + appInfo.name, + None, + None, + None, + None, + Seq(attempt)) + kvstore.write(new ApplicationInfoWrapper(appInfo)) + } + + override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { + // This needs to be an update in case an executor re-registers after the driver has + // marked it as "dead". + val exec = getOrCreateExecutor(event.executorId) + exec.host = event.executorInfo.executorHost + exec.isActive = true + exec.totalCores = event.executorInfo.totalCores + exec.maxTasks = event.executorInfo.totalCores / coresPerTask + exec.executorLogs = event.executorInfo.logUrlMap + update(exec) + } + + override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { + liveExecutors.remove(event.executorId).foreach { exec => + exec.isActive = false + update(exec) + } + } + + override def onExecutorBlacklisted(event: SparkListenerExecutorBlacklisted): Unit = { + updateBlackListStatus(event.executorId, true) + } + + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { + updateBlackListStatus(event.executorId, false) + } + + override def onNodeBlacklisted(event: SparkListenerNodeBlacklisted): Unit = { + updateNodeBlackList(event.hostId, true) + } + + override def onNodeUnblacklisted(event: SparkListenerNodeUnblacklisted): Unit = { + updateNodeBlackList(event.hostId, false) + } + + private def updateBlackListStatus(execId: String, blacklisted: Boolean): Unit = { + liveExecutors.get(execId).foreach { exec => + exec.isBlacklisted = blacklisted + update(exec) + } + } + + private def updateNodeBlackList(host: String, blacklisted: Boolean): Unit = { + // Implicitly (un)blacklist every executor associated with the node. + liveExecutors.values.foreach { exec => + if (exec.hostname == host) { + exec.isBlacklisted = blacklisted + update(exec) + } + } + } + + override def onJobStart(event: SparkListenerJobStart): Unit = { + // Compute (a potential over-estimate of) the number of tasks that will be run by this job. + // This may be an over-estimate because the job start event references all of the result + // stages' transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + val numTasks = { + val missingStages = event.stageInfos.filter(_.completionTime.isEmpty) + missingStages.map(_.numTasks).sum + } + + val lastStageInfo = event.stageInfos.lastOption + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + + val jobGroup = Option(event.properties) + .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } + + val job = new LiveJob( + event.jobId, + lastStageName, + Some(new Date(event.time)), + event.stageIds, + jobGroup, + numTasks) + liveJobs.put(event.jobId, job) + update(job) + + event.stageInfos.foreach { stageInfo => + // A new job submission may re-use an existing stage, so this code needs to do an update + // instead of just a write. + val stage = getOrCreateStage(stageInfo) + stage.jobs :+= job + stage.jobIds += event.jobId + update(stage) + } + } + + override def onJobEnd(event: SparkListenerJobEnd): Unit = { + liveJobs.remove(event.jobId).foreach { job => + job.status = event.jobResult match { + case JobSucceeded => JobExecutionStatus.SUCCEEDED + case JobFailed(_) => JobExecutionStatus.FAILED + } + + job.completionTime = Some(new Date(event.time)) + update(job) + } + } + + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + val stage = getOrCreateStage(event.stageInfo) + stage.status = v1.StageStatus.ACTIVE + stage.schedulingPool = Option(event.properties).flatMap { p => + Option(p.getProperty("spark.scheduler.pool")) + }.getOrElse(SparkUI.DEFAULT_POOL_NAME) + + // Look at all active jobs to find the ones that mention this stage. + stage.jobs = liveJobs.values + .filter(_.stageIds.contains(event.stageInfo.stageId)) + .toSeq + stage.jobIds = stage.jobs.map(_.jobId).toSet + + stage.jobs.foreach { job => + job.completedStages = job.completedStages - event.stageInfo.stageId + job.activeStages += 1 + update(job) + } + + event.stageInfo.rddInfos.foreach { info => + if (info.storageLevel.isValid) { + update(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info))) + } + } + + update(stage) + } + + override def onTaskStart(event: SparkListenerTaskStart): Unit = { + val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId) + liveTasks.put(event.taskInfo.taskId, task) + update(task) + + liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + stage.activeTasks += 1 + stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) + update(stage) + + stage.jobs.foreach { job => + job.activeTasks += 1 + update(job) + } + } + + liveExecutors.get(event.taskInfo.executorId).foreach { exec => + exec.activeTasks += 1 + exec.totalTasks += 1 + update(exec) + } + } + + override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = { + // Call update on the task so that the "getting result" time is written to the store; the + // value is part of the mutable TaskInfo state that the live entity already references. + liveTasks.get(event.taskInfo.taskId).foreach { task => + update(task) + } + } + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + // TODO: can this really happen? + if (event.taskInfo == null) { + return + } + + val metricsDelta = liveTasks.remove(event.taskInfo.taskId).map { task => + val errorMessage = event.reason match { + case Success => + None + case k: TaskKilled => + Some(k.reason) + case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates + Some(e.toErrorString) + case e: TaskFailedReason => // All other failure cases + Some(e.toErrorString) + case other => + logInfo(s"Unhandled task end reason: $other") + None + } + task.errorMessage = errorMessage + val delta = task.updateMetrics(event.taskMetrics) + update(task) + delta + }.orNull + + val (completedDelta, failedDelta) = event.reason match { + case Success => + (1, 0) + case _ => + (0, 1) + } + + liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + if (metricsDelta != null) { + stage.metrics.update(metricsDelta) + } + stage.activeTasks -= 1 + stage.completedTasks += completedDelta + stage.failedTasks += failedDelta + update(stage) + + stage.jobs.foreach { job => + job.activeTasks -= 1 + job.completedTasks += completedDelta + job.failedTasks += failedDelta + update(job) + } + + val esummary = stage.executorSummary(event.taskInfo.executorId) + esummary.taskTime += event.taskInfo.duration + esummary.succeededTasks += completedDelta + esummary.failedTasks += failedDelta + if (metricsDelta != null) { + esummary.metrics.update(metricsDelta) + } + update(esummary) + } + + liveExecutors.get(event.taskInfo.executorId).foreach { exec => + if (event.taskMetrics != null) { + val readMetrics = event.taskMetrics.shuffleReadMetrics + exec.totalGcTime += event.taskMetrics.jvmGCTime + exec.totalInputBytes += event.taskMetrics.inputMetrics.bytesRead + exec.totalShuffleRead += readMetrics.localBytesRead + readMetrics.remoteBytesRead + exec.totalShuffleWrite += event.taskMetrics.shuffleWriteMetrics.bytesWritten + } + + exec.activeTasks -= 1 + exec.completedTasks += completedDelta + exec.failedTasks += failedDelta + exec.totalDuration += event.taskInfo.duration + update(exec) + } + } + + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { + liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId)).foreach { stage => + stage.info = event.stageInfo + + // Because of SPARK-20205, old event logs may contain valid stages without a submission time + // in their start event. In those cases, we can only detect whether a stage was skipped by + // waiting until the completion event, at which point the field would have been set. + stage.status = event.stageInfo.failureReason match { + case Some(_) => v1.StageStatus.FAILED + case _ if event.stageInfo.submissionTime.isDefined => v1.StageStatus.COMPLETE + case _ => v1.StageStatus.SKIPPED + } + update(stage) + + stage.jobs.foreach { job => + stage.status match { + case v1.StageStatus.COMPLETE => + job.completedStages += event.stageInfo.stageId + case v1.StageStatus.SKIPPED => + job.skippedStages += event.stageInfo.stageId + job.skippedTasks += event.stageInfo.numTasks + case _ => + job.failedStages += 1 + } + job.activeStages -= 1 + update(job) + } + + stage.executorSummaries.values.foreach(update) + update(stage) + } + } + + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { + // This needs to set fields that are already set by onExecutorAdded because the driver is + // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. + val exec = getOrCreateExecutor(event.blockManagerId.executorId) + exec.hostPort = event.blockManagerId.hostPort + event.maxOnHeapMem.foreach { _ => + exec.totalOnHeap = event.maxOnHeapMem.get + exec.totalOffHeap = event.maxOffHeapMem.get + } + exec.isActive = true + exec.maxMemory = event.maxMem + update(exec) + } + + override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { + // Nothing to do here. Covered by onExecutorRemoved. + } + + override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { + liveRDDs.remove(event.rddId) + kvstore.delete(classOf[RDDStorageInfoWrapper], event.rddId) + } + + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + event.accumUpdates.foreach { case (taskId, sid, sAttempt, accumUpdates) => + liveTasks.get(taskId).foreach { task => + val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates) + val delta = task.updateMetrics(metrics) + update(task) + + liveStages.get((sid, sAttempt)).foreach { stage => + stage.metrics.update(delta) + update(stage) + + val esummary = stage.executorSummary(event.execId) + esummary.metrics.update(delta) + update(esummary) + } + } + } + } + + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { + event.blockUpdatedInfo.blockId match { + case block: RDDBlockId => updateRDDBlock(event, block) + case _ => // TODO: API only covers RDD storage. + } + } + + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { + val executorId = event.blockUpdatedInfo.blockManagerId.executorId + + // Whether values are being added to or removed from the existing accounting. + val storageLevel = event.blockUpdatedInfo.storageLevel + val diskDelta = event.blockUpdatedInfo.diskSize * (if (storageLevel.useDisk) 1 else -1) + val memoryDelta = event.blockUpdatedInfo.memSize * (if (storageLevel.useMemory) 1 else -1) + + // Function to apply a delta to a value, but ensure that it doesn't go negative. + def newValue(old: Long, delta: Long): Long = math.max(0, old + delta) + + val updatedStorageLevel = if (storageLevel.isValid) { + Some(storageLevel.description) + } else { + None + } + + // We need information about the executor to update some memory accounting values in the + // RDD info, so read that beforehand. + val maybeExec = liveExecutors.get(executorId) + var rddBlocksDelta = 0 + + // Update the block entry in the RDD info, keeping track of the deltas above so that we + // can update the executor information too. + liveRDDs.get(block.rddId).foreach { rdd => + val partition = rdd.partition(block.name) + + val executors = if (updatedStorageLevel.isDefined) { + if (!partition.executors.contains(executorId)) { + rddBlocksDelta = 1 + } + partition.executors + executorId + } else { + rddBlocksDelta = -1 + partition.executors - executorId + } + + // Only update the partition if it's still stored in some executor, otherwise get rid of it. + if (executors.nonEmpty) { + if (updatedStorageLevel.isDefined) { + partition.storageLevel = updatedStorageLevel.get + } + partition.memoryUsed = newValue(partition.memoryUsed, memoryDelta) + partition.diskUsed = newValue(partition.diskUsed, diskDelta) + partition.executors = executors + } else { + rdd.removePartition(block.name) + } + + maybeExec.foreach { exec => + if (exec.rddBlocks + rddBlocksDelta > 0) { + val dist = rdd.distribution(exec) + dist.memoryRemaining = newValue(dist.memoryRemaining, -memoryDelta) + dist.memoryUsed = newValue(dist.memoryUsed, memoryDelta) + dist.diskUsed = newValue(dist.diskUsed, diskDelta) + + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + dist.offHeapUsed = newValue(dist.offHeapUsed, memoryDelta) + dist.offHeapRemaining = newValue(dist.offHeapRemaining, -memoryDelta) + } else { + dist.onHeapUsed = newValue(dist.onHeapUsed, memoryDelta) + dist.onHeapRemaining = newValue(dist.onHeapRemaining, -memoryDelta) + } + } + } else { + rdd.removeDistribution(exec) + } + } + + if (updatedStorageLevel.isDefined) { + rdd.storageLevel = updatedStorageLevel.get + } + rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) + rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) + update(rdd) + } + + maybeExec.foreach { exec => + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) + } else { + exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) + } + } + exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) + exec.diskUsed = newValue(exec.diskUsed, diskDelta) + exec.rddBlocks += rddBlocksDelta + if (exec.hasMemoryInfo || rddBlocksDelta != 0) { + update(exec) + } + } + } + + private def getOrCreateExecutor(executorId: String): LiveExecutor = { + liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId)) + } + + private def getOrCreateStage(info: StageInfo): LiveStage = { + val stage = liveStages.getOrElseUpdate((info.stageId, info.attemptId), new LiveStage()) + stage.info = info + stage + } + + private def update(entity: LiveEntity): Unit = { + entity.write(kvstore) + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/KVUtils.scala b/core/src/main/scala/org/apache/spark/status/KVUtils.scala new file mode 100644 index 0000000000000..4638511944c61 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/KVUtils.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File + +import scala.annotation.meta.getter +import scala.language.implicitConversions +import scala.reflect.{classTag, ClassTag} + +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.module.scala.DefaultScalaModule + +import org.apache.spark.internal.Logging +import org.apache.spark.util.kvstore._ + +private[spark] object KVUtils extends Logging { + + /** Use this to annotate constructor params to be used as KVStore indices. */ + type KVIndexParam = KVIndex @getter + + /** + * A KVStoreSerializer that provides Scala types serialization too, and uses the same options as + * the API serializer. + */ + private[spark] class KVStoreScalaSerializer extends KVStoreSerializer { + + mapper.registerModule(DefaultScalaModule) + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) + + } + + /** + * Open or create a LevelDB store. + * + * @param path Location of the store. + * @param metadata Metadata value to compare to the data in the store. If the store does not + * contain any metadata (e.g. it's a new store), this value is written as + * the store's metadata. + */ + def open[M: ClassTag](path: File, metadata: M): LevelDB = { + require(metadata != null, "Metadata is required.") + + val db = new LevelDB(path, new KVStoreScalaSerializer()) + val dbMeta = db.getMetadata(classTag[M].runtimeClass) + if (dbMeta == null) { + db.setMetadata(metadata) + } else if (dbMeta != metadata) { + db.close() + throw new MetadataMismatchException() + } + + db + } + + private[spark] class MetadataMismatchException extends Exception + +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala new file mode 100644 index 0000000000000..63fa36580bc7d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -0,0 +1,526 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} +import org.apache.spark.status.api.v1 +import org.apache.spark.storage.RDDInfo +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.kvstore.KVStore + +/** + * A mutable representation of a live entity in Spark (jobs, stages, tasks, et al). Every live + * entity uses one of these instances to keep track of their evolving state, and periodically + * flush an immutable view of the entity to the app state store. + */ +private[spark] abstract class LiveEntity { + + def write(store: KVStore): Unit = { + store.write(doUpdate()) + } + + /** + * Returns an updated view of entity data, to be stored in the status store, reflecting the + * latest information collected by the listener. + */ + protected def doUpdate(): Any + +} + +private class LiveJob( + val jobId: Int, + name: String, + submissionTime: Option[Date], + val stageIds: Seq[Int], + jobGroup: Option[String], + numTasks: Int) extends LiveEntity { + + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + + var skippedTasks = 0 + var skippedStages = Set[Int]() + + var status = JobExecutionStatus.RUNNING + var completionTime: Option[Date] = None + + var completedStages: Set[Int] = Set() + var activeStages = 0 + var failedStages = 0 + + override protected def doUpdate(): Any = { + val info = new v1.JobData( + jobId, + name, + None, // description is always None? + submissionTime, + completionTime, + stageIds, + jobGroup, + status, + numTasks, + activeTasks, + completedTasks, + skippedTasks, + failedTasks, + activeStages, + completedStages.size, + skippedStages.size, + failedStages) + new JobDataWrapper(info, skippedStages) + } + +} + +private class LiveTask( + info: TaskInfo, + stageId: Int, + stageAttemptId: Int) extends LiveEntity { + + import LiveEntityHelpers._ + + private var recordedMetrics: v1.TaskMetrics = null + + var errorMessage: Option[String] = None + + /** + * Update the metrics for the task and return the difference between the previous and new + * values. + */ + def updateMetrics(metrics: TaskMetrics): v1.TaskMetrics = { + if (metrics != null) { + val old = recordedMetrics + recordedMetrics = new v1.TaskMetrics( + metrics.executorDeserializeTime, + metrics.executorDeserializeCpuTime, + metrics.executorRunTime, + metrics.executorCpuTime, + metrics.resultSize, + metrics.jvmGCTime, + metrics.resultSerializationTime, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + new v1.InputMetrics( + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead), + new v1.OutputMetrics( + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten), + new v1.ShuffleReadMetrics( + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead), + new v1.ShuffleWriteMetrics( + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten)) + if (old != null) calculateMetricsDelta(recordedMetrics, old) else recordedMetrics + } else { + null + } + } + + /** + * Return a new TaskMetrics object containing the delta of the various fields of the given + * metrics objects. This is currently targeted at updating stage data, so it does not + * necessarily calculate deltas for all the fields. + */ + private def calculateMetricsDelta( + metrics: v1.TaskMetrics, + old: v1.TaskMetrics): v1.TaskMetrics = { + val shuffleWriteDelta = new v1.ShuffleWriteMetrics( + metrics.shuffleWriteMetrics.bytesWritten - old.shuffleWriteMetrics.bytesWritten, + 0L, + metrics.shuffleWriteMetrics.recordsWritten - old.shuffleWriteMetrics.recordsWritten) + + val shuffleReadDelta = new v1.ShuffleReadMetrics( + 0L, 0L, 0L, + metrics.shuffleReadMetrics.remoteBytesRead - old.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk - + old.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead - old.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead - old.shuffleReadMetrics.recordsRead) + + val inputDelta = new v1.InputMetrics( + metrics.inputMetrics.bytesRead - old.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead - old.inputMetrics.recordsRead) + + val outputDelta = new v1.OutputMetrics( + metrics.outputMetrics.bytesWritten - old.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten - old.outputMetrics.recordsWritten) + + new v1.TaskMetrics( + 0L, 0L, + metrics.executorRunTime - old.executorRunTime, + metrics.executorCpuTime - old.executorCpuTime, + 0L, 0L, 0L, + metrics.memoryBytesSpilled - old.memoryBytesSpilled, + metrics.diskBytesSpilled - old.diskBytesSpilled, + inputDelta, + outputDelta, + shuffleReadDelta, + shuffleWriteDelta) + } + + override protected def doUpdate(): Any = { + val task = new v1.TaskData( + info.taskId, + info.index, + info.attemptNumber, + new Date(info.launchTime), + if (info.finished) Some(info.duration) else None, + info.executorId, + info.host, + info.status, + info.taskLocality.toString(), + info.speculative, + newAccumulatorInfos(info.accumulables), + errorMessage, + Option(recordedMetrics)) + new TaskDataWrapper(task) + } + +} + +private class LiveExecutor(val executorId: String) extends LiveEntity { + + var hostPort: String = null + var host: String = null + var isActive = true + var totalCores = 0 + + var rddBlocks = 0 + var memoryUsed = 0L + var diskUsed = 0L + var maxTasks = 0 + var maxMemory = 0L + + var totalTasks = 0 + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + var totalDuration = 0L + var totalGcTime = 0L + var totalInputBytes = 0L + var totalShuffleRead = 0L + var totalShuffleWrite = 0L + var isBlacklisted = false + + var executorLogs = Map[String, String]() + + // Memory metrics. They may not be recorded (e.g. old event logs) so if totalOnHeap is not + // initialized, the store will not contain this information. + var totalOnHeap = -1L + var totalOffHeap = 0L + var usedOnHeap = 0L + var usedOffHeap = 0L + + def hasMemoryInfo: Boolean = totalOnHeap >= 0L + + def hostname: String = if (host != null) host else hostPort.split(":")(0) + + override protected def doUpdate(): Any = { + val memoryMetrics = if (totalOnHeap >= 0) { + Some(new v1.MemoryMetrics(usedOnHeap, usedOffHeap, totalOnHeap, totalOffHeap)) + } else { + None + } + + val info = new v1.ExecutorSummary( + executorId, + if (hostPort != null) hostPort else host, + isActive, + rddBlocks, + memoryUsed, + diskUsed, + totalCores, + maxTasks, + activeTasks, + failedTasks, + completedTasks, + totalTasks, + totalDuration, + totalGcTime, + totalInputBytes, + totalShuffleRead, + totalShuffleWrite, + isBlacklisted, + maxMemory, + executorLogs, + memoryMetrics) + new ExecutorSummaryWrapper(info) + } + +} + +/** Metrics tracked per stage (both total and per executor). */ +private class MetricsTracker { + var executorRunTime = 0L + var executorCpuTime = 0L + var inputBytes = 0L + var inputRecords = 0L + var outputBytes = 0L + var outputRecords = 0L + var shuffleReadBytes = 0L + var shuffleReadRecords = 0L + var shuffleWriteBytes = 0L + var shuffleWriteRecords = 0L + var memoryBytesSpilled = 0L + var diskBytesSpilled = 0L + + def update(delta: v1.TaskMetrics): Unit = { + executorRunTime += delta.executorRunTime + executorCpuTime += delta.executorCpuTime + inputBytes += delta.inputMetrics.bytesRead + inputRecords += delta.inputMetrics.recordsRead + outputBytes += delta.outputMetrics.bytesWritten + outputRecords += delta.outputMetrics.recordsWritten + shuffleReadBytes += delta.shuffleReadMetrics.localBytesRead + + delta.shuffleReadMetrics.remoteBytesRead + shuffleReadRecords += delta.shuffleReadMetrics.recordsRead + shuffleWriteBytes += delta.shuffleWriteMetrics.bytesWritten + shuffleWriteRecords += delta.shuffleWriteMetrics.recordsWritten + memoryBytesSpilled += delta.memoryBytesSpilled + diskBytesSpilled += delta.diskBytesSpilled + } + +} + +private class LiveExecutorStageSummary( + stageId: Int, + attemptId: Int, + executorId: String) extends LiveEntity { + + var taskTime = 0L + var succeededTasks = 0 + var failedTasks = 0 + var killedTasks = 0 + + val metrics = new MetricsTracker() + + override protected def doUpdate(): Any = { + val info = new v1.ExecutorStageSummary( + taskTime, + failedTasks, + succeededTasks, + metrics.inputBytes, + metrics.outputBytes, + metrics.shuffleReadBytes, + metrics.shuffleWriteBytes, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled) + new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) + } + +} + +private class LiveStage extends LiveEntity { + + import LiveEntityHelpers._ + + var jobs = Seq[LiveJob]() + var jobIds = Set[Int]() + + var info: StageInfo = null + var status = v1.StageStatus.PENDING + + var schedulingPool: String = SparkUI.DEFAULT_POOL_NAME + + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + + var firstLaunchTime = Long.MaxValue + + val metrics = new MetricsTracker() + + val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + + def executorSummary(executorId: String): LiveExecutorStageSummary = { + executorSummaries.getOrElseUpdate(executorId, + new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) + } + + override protected def doUpdate(): Any = { + val update = new v1.StageData( + status, + info.stageId, + info.attemptId, + + activeTasks, + completedTasks, + failedTasks, + + metrics.executorRunTime, + metrics.executorCpuTime, + info.submissionTime.map(new Date(_)), + if (firstLaunchTime < Long.MaxValue) Some(new Date(firstLaunchTime)) else None, + info.completionTime.map(new Date(_)), + + metrics.inputBytes, + metrics.inputRecords, + metrics.outputBytes, + metrics.outputRecords, + metrics.shuffleReadBytes, + metrics.shuffleReadRecords, + metrics.shuffleWriteBytes, + metrics.shuffleWriteRecords, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + + info.name, + info.details, + schedulingPool, + + newAccumulatorInfos(info.accumulables.values), + None, + None) + + new StageDataWrapper(update, jobIds) + } + +} + +private class LiveRDDPartition(val blockName: String) { + + var executors = Set[String]() + var storageLevel: String = null + var memoryUsed = 0L + var diskUsed = 0L + + def toApi(): v1.RDDPartitionInfo = { + new v1.RDDPartitionInfo( + blockName, + storageLevel, + memoryUsed, + diskUsed, + executors.toSeq.sorted) + } + +} + +private class LiveRDDDistribution(val exec: LiveExecutor) { + + var memoryRemaining = exec.maxMemory + var memoryUsed = 0L + var diskUsed = 0L + + var onHeapUsed = 0L + var offHeapUsed = 0L + var onHeapRemaining = 0L + var offHeapRemaining = 0L + + def toApi(): v1.RDDDataDistribution = { + new v1.RDDDataDistribution( + exec.hostPort, + memoryUsed, + memoryRemaining, + diskUsed, + if (exec.hasMemoryInfo) Some(onHeapUsed) else None, + if (exec.hasMemoryInfo) Some(offHeapUsed) else None, + if (exec.hasMemoryInfo) Some(onHeapRemaining) else None, + if (exec.hasMemoryInfo) Some(offHeapRemaining) else None) + } + +} + +private class LiveRDD(info: RDDInfo) extends LiveEntity { + + var storageLevel: String = info.storageLevel.description + var memoryUsed = 0L + var diskUsed = 0L + + private val partitions = new HashMap[String, LiveRDDPartition]() + private val distributions = new HashMap[String, LiveRDDDistribution]() + + def partition(blockName: String): LiveRDDPartition = { + partitions.getOrElseUpdate(blockName, new LiveRDDPartition(blockName)) + } + + def removePartition(blockName: String): Unit = partitions.remove(blockName) + + def distribution(exec: LiveExecutor): LiveRDDDistribution = { + distributions.getOrElseUpdate(exec.hostPort, new LiveRDDDistribution(exec)) + } + + def removeDistribution(exec: LiveExecutor): Unit = { + distributions.remove(exec.hostPort) + } + + override protected def doUpdate(): Any = { + val parts = if (partitions.nonEmpty) { + Some(partitions.values.toList.sortBy(_.blockName).map(_.toApi())) + } else { + None + } + + val dists = if (distributions.nonEmpty) { + Some(distributions.values.toList.sortBy(_.exec.executorId).map(_.toApi())) + } else { + None + } + + val rdd = new v1.RDDStorageInfo( + info.id, + info.name, + info.numPartitions, + partitions.size, + storageLevel, + memoryUsed, + diskUsed, + dists, + parts) + + new RDDStorageInfoWrapper(rdd) + } + +} + +private object LiveEntityHelpers { + + def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { + accums + .filter { acc => + // We don't need to store internal or SQL accumulables as their values will be shown in + // other places, so drop them to reduce the memory usage. + !acc.internal && (!acc.metadata.isDefined || + acc.metadata.get != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) + } + .map { acc => + new v1.AccumulableInfo( + acc.id, + acc.name.map(_.intern()).orNull, + acc.update.map(_.toString()), + acc.value.map(_.toString()).orNull) + } + .toSeq + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 4a4ed954d689e..5f69949c618fd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -71,7 +71,7 @@ private[v1] object AllStagesResource { val taskData = if (includeDetails) { Some(stageUiData.taskData.map { case (k, v) => - k -> convertTaskData(v, stageUiData.lastUpdateTime) }) + k -> convertTaskData(v, stageUiData.lastUpdateTime) }.toMap) } else { None } @@ -88,7 +88,7 @@ private[v1] object AllStagesResource { memoryBytesSpilled = summary.memoryBytesSpilled, diskBytesSpilled = summary.diskBytesSpilled ) - }) + }.toMap) } else { None } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 31659b25db318..bff6f90823f40 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -16,11 +16,11 @@ */ package org.apache.spark.status.api.v1 +import java.lang.{Long => JLong} import java.util.Date -import scala.collection.Map - import com.fasterxml.jackson.annotation.JsonIgnoreProperties +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.JobExecutionStatus @@ -129,9 +129,13 @@ class RDDDataDistribution private[spark]( val memoryUsed: Long, val memoryRemaining: Long, val diskUsed: Long, + @JsonDeserialize(contentAs = classOf[JLong]) val onHeapMemoryUsed: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val offHeapMemoryUsed: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val onHeapMemoryRemaining: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val offHeapMemoryRemaining: Option[Long]) class RDDPartitionInfo private[spark]( @@ -179,7 +183,8 @@ class TaskData private[spark]( val index: Int, val attempt: Int, val launchTime: Date, - val duration: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[JLong]) + val duration: Option[Long], val executorId: String, val host: String, val status: String, diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala new file mode 100644 index 0000000000000..9579accd2cba7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import com.fasterxml.jackson.annotation.JsonIgnore + +import org.apache.spark.status.KVUtils._ +import org.apache.spark.status.api.v1._ +import org.apache.spark.util.kvstore.KVIndex + +private[spark] class ApplicationInfoWrapper(val info: ApplicationInfo) { + + @JsonIgnore @KVIndex + def id: String = info.id + +} + +private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { + + @JsonIgnore @KVIndex + private[this] val id: String = info.id + + @JsonIgnore @KVIndex("active") + private[this] val active: Boolean = info.isActive + + @JsonIgnore @KVIndex("host") + val host: String = info.hostPort.split(":")(0) + +} + +/** + * Keep track of the existing stages when the job was submitted, and those that were + * completed during the job's execution. This allows a more accurate acounting of how + * many tasks were skipped for the job. + */ +private[spark] class JobDataWrapper( + val info: JobData, + val skippedStages: Set[Int]) { + + @JsonIgnore @KVIndex + private[this] val id: Int = info.jobId + +} + +private[spark] class StageDataWrapper( + val info: StageData, + val jobIds: Set[Int]) { + + @JsonIgnore @KVIndex + def id: Array[Int] = Array(info.stageId, info.attemptId) + +} + +private[spark] class TaskDataWrapper(val info: TaskData) { + + @JsonIgnore @KVIndex + def id: Long = info.taskId + +} + +private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { + + @JsonIgnore @KVIndex + def id: Int = info.id + + @JsonIgnore @KVIndex("cached") + def cached: Boolean = info.numCachedPartitions > 0 + +} + +private[spark] class ExecutorStageSummaryWrapper( + val stageId: Int, + val stageAttemptId: Int, + val executorId: String, + val info: ExecutorStageSummary) { + + @JsonIgnore @KVIndex + val id: Array[Any] = Array(stageId, stageAttemptId, executorId) + + @JsonIgnore @KVIndex("stage") + private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 2141934c92640..03bd3eaf579f3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -611,7 +611,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Manually overwrite the version in the listing db; this should cause the new provider to // discard all data because the versions don't match. - val meta = new KVStoreMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, + val meta = new FsHistoryProviderMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, conf.get(LOCAL_STORE_DIR).get) oldProvider.listing.setMetadata(meta) oldProvider.stop() diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala new file mode 100644 index 0000000000000..6f7a0c14dd684 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -0,0 +1,690 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File +import java.util.{Date, Properties} + +import scala.collection.JavaConverters._ +import scala.reflect.{classTag, ClassTag} + +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster._ +import org.apache.spark.status.api.v1 +import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore._ + +class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { + + private var time: Long = _ + private var testDir: File = _ + private var store: KVStore = _ + + before { + time = 0L + testDir = Utils.createTempDir() + store = KVUtils.open(testDir, getClass().getName()) + } + + after { + store.close() + Utils.deleteRecursively(testDir) + } + + test("scheduler events") { + val listener = new AppStatusListener(store) + + // Start the application. + time += 1 + listener.onApplicationStart(SparkListenerApplicationStart( + "name", + Some("id"), + time, + "user", + Some("attempt"), + None)) + + check[ApplicationInfoWrapper]("id") { app => + assert(app.info.name === "name") + assert(app.info.id === "id") + assert(app.info.attempts.size === 1) + + val attempt = app.info.attempts.head + assert(attempt.attemptId === Some("attempt")) + assert(attempt.startTime === new Date(time)) + assert(attempt.lastUpdated === new Date(time)) + assert(attempt.endTime.getTime() === -1L) + assert(attempt.sparkUser === "user") + assert(!attempt.completed) + } + + // Start a couple of executors. + time += 1 + val execIds = Array("1", "2") + + execIds.foreach { id => + listener.onExecutorAdded(SparkListenerExecutorAdded(time, id, + new ExecutorInfo(s"$id.example.com", 1, Map()))) + } + + execIds.foreach { id => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + assert(exec.info.hostPort === s"$id.example.com") + assert(exec.info.isActive) + } + } + + // Start a job with 2 stages / 4 tasks each + time += 1 + val stages = Seq( + new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1"), + new StageInfo(2, 0, "stage2", 4, Nil, Seq(1), "details2")) + + val jobProps = new Properties() + jobProps.setProperty(SparkContext.SPARK_JOB_GROUP_ID, "jobGroup") + jobProps.setProperty("spark.scheduler.pool", "schedPool") + + listener.onJobStart(SparkListenerJobStart(1, time, stages, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.jobId === 1) + assert(job.info.name === stages.last.name) + assert(job.info.description === None) + assert(job.info.status === JobExecutionStatus.RUNNING) + assert(job.info.submissionTime === Some(new Date(time))) + assert(job.info.jobGroup === Some("jobGroup")) + } + + stages.foreach { info => + check[StageDataWrapper](key(info)) { stage => + assert(stage.info.status === v1.StageStatus.PENDING) + assert(stage.jobIds === Set(1)) + } + } + + // Submit stage 1 + time += 1 + stages.head.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stages.head, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 1) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.status === v1.StageStatus.ACTIVE) + assert(stage.info.submissionTime === Some(new Date(stages.head.submissionTime.get))) + assert(stage.info.schedulingPool === "schedPool") + } + + // Start tasks from stage 1 + time += 1 + var _taskIdTracker = -1L + def nextTaskId(): Long = { + _taskIdTracker += 1 + _taskIdTracker + } + + def createTasks(count: Int, time: Long): Seq[TaskInfo] = { + (1 to count).map { id => + val exec = execIds(id.toInt % execIds.length) + val taskId = nextTaskId() + new TaskInfo(taskId, taskId.toInt, 1, time, exec, s"$exec.example.com", + TaskLocality.PROCESS_LOCAL, id % 2 == 0) + } + } + + val s1Tasks = createTasks(4, time) + s1Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) + } + + assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveTasks === s1Tasks.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numActiveTasks === s1Tasks.size) + assert(stage.info.firstTaskLaunchedTime === Some(new Date(s1Tasks.head.launchTime))) + } + + s1Tasks.foreach { task => + check[TaskDataWrapper](task.taskId) { wrapper => + assert(wrapper.info.taskId === task.taskId) + assert(wrapper.info.index === task.index) + assert(wrapper.info.attempt === task.attemptNumber) + assert(wrapper.info.launchTime === new Date(task.launchTime)) + assert(wrapper.info.executorId === task.executorId) + assert(wrapper.info.host === task.host) + assert(wrapper.info.status === task.status) + assert(wrapper.info.taskLocality === task.taskLocality.toString()) + assert(wrapper.info.speculative === task.speculative) + } + } + + // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code. + s1Tasks.foreach { task => + val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), + Some(1L), None, true, false, None) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( + task.executorId, + Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum))))) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.memoryBytesSpilled === s1Tasks.size) + } + + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)).last(key(stages.head)).asScala.toSeq + assert(execs.size > 0) + execs.foreach { exec => + assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2) + } + + // Fail one of the tasks, re-start it. + time += 1 + s1Tasks.head.markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", TaskResultLost, s1Tasks.head, null)) + + time += 1 + val reattempt = { + val orig = s1Tasks.head + // Task reattempts have a different ID, but the same index as the original. + new TaskInfo(nextTaskId(), orig.index, orig.attemptNumber + 1, time, orig.executorId, + s"${orig.executorId}.example.com", TaskLocality.PROCESS_LOCAL, orig.speculative) + } + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + reattempt)) + + assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1) + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1) + assert(job.info.numActiveTasks === s1Tasks.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === s1Tasks.size) + } + + check[TaskDataWrapper](s1Tasks.head.taskId) { task => + assert(task.info.status === s1Tasks.head.status) + assert(task.info.duration === Some(s1Tasks.head.duration)) + assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) + } + + check[TaskDataWrapper](reattempt.taskId) { task => + assert(task.info.index === s1Tasks.head.index) + assert(task.info.attempt === reattempt.attemptNumber) + } + + // Succeed all tasks in stage 1. + val pending = s1Tasks.drop(1) ++ Seq(reattempt) + + val s1Metrics = TaskMetrics.empty + s1Metrics.setExecutorCpuTime(2L) + s1Metrics.setExecutorRunTime(4L) + + time += 1 + pending.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", Success, task, s1Metrics)) + } + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1) + assert(job.info.numActiveTasks === 0) + assert(job.info.numCompletedTasks === pending.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === pending.size) + } + + pending.foreach { task => + check[TaskDataWrapper](task.taskId) { wrapper => + assert(wrapper.info.errorMessage === None) + assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) + assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) + } + } + + assert(store.count(classOf[TaskDataWrapper]) === pending.size + 1) + + // End stage 1. + time += 1 + stages.head.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stages.head)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 0) + assert(job.info.numCompletedStages === 1) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.status === v1.StageStatus.COMPLETE) + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === pending.size) + } + + // Submit stage 2. + time += 1 + stages.last.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stages.last, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 1) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.status === v1.StageStatus.ACTIVE) + assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get))) + } + + // Start and fail all tasks of stage 2. + time += 1 + val s2Tasks = createTasks(4, time) + s2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) + } + + time += 1 + s2Tasks.foreach { task => + task.markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId, + "taskType", TaskResultLost, task, null)) + } + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1 + s2Tasks.size) + assert(job.info.numActiveTasks === 0) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.numFailedTasks === s2Tasks.size) + assert(stage.info.numActiveTasks === 0) + } + + // Fail stage 2. + time += 1 + stages.last.completionTime = Some(time) + stages.last.failureReason = Some("uh oh") + listener.onStageCompleted(SparkListenerStageCompleted(stages.last)) + + check[JobDataWrapper](1) { job => + assert(job.info.numCompletedStages === 1) + assert(job.info.numFailedStages === 1) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.status === v1.StageStatus.FAILED) + assert(stage.info.numFailedTasks === s2Tasks.size) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === 0) + } + + // - Re-submit stage 2, all tasks, and succeed them and the stage. + val oldS2 = stages.last + val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks, + oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) + + time += 1 + newS2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(newS2, jobProps)) + assert(store.count(classOf[StageDataWrapper]) === 3) + + val newS2Tasks = createTasks(4, time) + + newS2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) + } + + time += 1 + newS2Tasks.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success, + task, null)) + } + + time += 1 + newS2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(newS2)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 0) + assert(job.info.numFailedStages === 1) + assert(job.info.numCompletedStages === 2) + } + + check[StageDataWrapper](key(newS2)) { stage => + assert(stage.info.status === v1.StageStatus.COMPLETE) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === newS2Tasks.size) + } + + // End job. + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + check[JobDataWrapper](1) { job => + assert(job.info.status === JobExecutionStatus.SUCCEEDED) + } + + // Submit a second job that re-uses stage 1 and stage 2. Stage 1 won't be re-run, but + // stage 2 will. In any case, the DAGScheduler creates new info structures that are copies + // of the old stages, so mimic that behavior here. The "new" stage 1 is submitted without + // a submission time, which means it is "skipped", and the stage 2 re-execution should not + // change the stats of the already finished job. + time += 1 + val j2Stages = Seq( + new StageInfo(3, 0, "stage1", 4, Nil, Nil, "details1"), + new StageInfo(4, 0, "stage2", 4, Nil, Seq(3), "details2")) + j2Stages.last.submissionTime = Some(time) + listener.onJobStart(SparkListenerJobStart(2, time, j2Stages, null)) + assert(store.count(classOf[JobDataWrapper]) === 2) + + listener.onStageSubmitted(SparkListenerStageSubmitted(j2Stages.head, jobProps)) + listener.onStageCompleted(SparkListenerStageCompleted(j2Stages.head)) + listener.onStageSubmitted(SparkListenerStageSubmitted(j2Stages.last, jobProps)) + assert(store.count(classOf[StageDataWrapper]) === 5) + + time += 1 + val j2s2Tasks = createTasks(4, time) + + j2s2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, + task)) + } + + time += 1 + j2s2Tasks.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId, + "taskType", Success, task, null)) + } + + time += 1 + j2Stages.last.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(j2Stages.last)) + + time += 1 + listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded)) + + check[JobDataWrapper](1) { job => + assert(job.info.numCompletedStages === 2) + assert(job.info.numCompletedTasks === s1Tasks.size + s2Tasks.size) + } + + check[JobDataWrapper](2) { job => + assert(job.info.status === JobExecutionStatus.SUCCEEDED) + assert(job.info.numCompletedStages === 1) + assert(job.info.numCompletedTasks === j2s2Tasks.size) + assert(job.info.numSkippedStages === 1) + assert(job.info.numSkippedTasks === s1Tasks.size) + } + + // Blacklist an executor. + time += 1 + listener.onExecutorBlacklisted(SparkListenerExecutorBlacklisted(time, "1", 42)) + check[ExecutorSummaryWrapper]("1") { exec => + assert(exec.info.isBlacklisted) + } + + time += 1 + listener.onExecutorUnblacklisted(SparkListenerExecutorUnblacklisted(time, "1")) + check[ExecutorSummaryWrapper]("1") { exec => + assert(!exec.info.isBlacklisted) + } + + // Blacklist a node. + time += 1 + listener.onNodeBlacklisted(SparkListenerNodeBlacklisted(time, "1.example.com", 2)) + check[ExecutorSummaryWrapper]("1") { exec => + assert(exec.info.isBlacklisted) + } + + time += 1 + listener.onNodeUnblacklisted(SparkListenerNodeUnblacklisted(time, "1.example.com")) + check[ExecutorSummaryWrapper]("1") { exec => + assert(!exec.info.isBlacklisted) + } + + // Stop executors. + listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "1", "Test")) + listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "2", "Test")) + + Seq("1", "2").foreach { id => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + assert(!exec.info.isActive) + } + } + + // End the application. + listener.onApplicationEnd(SparkListenerApplicationEnd(42L)) + + check[ApplicationInfoWrapper]("id") { app => + assert(app.info.name === "name") + assert(app.info.id === "id") + assert(app.info.attempts.size === 1) + + val attempt = app.info.attempts.head + assert(attempt.attemptId === Some("attempt")) + assert(attempt.startTime === new Date(1L)) + assert(attempt.lastUpdated === new Date(42L)) + assert(attempt.endTime === new Date(42L)) + assert(attempt.duration === 41L) + assert(attempt.sparkUser === "user") + assert(attempt.completed) + } + } + + test("storage events") { + val listener = new AppStatusListener(store) + val maxMemory = 42L + + // Register a couple of block managers. + val bm1 = BlockManagerId("1", "1.example.com", 42) + val bm2 = BlockManagerId("2", "2.example.com", 84) + Seq(bm1, bm2).foreach { bm => + listener.onExecutorAdded(SparkListenerExecutorAdded(1L, bm.executorId, + new ExecutorInfo(bm.host, 1, Map()))) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm, maxMemory)) + check[ExecutorSummaryWrapper](bm.executorId) { exec => + assert(exec.info.maxMemory === maxMemory) + } + } + + val rdd1b1 = RDDBlockId(1, 1) + val level = StorageLevel.MEMORY_AND_DISK + + // Submit a stage and make sure the RDD is recorded. + val rddInfo = new RDDInfo(rdd1b1.rddId, "rdd1", 2, level, Nil) + val stage = new StageInfo(1, 0, "stage1", 4, Seq(rddInfo), Nil, "details1") + listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.name === rddInfo.name) + assert(wrapper.info.numPartitions === rddInfo.numPartitions) + assert(wrapper.info.storageLevel === rddInfo.storageLevel.description) + } + + // Add partition 1 replicated on two block managers. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, level, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 1L) + assert(wrapper.info.diskUsed === 1L) + + assert(wrapper.info.dataDistribution.isDefined) + assert(wrapper.info.dataDistribution.get.size === 1) + + val dist = wrapper.info.dataDistribution.get.head + assert(dist.address === bm1.hostPort) + assert(dist.memoryUsed === 1L) + assert(dist.diskUsed === 1L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + assert(wrapper.info.partitions.isDefined) + assert(wrapper.info.partitions.get.size === 1) + + val part = wrapper.info.partitions.get.head + assert(part.blockName === rdd1b1.name) + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 1L) + assert(part.diskUsed === 1L) + assert(part.executors === Seq(bm1.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 1L) + assert(exec.info.diskUsed === 1L) + } + + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, level, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 2L) + assert(wrapper.info.diskUsed === 2L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 1L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm2.hostPort).get + assert(dist.memoryUsed === 1L) + assert(dist.diskUsed === 1L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get(0) + assert(part.memoryUsed === 2L) + assert(part.diskUsed === 2L) + assert(part.executors === Seq(bm1.executorId, bm2.executorId)) + } + + check[ExecutorSummaryWrapper](bm2.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 1L) + assert(exec.info.diskUsed === 1L) + } + + // Add a second partition only to bm 1. + val rdd1b2 = RDDBlockId(1, 2) + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b2, level, + 3L, 3L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 5L) + assert(wrapper.info.diskUsed === 5L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 2L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get + assert(dist.memoryUsed === 4L) + assert(dist.diskUsed === 4L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b2.name).get + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 3L) + assert(part.diskUsed === 3L) + assert(part.executors === Seq(bm1.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 2L) + assert(exec.info.memoryUsed === 4L) + assert(exec.info.diskUsed === 4L) + } + + // Remove block 1 from bm 1. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, + StorageLevel.NONE, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 4L) + assert(wrapper.info.diskUsed === 4L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 2L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get + assert(dist.memoryUsed === 3L) + assert(dist.diskUsed === 3L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b1.name).get + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 1L) + assert(part.diskUsed === 1L) + assert(part.executors === Seq(bm2.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 3L) + assert(exec.info.diskUsed === 3L) + } + + // Remove block 2 from bm 2. This should leave only block 2 info in the store. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, + StorageLevel.NONE, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 3L) + assert(wrapper.info.diskUsed === 3L) + assert(wrapper.info.dataDistribution.get.size === 1L) + assert(wrapper.info.partitions.get.size === 1L) + assert(wrapper.info.partitions.get(0).blockName === rdd1b2.name) + } + + check[ExecutorSummaryWrapper](bm2.executorId) { exec => + assert(exec.info.rddBlocks === 0L) + assert(exec.info.memoryUsed === 0L) + assert(exec.info.diskUsed === 0L) + } + + // Unpersist RDD1. + listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd1b1.rddId)) + intercept[NoSuchElementException] { + check[RDDStorageInfoWrapper](rdd1b1.rddId) { _ => () } + } + + } + + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) + + private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { + val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T] + fn(value) + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd299e074535e..45b8870f3b62f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,8 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // SPARK-18085: Better History Server scalability for many / large applications + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 4f8dc6b01ea787243a38678ea8199fbb0814cffc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 26 Oct 2017 21:41:45 +0100 Subject: [PATCH 770/779] [SPARK-22328][CORE] ClosureCleaner should not miss referenced superclass fields ## What changes were proposed in this pull request? When the given closure uses some fields defined in super class, `ClosureCleaner` can't figure them and don't set it properly. Those fields will be in null values. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19556 from viirya/SPARK-22328. --- .../apache/spark/util/ClosureCleaner.scala | 73 ++++++++++++++++--- .../spark/util/ClosureCleanerSuite.scala | 72 ++++++++++++++++++ 2 files changed, 133 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 48a1d7b84b61b..dfece5dd0670b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging { (seen - obj.getClass).toList } + /** Initializes the accessed fields for outer classes and their super classes. */ + private def initAccessedFields( + accessedFields: Map[Class[_], Set[String]], + outerClasses: Seq[Class[_]]): Unit = { + for (cls <- outerClasses) { + var currentClass = cls + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + accessedFields(currentClass) = Set.empty[String] + currentClass = currentClass.getSuperclass() + } + } + } + + /** Sets accessed fields for given class in clone object based on given object. */ + private def setAccessedFields( + outerClass: Class[_], + clone: AnyRef, + obj: AnyRef, + accessedFields: Map[Class[_], Set[String]]): Unit = { + for (fieldName <- accessedFields(outerClass)) { + val field = outerClass.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + field.set(clone, value) + } + } + + /** Clones a given object and sets accessed fields in cloned object. */ + private def cloneAndSetFields( + parent: AnyRef, + obj: AnyRef, + outerClass: Class[_], + accessedFields: Map[Class[_], Set[String]]): AnyRef = { + val clone = instantiateClass(outerClass, parent) + + var currentClass = outerClass + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + setAccessedFields(currentClass, clone, obj, accessedFields) + currentClass = currentClass.getSuperclass() + } + + clone + } + /** * Clean the given closure in place. * @@ -202,9 +250,8 @@ private[spark] object ClosureCleaner extends Logging { logDebug(s" + populating accessed fields because this is the starting closure") // Initialize accessed fields with the outer classes first // This step is needed to associate the fields to the correct classes later - for (cls <- outerClasses) { - accessedFields(cls) = Set.empty[String] - } + initAccessedFields(accessedFields, outerClasses) + // Populate accessed fields by visiting all fields and methods accessed by this and // all of its inner closures. If transitive cleaning is enabled, this may recursively // visits methods that belong to other classes in search of transitively referenced fields. @@ -250,13 +297,8 @@ private[spark] object ClosureCleaner extends Logging { // required fields from the original object. We need the parent here because the Java // language specification requires the first constructor parameter of any closure to be // its enclosing object. - val clone = instantiateClass(cls, parent) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - field.set(clone, value) - } + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + // If transitive cleaning is enabled, we recursively clean any enclosing closure using // the already populated accessed fields map of the starting closure if (cleanTransitively && isClosure(clone.getClass)) { @@ -395,8 +437,15 @@ private[util] class FieldAccessFinder( if (!visitedMethods.contains(m)) { // Keep track of visited methods to avoid potential infinite cycles visitedMethods += m - ClosureCleaner.getClassReader(cl).accept( - new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + + var currentClass = cl + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + ClosureCleaner.getClassReader(currentClass).accept( + new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + currentClass = currentClass.getSuperclass() + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 4920b7ee8bfb4..9a19baee9569e 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite { test("createNullValue") { new TestCreateNullValue().run() } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { + val concreteObject = new TestAbstractClass { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] = { + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1) + body(rdd) + } + } + + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ => + (n1, n2, s1, s2, d1, d2) + }.collect() + } + assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + val concreteObject = new TestAbstractClass2 { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2) + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.getData) + assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + } + + test("SPARK-22328: multiple outer classes have the same parent class") { + val concreteObject = new TestAbstractClass2 { + + val innerObject = new TestAbstractClass2 { + override val n1 = 222 + override val s1 = "bbb" + } + + val innerObject2 = new TestAbstractClass2 { + override val n1 = 444 + val n3 = 333 + val s3 = "ccc" + val d3 = 3.0d + + def getData: Int => (Int, Int, String, String, Double, Double, Int, String) = + _ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1) + } + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData) + assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb"))) + } + } } // A non-serializable class we create in closures to make sure that we aren't @@ -377,3 +434,18 @@ class TestCreateNullValue { nestedClosure() } } + +abstract class TestAbstractClass extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] +} + +abstract class TestAbstractClass2 extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d +} From 5415963d2caaf95604211419ffc4e29fff38e1d7 Mon Sep 17 00:00:00 2001 From: "Susan X. Huynh" Date: Thu, 26 Oct 2017 16:13:48 -0700 Subject: [PATCH 771/779] [SPARK-22131][MESOS] Mesos driver secrets ## Background In #18837 , ArtRand added Mesos secrets support to the dispatcher. **This PR is to add the same secrets support to the drivers.** This means if the secret configs are set, the driver will launch executors that have access to either env or file-based secrets. One use case for this is to support TLS in the driver <=> executor communication. ## What changes were proposed in this pull request? Most of the changes are a refactor of the dispatcher secrets support (#18837) - moving it to a common place that can be used by both the dispatcher and drivers. The same goes for the unit tests. ## How was this patch tested? There are four config combinations: [env or file-based] x [value or reference secret]. For each combination: - Added a unit test. - Tested in DC/OS. Author: Susan X. Huynh Closes #19437 from susanxhuynh/sh-mesos-driver-secret. --- docs/running-on-mesos.md | 111 ++++++++++--- .../apache/spark/deploy/mesos/config.scala | 64 ++++---- .../cluster/mesos/MesosClusterScheduler.scala | 138 +++------------- .../MesosCoarseGrainedSchedulerBackend.scala | 31 +++- .../MesosFineGrainedSchedulerBackend.scala | 4 +- .../mesos/MesosSchedulerBackendUtil.scala | 92 ++++++++++- .../mesos/MesosClusterSchedulerSuite.scala | 150 +++--------------- ...osCoarseGrainedSchedulerBackendSuite.scala | 34 +++- .../MesosSchedulerBackendUtilSuite.scala | 7 +- .../spark/scheduler/cluster/mesos/Utils.scala | 107 +++++++++++++ 10 files changed, 434 insertions(+), 304 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index e0944bc9f5f86..b7e3e6473c338 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -485,39 +485,106 @@ See the [configuration page](configuration.html) for information on Spark config - spark.mesos.driver.secret.envkeys - (none) - A comma-separated list that, if set, the contents of the secret referenced - by spark.mesos.driver.secret.names or spark.mesos.driver.secret.values will be - set to the provided environment variable in the driver's process. + spark.mesos.driver.secret.values, + spark.mesos.driver.secret.names, + spark.mesos.executor.secret.values, + spark.mesos.executor.secret.names, - - -spark.mesos.driver.secret.filenames (none) - A comma-separated list that, if set, the contents of the secret referenced by - spark.mesos.driver.secret.names or spark.mesos.driver.secret.values will be - written to the provided file. Paths are relative to the container's work - directory. Absolute paths must already exist. Consult the Mesos Secret - protobuf for more information. +

    + A secret is specified by its contents and destination. These properties + specify a secret's contents. To specify a secret's destination, see the cell below. +

    +

    + You can specify a secret's contents either (1) by value or (2) by reference. +

    +

    + (1) To specify a secret by value, set the + spark.mesos.[driver|executor].secret.values + property, to make the secret available in the driver or executors. + For example, to make a secret password "guessme" available to the driver process, set: + +

    spark.mesos.driver.secret.values=guessme
    +

    +

    + (2) To specify a secret that has been placed in a secret store + by reference, specify its name within the secret store + by setting the spark.mesos.[driver|executor].secret.names + property. For example, to make a secret password named "password" in a secret store + available to the driver process, set: + +

    spark.mesos.driver.secret.names=password
    +

    +

    + Note: To use a secret store, make sure one has been integrated with Mesos via a custom + SecretResolver + module. +

    +

    + To specify multiple secrets, provide a comma-separated list: + +

    spark.mesos.driver.secret.values=guessme,passwd123
    + + or + +
    spark.mesos.driver.secret.names=password1,password2
    +

    + - spark.mesos.driver.secret.names - (none) - A comma-separated list of secret references. Consult the Mesos Secret - protobuf for more information. + spark.mesos.driver.secret.envkeys, + spark.mesos.driver.secret.filenames, + spark.mesos.executor.secret.envkeys, + spark.mesos.executor.secret.filenames, - - - spark.mesos.driver.secret.values (none) - A comma-separated list of secret values. Consult the Mesos Secret - protobuf for more information. +

    + A secret is specified by its contents and destination. These properties + specify a secret's destination. To specify a secret's contents, see the cell above. +

    +

    + You can specify a secret's destination in the driver or + executors as either (1) an environment variable or (2) as a file. +

    +

    + (1) To make an environment-based secret, set the + spark.mesos.[driver|executor].secret.envkeys property. + The secret will appear as an environment variable with the + given name in the driver or executors. For example, to make a secret password available + to the driver process as $PASSWORD, set: + +

    spark.mesos.driver.secret.envkeys=PASSWORD
    +

    +

    + (2) To make a file-based secret, set the + spark.mesos.[driver|executor].secret.filenames property. + The secret will appear in the contents of a file with the given file name in + the driver or executors. For example, to make a secret password available in a + file named "pwdfile" in the driver process, set: + +

    spark.mesos.driver.secret.filenames=pwdfile
    +

    +

    + Paths are relative to the container's work directory. Absolute paths must + already exist. Note: File-based secrets require a custom + SecretResolver + module. +

    +

    + To specify env vars or file names corresponding to multiple secrets, + provide a comma-separated list: + +

    spark.mesos.driver.secret.envkeys=PASSWORD1,PASSWORD2
    + + or + +
    spark.mesos.driver.secret.filenames=pwdfile1,pwdfile2
    +

    diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 7e85de91c5d36..821534eb4fc38 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -23,6 +23,39 @@ import org.apache.spark.internal.config.ConfigBuilder package object config { + private[spark] class MesosSecretConfig private[config](taskType: String) { + private[spark] val SECRET_NAMES = + ConfigBuilder(s"spark.mesos.$taskType.secret.names") + .doc("A comma-separated list of secret reference names. Consult the Mesos Secret " + + "protobuf for more information.") + .stringConf + .toSequence + .createOptional + + private[spark] val SECRET_VALUES = + ConfigBuilder(s"spark.mesos.$taskType.secret.values") + .doc("A comma-separated list of secret values.") + .stringConf + .toSequence + .createOptional + + private[spark] val SECRET_ENVKEYS = + ConfigBuilder(s"spark.mesos.$taskType.secret.envkeys") + .doc("A comma-separated list of the environment variables to contain the secrets." + + "The environment variable will be set on the driver.") + .stringConf + .toSequence + .createOptional + + private[spark] val SECRET_FILENAMES = + ConfigBuilder(s"spark.mesos.$taskType.secret.filenames") + .doc("A comma-separated list of file paths secret will be written to. Consult the Mesos " + + "Secret protobuf for more information.") + .stringConf + .toSequence + .createOptional + } + /* Common app configuration. */ private[spark] val SHUFFLE_CLEANER_INTERVAL_S = @@ -64,36 +97,9 @@ package object config { .stringConf .createOptional - private[spark] val SECRET_NAME = - ConfigBuilder("spark.mesos.driver.secret.names") - .doc("A comma-separated list of secret reference names. Consult the Mesos Secret protobuf " + - "for more information.") - .stringConf - .toSequence - .createOptional - - private[spark] val SECRET_VALUE = - ConfigBuilder("spark.mesos.driver.secret.values") - .doc("A comma-separated list of secret values.") - .stringConf - .toSequence - .createOptional + private[spark] val driverSecretConfig = new MesosSecretConfig("driver") - private[spark] val SECRET_ENVKEY = - ConfigBuilder("spark.mesos.driver.secret.envkeys") - .doc("A comma-separated list of the environment variables to contain the secrets." + - "The environment variable will be set on the driver.") - .stringConf - .toSequence - .createOptional - - private[spark] val SECRET_FILENAME = - ConfigBuilder("spark.mesos.driver.secret.filenames") - .doc("A comma-seperated list of file paths secret will be written to. Consult the Mesos " + - "Secret protobuf for more information.") - .stringConf - .toSequence - .createOptional + private[spark] val executorSecretConfig = new MesosSecretConfig("executor") private[spark] val DRIVER_FAILOVER_TIMEOUT = ConfigBuilder("spark.mesos.driver.failoverTimeout") diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index ec533f91474f2..82470264f2a4a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -28,7 +28,6 @@ import org.apache.mesos.{Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason -import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} import org.apache.spark.deploy.mesos.MesosDriverDescription @@ -394,39 +393,20 @@ private[spark] class MesosClusterScheduler( } // add secret environment variables - getSecretEnvVar(desc).foreach { variable => - if (variable.getSecret.getReference.isInitialized) { - logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName}" + - s"on file ${variable.getName}") - } else { - logInfo(s"Setting secret on environment variable name=${variable.getName}") - } - envBuilder.addVariables(variable) + MesosSchedulerBackendUtil.getSecretEnvVar(desc.conf, config.driverSecretConfig) + .foreach { variable => + if (variable.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName} " + + s"on file ${variable.getName}") + } else { + logInfo(s"Setting secret on environment variable name=${variable.getName}") + } + envBuilder.addVariables(variable) } envBuilder.build() } - private def getSecretEnvVar(desc: MesosDriverDescription): List[Variable] = { - val secrets = getSecrets(desc) - val secretEnvKeys = desc.conf.get(config.SECRET_ENVKEY).getOrElse(Nil) - if (illegalSecretInput(secretEnvKeys, secrets)) { - throw new SparkException( - s"Need to give equal numbers of secrets and environment keys " + - s"for environment-based reference secrets got secrets $secrets, " + - s"and keys $secretEnvKeys") - } - - secrets.zip(secretEnvKeys).map { - case (s, k) => - Variable.newBuilder() - .setName(k) - .setType(Variable.Type.SECRET) - .setSecret(s) - .build - }.toList - } - private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = { val confUris = List(conf.getOption("spark.mesos.uris"), desc.conf.getOption("spark.mesos.uris"), @@ -440,6 +420,23 @@ private[spark] class MesosClusterScheduler( CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) } + private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo(desc.conf) + + MesosSchedulerBackendUtil.getSecretVolume(desc.conf, config.driverSecretConfig) + .foreach { volume => + if (volume.getSource.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName} " + + s"on file ${volume.getContainerPath}") + } else { + logInfo(s"Setting secret on file name=${volume.getContainerPath}") + } + containerInfo.addVolumes(volume) + } + + containerInfo + } + private def getDriverCommandValue(desc: MesosDriverDescription): String = { val dockerDefined = desc.conf.contains("spark.mesos.executor.docker.image") val executorUri = getDriverExecutorURI(desc) @@ -579,89 +576,6 @@ private[spark] class MesosClusterScheduler( .build } - private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { - val containerInfo = MesosSchedulerBackendUtil.containerInfo(desc.conf) - - getSecretVolume(desc).foreach { volume => - if (volume.getSource.getSecret.getReference.isInitialized) { - logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName}" + - s"on file ${volume.getContainerPath}") - } else { - logInfo(s"Setting secret on file name=${volume.getContainerPath}") - } - containerInfo.addVolumes(volume) - } - - containerInfo - } - - - private def getSecrets(desc: MesosDriverDescription): Seq[Secret] = { - def createValueSecret(data: String): Secret = { - Secret.newBuilder() - .setType(Secret.Type.VALUE) - .setValue(Secret.Value.newBuilder().setData(ByteString.copyFrom(data.getBytes))) - .build() - } - - def createReferenceSecret(name: String): Secret = { - Secret.newBuilder() - .setReference(Secret.Reference.newBuilder().setName(name)) - .setType(Secret.Type.REFERENCE) - .build() - } - - val referenceSecrets: Seq[Secret] = - desc.conf.get(config.SECRET_NAME).getOrElse(Nil).map(s => createReferenceSecret(s)) - - val valueSecrets: Seq[Secret] = { - desc.conf.get(config.SECRET_VALUE).getOrElse(Nil).map(s => createValueSecret(s)) - } - - if (valueSecrets.nonEmpty && referenceSecrets.nonEmpty) { - throw new SparkException("Cannot specify VALUE type secrets and REFERENCE types ones") - } - - if (referenceSecrets.nonEmpty) referenceSecrets else valueSecrets - } - - private def illegalSecretInput(dest: Seq[String], s: Seq[Secret]): Boolean = { - if (dest.isEmpty) { // no destination set (ie not using secrets of this type - return false - } - if (dest.nonEmpty && s.nonEmpty) { - // make sure there is a destination for each secret of this type - if (dest.length != s.length) { - return true - } - } - false - } - - private def getSecretVolume(desc: MesosDriverDescription): List[Volume] = { - val secrets = getSecrets(desc) - val secretPaths: Seq[String] = - desc.conf.get(config.SECRET_FILENAME).getOrElse(Nil) - - if (illegalSecretInput(secretPaths, secrets)) { - throw new SparkException( - s"Need to give equal numbers of secrets and file paths for file-based " + - s"reference secrets got secrets $secrets, and paths $secretPaths") - } - - secrets.zip(secretPaths).map { - case (s, p) => - val source = Volume.Source.newBuilder() - .setType(Volume.Source.Type.SECRET) - .setSecret(s) - Volume.newBuilder() - .setContainerPath(p) - .setSource(source) - .setMode(Volume.Mode.RO) - .build - }.toList - } - /** * This method takes all the possible candidates and attempt to schedule them with Mesos offers. * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 603c980cb268d..104ed01d293ce 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -28,7 +28,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.Future -import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException, TaskState} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config @@ -244,6 +244,17 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setValue(value) .build()) } + + MesosSchedulerBackendUtil.getSecretEnvVar(conf, executorSecretConfig).foreach { variable => + if (variable.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName} " + + s"on file ${variable.getName}") + } else { + logInfo(s"Setting secret on environment variable name=${variable.getName}") + } + environment.addVariables(variable) + } + val command = CommandInfo.newBuilder() .setEnvironment(environment) @@ -424,6 +435,22 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } } + private def getContainerInfo(conf: SparkConf): ContainerInfo.Builder = { + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo(conf) + + MesosSchedulerBackendUtil.getSecretVolume(conf, executorSecretConfig).foreach { volume => + if (volume.getSource.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName} " + + s"on file ${volume.getContainerPath}") + } else { + logInfo(s"Setting secret on file name=${volume.getContainerPath}") + } + containerInfo.addVolumes(volume) + } + + containerInfo + } + /** * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize * per-task memory and IO, tasks are round-robin assigned to offers. @@ -475,7 +502,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setName(s"${sc.appName} $taskId") .setLabels(MesosProtoUtils.mesosLabels(taskLabels)) .addAllResources(resourcesToUse.asJava) - .setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) + .setContainer(getContainerInfo(sc.conf)) tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 66b8e0a640121..d6d939d246109 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -28,6 +28,7 @@ import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} +import org.apache.spark.deploy.mesos.config import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -159,7 +160,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) - executorInfo.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) + executorInfo.setContainer( + MesosSchedulerBackendUtil.buildContainerInfo(sc.conf)) (executorInfo.build(), resourcesAfterMem.asJava) } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index f29e541addf23..bfb73611f0530 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -17,11 +17,15 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Parameter, Volume} +import org.apache.mesos.Protos.{ContainerInfo, Environment, Image, NetworkInfo, Parameter, Secret, Volume} import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo} +import org.apache.mesos.Protos.Environment.Variable +import org.apache.mesos.protobuf.ByteString -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkConf +import org.apache.spark.SparkException import org.apache.spark.deploy.mesos.config.{NETWORK_LABELS, NETWORK_NAME} +import org.apache.spark.deploy.mesos.config.MesosSecretConfig import org.apache.spark.internal.Logging /** @@ -122,7 +126,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .toList } - def containerInfo(conf: SparkConf): ContainerInfo.Builder = { + def buildContainerInfo(conf: SparkConf): ContainerInfo.Builder = { val containerType = if (conf.contains("spark.mesos.executor.docker.image") && conf.get("spark.mesos.containerizer", "docker") == "docker") { ContainerInfo.Type.DOCKER @@ -173,6 +177,88 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { containerInfo } + private def getSecrets(conf: SparkConf, secretConfig: MesosSecretConfig): Seq[Secret] = { + def createValueSecret(data: String): Secret = { + Secret.newBuilder() + .setType(Secret.Type.VALUE) + .setValue(Secret.Value.newBuilder().setData(ByteString.copyFrom(data.getBytes))) + .build() + } + + def createReferenceSecret(name: String): Secret = { + Secret.newBuilder() + .setReference(Secret.Reference.newBuilder().setName(name)) + .setType(Secret.Type.REFERENCE) + .build() + } + + val referenceSecrets: Seq[Secret] = + conf.get(secretConfig.SECRET_NAMES).getOrElse(Nil).map { s => createReferenceSecret(s) } + + val valueSecrets: Seq[Secret] = { + conf.get(secretConfig.SECRET_VALUES).getOrElse(Nil).map { s => createValueSecret(s) } + } + + if (valueSecrets.nonEmpty && referenceSecrets.nonEmpty) { + throw new SparkException("Cannot specify both value-type and reference-type secrets.") + } + + if (referenceSecrets.nonEmpty) referenceSecrets else valueSecrets + } + + private def illegalSecretInput(dest: Seq[String], secrets: Seq[Secret]): Boolean = { + if (dest.nonEmpty) { + // make sure there is a one-to-one correspondence between destinations and secrets + if (dest.length != secrets.length) { + return true + } + } + false + } + + def getSecretVolume(conf: SparkConf, secretConfig: MesosSecretConfig): List[Volume] = { + val secrets = getSecrets(conf, secretConfig) + val secretPaths: Seq[String] = + conf.get(secretConfig.SECRET_FILENAMES).getOrElse(Nil) + + if (illegalSecretInput(secretPaths, secrets)) { + throw new SparkException( + s"Need to give equal numbers of secrets and file paths for file-based " + + s"reference secrets got secrets $secrets, and paths $secretPaths") + } + + secrets.zip(secretPaths).map { case (s, p) => + val source = Volume.Source.newBuilder() + .setType(Volume.Source.Type.SECRET) + .setSecret(s) + Volume.newBuilder() + .setContainerPath(p) + .setSource(source) + .setMode(Volume.Mode.RO) + .build + }.toList + } + + def getSecretEnvVar(conf: SparkConf, secretConfig: MesosSecretConfig): + List[Variable] = { + val secrets = getSecrets(conf, secretConfig) + val secretEnvKeys = conf.get(secretConfig.SECRET_ENVKEYS).getOrElse(Nil) + if (illegalSecretInput(secretEnvKeys, secrets)) { + throw new SparkException( + s"Need to give equal numbers of secrets and environment keys " + + s"for environment-based reference secrets got secrets $secrets, " + + s"and keys $secretEnvKeys") + } + + secrets.zip(secretEnvKeys).map { case (s, k) => + Variable.newBuilder() + .setName(k) + .setType(Variable.Type.SECRET) + .setSecret(s) + .build + }.toList + } + private def dockerInfo( image: String, forcePullImage: Boolean, diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index ff63e3f4ccfc3..77acee608f25f 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.mesos.Protos.{Environment, Secret, TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Value.{Scalar, Type} import org.apache.mesos.SchedulerDriver -import org.apache.mesos.protobuf.ByteString import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar @@ -32,6 +31,7 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.mesos.config class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { @@ -341,132 +341,33 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi } test("Creates an env-based reference secrets.") { - setScheduler() - - val mem = 1000 - val cpu = 1 - val secretName = "/path/to/secret,/anothersecret" - val envKey = "SECRET_ENV_KEY,PASSWORD" - val driverDesc = new MesosDriverDescription( - "d1", - "jar", - mem, - cpu, - true, - command, - Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.names" -> secretName, - "spark.mesos.driver.secret.envkeys" -> envKey), - "s1", - new Date()) - val response = scheduler.submitDriver(driverDesc) - assert(response.success) - val offer = Utils.createOffer("o1", "s1", mem, cpu) - scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - assert(launchedTasks.head - .getCommand - .getEnvironment - .getVariablesCount == 3) // SPARK_SUBMIT_OPS and the secret - val variableOne = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "SECRET_ENV_KEY").head - assert(variableOne.getSecret.isInitialized) - assert(variableOne.getSecret.getType == Secret.Type.REFERENCE) - assert(variableOne.getSecret.getReference.getName == "/path/to/secret") - assert(variableOne.getType == Environment.Variable.Type.SECRET) - val variableTwo = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "PASSWORD").head - assert(variableTwo.getSecret.isInitialized) - assert(variableTwo.getSecret.getType == Secret.Type.REFERENCE) - assert(variableTwo.getSecret.getReference.getName == "/anothersecret") - assert(variableTwo.getType == Environment.Variable.Type.SECRET) + val launchedTasks = launchDriverTask( + Utils.configEnvBasedRefSecrets(config.driverSecretConfig)) + Utils.verifyEnvBasedRefSecrets(launchedTasks) } test("Creates an env-based value secrets.") { - setScheduler() - val mem = 1000 - val cpu = 1 - val secretValues = "user,password" - val envKeys = "USER,PASSWORD" - val driverDesc = new MesosDriverDescription( - "d1", - "jar", - mem, - cpu, - true, - command, - Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.values" -> secretValues, - "spark.mesos.driver.secret.envkeys" -> envKeys), - "s1", - new Date()) - val response = scheduler.submitDriver(driverDesc) - assert(response.success) - val offer = Utils.createOffer("o1", "s1", mem, cpu) - scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - assert(launchedTasks.head - .getCommand - .getEnvironment - .getVariablesCount == 3) // SPARK_SUBMIT_OPS and the secret - val variableOne = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "USER").head - assert(variableOne.getSecret.isInitialized) - assert(variableOne.getSecret.getType == Secret.Type.VALUE) - assert(variableOne.getSecret.getValue.getData == ByteString.copyFrom("user".getBytes)) - assert(variableOne.getType == Environment.Variable.Type.SECRET) - val variableTwo = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "PASSWORD").head - assert(variableTwo.getSecret.isInitialized) - assert(variableTwo.getSecret.getType == Secret.Type.VALUE) - assert(variableTwo.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes)) - assert(variableTwo.getType == Environment.Variable.Type.SECRET) + val launchedTasks = launchDriverTask( + Utils.configEnvBasedValueSecrets(config.driverSecretConfig)) + Utils.verifyEnvBasedValueSecrets(launchedTasks) } test("Creates file-based reference secrets.") { - setScheduler() - val mem = 1000 - val cpu = 1 - val secretName = "/path/to/secret,/anothersecret" - val secretPath = "/topsecret,/mypassword" - val driverDesc = new MesosDriverDescription( - "d1", - "jar", - mem, - cpu, - true, - command, - Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.names" -> secretName, - "spark.mesos.driver.secret.filenames" -> secretPath), - "s1", - new Date()) - val response = scheduler.submitDriver(driverDesc) - assert(response.success) - val offer = Utils.createOffer("o1", "s1", mem, cpu) - scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - val volumes = launchedTasks.head.getContainer.getVolumesList - assert(volumes.size() == 2) - val secretVolOne = volumes.get(0) - assert(secretVolOne.getContainerPath == "/topsecret") - assert(secretVolOne.getSource.getSecret.getType == Secret.Type.REFERENCE) - assert(secretVolOne.getSource.getSecret.getReference.getName == "/path/to/secret") - val secretVolTwo = volumes.get(1) - assert(secretVolTwo.getContainerPath == "/mypassword") - assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.REFERENCE) - assert(secretVolTwo.getSource.getSecret.getReference.getName == "/anothersecret") + val launchedTasks = launchDriverTask( + Utils.configFileBasedRefSecrets(config.driverSecretConfig)) + Utils.verifyFileBasedRefSecrets(launchedTasks) } test("Creates a file-based value secrets.") { + val launchedTasks = launchDriverTask( + Utils.configFileBasedValueSecrets(config.driverSecretConfig)) + Utils.verifyFileBasedValueSecrets(launchedTasks) + } + + private def launchDriverTask(addlSparkConfVars: Map[String, String]): List[TaskInfo] = { setScheduler() val mem = 1000 val cpu = 1 - val secretValues = "user,password" - val secretPath = "/whoami,/mypassword" val driverDesc = new MesosDriverDescription( "d1", "jar", @@ -475,27 +376,14 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi true, command, Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.values" -> secretValues, - "spark.mesos.driver.secret.filenames" -> secretPath), + "spark.app.name" -> "test") ++ + addlSparkConfVars, "s1", new Date()) val response = scheduler.submitDriver(driverDesc) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem, cpu) scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - val volumes = launchedTasks.head.getContainer.getVolumesList - assert(volumes.size() == 2) - val secretVolOne = volumes.get(0) - assert(secretVolOne.getContainerPath == "/whoami") - assert(secretVolOne.getSource.getSecret.getType == Secret.Type.VALUE) - assert(secretVolOne.getSource.getSecret.getValue.getData == - ByteString.copyFrom("user".getBytes)) - val secretVolTwo = volumes.get(1) - assert(secretVolTwo.getContainerPath == "/mypassword") - assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.VALUE) - assert(secretVolTwo.getSource.getSecret.getValue.getData == - ByteString.copyFrom("password".getBytes)) + Utils.verifyTaskLaunched(driver, "o1") } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 6c40792112f49..f4bd1ee9da6f7 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.concurrent.duration._ -import scala.reflect.ClassTag import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ @@ -38,7 +37,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor} import org.apache.spark.scheduler.cluster.mesos.Utils._ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite @@ -653,6 +652,37 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite offerResourcesAndVerify(2, true) } + test("Creates an env-based reference secrets.") { + val launchedTasks = launchExecutorTasks(configEnvBasedRefSecrets(executorSecretConfig)) + verifyEnvBasedRefSecrets(launchedTasks) + } + + test("Creates an env-based value secrets.") { + val launchedTasks = launchExecutorTasks(configEnvBasedValueSecrets(executorSecretConfig)) + verifyEnvBasedValueSecrets(launchedTasks) + } + + test("Creates file-based reference secrets.") { + val launchedTasks = launchExecutorTasks(configFileBasedRefSecrets(executorSecretConfig)) + verifyFileBasedRefSecrets(launchedTasks) + } + + test("Creates a file-based value secrets.") { + val launchedTasks = launchExecutorTasks(configFileBasedValueSecrets(executorSecretConfig)) + verifyFileBasedValueSecrets(launchedTasks) + } + + private def launchExecutorTasks(sparkConfVars: Map[String, String]): List[TaskInfo] = { + setBackend(sparkConfVars) + + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + + verifyTaskLaunched(driver, "o1") + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala index f49d7c29eda49..442c43960ec1f 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster.mesos import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.mesos.config class MesosSchedulerBackendUtilSuite extends SparkFunSuite { @@ -26,7 +27,8 @@ class MesosSchedulerBackendUtilSuite extends SparkFunSuite { conf.set("spark.mesos.executor.docker.parameters", "a,b") conf.set("spark.mesos.executor.docker.image", "test") - val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo( + conf) val params = containerInfo.getDocker.getParametersList assert(params.size() == 0) @@ -37,7 +39,8 @@ class MesosSchedulerBackendUtilSuite extends SparkFunSuite { conf.set("spark.mesos.executor.docker.parameters", "a=1,b=2,c=3") conf.set("spark.mesos.executor.docker.image", "test") - val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo( + conf) val params = containerInfo.getDocker.getParametersList assert(params.size() == 3) assert(params.get(0).getKey == "a") diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 833db0c1ff334..5636ac52bd4a7 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -24,9 +24,12 @@ import scala.collection.JavaConverters._ import org.apache.mesos.Protos._ import org.apache.mesos.Protos.Value.{Range => MesosRange, Ranges, Scalar} import org.apache.mesos.SchedulerDriver +import org.apache.mesos.protobuf.ByteString import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ +import org.apache.spark.deploy.mesos.config.MesosSecretConfig + object Utils { val TEST_FRAMEWORK_ID = FrameworkID.newBuilder() @@ -105,4 +108,108 @@ object Utils { def createTaskId(taskId: String): TaskID = { TaskID.newBuilder().setValue(taskId).build() } + + def configEnvBasedRefSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretName = "/path/to/secret,/anothersecret" + val envKey = "SECRET_ENV_KEY,PASSWORD" + Map( + secretConfig.SECRET_NAMES.key -> secretName, + secretConfig.SECRET_ENVKEYS.key -> envKey + ) + } + + def verifyEnvBasedRefSecrets(launchedTasks: List[TaskInfo]): Unit = { + val envVars = launchedTasks.head + .getCommand + .getEnvironment + .getVariablesList + .asScala + assert(envVars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + val variableOne = envVars.filter(_.getName == "SECRET_ENV_KEY").head + assert(variableOne.getSecret.isInitialized) + assert(variableOne.getSecret.getType == Secret.Type.REFERENCE) + assert(variableOne.getSecret.getReference.getName == "/path/to/secret") + assert(variableOne.getType == Environment.Variable.Type.SECRET) + val variableTwo = envVars.filter(_.getName == "PASSWORD").head + assert(variableTwo.getSecret.isInitialized) + assert(variableTwo.getSecret.getType == Secret.Type.REFERENCE) + assert(variableTwo.getSecret.getReference.getName == "/anothersecret") + assert(variableTwo.getType == Environment.Variable.Type.SECRET) + } + + def configEnvBasedValueSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretValues = "user,password" + val envKeys = "USER,PASSWORD" + Map( + secretConfig.SECRET_VALUES.key -> secretValues, + secretConfig.SECRET_ENVKEYS.key -> envKeys + ) + } + + def verifyEnvBasedValueSecrets(launchedTasks: List[TaskInfo]): Unit = { + val envVars = launchedTasks.head + .getCommand + .getEnvironment + .getVariablesList + .asScala + assert(envVars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + val variableOne = envVars.filter(_.getName == "USER").head + assert(variableOne.getSecret.isInitialized) + assert(variableOne.getSecret.getType == Secret.Type.VALUE) + assert(variableOne.getSecret.getValue.getData == ByteString.copyFrom("user".getBytes)) + assert(variableOne.getType == Environment.Variable.Type.SECRET) + val variableTwo = envVars.filter(_.getName == "PASSWORD").head + assert(variableTwo.getSecret.isInitialized) + assert(variableTwo.getSecret.getType == Secret.Type.VALUE) + assert(variableTwo.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes)) + assert(variableTwo.getType == Environment.Variable.Type.SECRET) + } + + def configFileBasedRefSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretName = "/path/to/secret,/anothersecret" + val secretPath = "/topsecret,/mypassword" + Map( + secretConfig.SECRET_NAMES.key -> secretName, + secretConfig.SECRET_FILENAMES.key -> secretPath + ) + } + + def verifyFileBasedRefSecrets(launchedTasks: List[TaskInfo]): Unit = { + val volumes = launchedTasks.head.getContainer.getVolumesList + assert(volumes.size() == 2) + val secretVolOne = volumes.get(0) + assert(secretVolOne.getContainerPath == "/topsecret") + assert(secretVolOne.getSource.getSecret.getType == Secret.Type.REFERENCE) + assert(secretVolOne.getSource.getSecret.getReference.getName == "/path/to/secret") + val secretVolTwo = volumes.get(1) + assert(secretVolTwo.getContainerPath == "/mypassword") + assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.REFERENCE) + assert(secretVolTwo.getSource.getSecret.getReference.getName == "/anothersecret") + } + + def configFileBasedValueSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretValues = "user,password" + val secretPath = "/whoami,/mypassword" + Map( + secretConfig.SECRET_VALUES.key -> secretValues, + secretConfig.SECRET_FILENAMES.key -> secretPath + ) + } + + def verifyFileBasedValueSecrets(launchedTasks: List[TaskInfo]): Unit = { + val volumes = launchedTasks.head.getContainer.getVolumesList + assert(volumes.size() == 2) + val secretVolOne = volumes.get(0) + assert(secretVolOne.getContainerPath == "/whoami") + assert(secretVolOne.getSource.getSecret.getType == Secret.Type.VALUE) + assert(secretVolOne.getSource.getSecret.getValue.getData == + ByteString.copyFrom("user".getBytes)) + val secretVolTwo = volumes.get(1) + assert(secretVolTwo.getContainerPath == "/mypassword") + assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.VALUE) + assert(secretVolTwo.getSource.getSecret.getValue.getData == + ByteString.copyFrom("password".getBytes)) + } } From 8e9863531bebbd4d83eafcbc2b359b8bd0ac5734 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 26 Oct 2017 16:55:30 -0700 Subject: [PATCH 772/779] [SPARK-22366] Support ignoring missing files ## What changes were proposed in this pull request? Add a flag "spark.sql.files.ignoreMissingFiles" to parallel the existing flag "spark.sql.files.ignoreCorruptFiles". ## How was this patch tested? new unit test Author: Jose Torres Closes #19581 from joseph-torres/SPARK-22366. --- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++ .../execution/datasources/FileScanRDD.scala | 13 +++++--- .../parquet/ParquetQuerySuite.scala | 33 +++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4cfe53b2c115b..21e4685fcc456 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -614,6 +614,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val IGNORE_MISSING_FILES = buildConf("spark.sql.files.ignoreMissingFiles") + .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + + "encountering missing files and the contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + val MAX_RECORDS_PER_FILE = buildConf("spark.sql.files.maxRecordsPerFile") .doc("Maximum number of records to write out to a single file. " + "If this value is zero or negative, there is no limit.") @@ -1014,6 +1020,8 @@ class SQLConf extends Serializable with Logging { def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) + def ignoreMissingFiles: Boolean = getConf(IGNORE_MISSING_FILES) + def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE) def useCompression: Boolean = getConf(COMPRESS_CACHED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 9df20731c71d5..8731ee88f87f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -66,6 +66,7 @@ class FileScanRDD( extends RDD[InternalRow](sparkSession.sparkContext, Nil) { private val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + private val ignoreMissingFiles = sparkSession.sessionState.conf.ignoreMissingFiles override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { @@ -142,7 +143,7 @@ class FileScanRDD( // Sets InputFileBlockHolder for the file block's information InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) - if (ignoreCorruptFiles) { + if (ignoreMissingFiles || ignoreCorruptFiles) { currentIterator = new NextIterator[Object] { // The readFunction may read some bytes before consuming the iterator, e.g., // vectorized Parquet reader. Here we use lazy val to delay the creation of @@ -158,9 +159,13 @@ class FileScanRDD( null } } catch { - // Throw FileNotFoundException even `ignoreCorruptFiles` is true - case e: FileNotFoundException => throw e - case e @ (_: RuntimeException | _: IOException) => + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: $currentFile", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e + case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles => logWarning( s"Skipped the rest of the content in the corrupted file: $currentFile", e) finished = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2efff3f57d7d3..e822e40b146ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -316,6 +316,39 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + testQuietly("Enabling/disabling ignoreMissingFiles") { + def testIgnoreMissingFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + spark.range(2, 3).toDF("a").write.parquet(thirdPath.toString) + val df = spark.read.parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + fs.delete(thirdPath, true) + checkAnswer( + df, + Seq(Row(0), Row(1))) + } + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { + testIgnoreMissingFiles() + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreMissingFiles() + } + assert(exception.getMessage().contains("does not exist")) + } + } + /** * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop * to increase the chance of failure From 9b262f6a08c0c1b474d920d49b9fdd574c401d39 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Oct 2017 17:39:53 -0700 Subject: [PATCH 773/779] [SPARK-22356][SQL] data source table should support overlapped columns between data and partition schema ## What changes were proposed in this pull request? This is a regression introduced by #14207. After Spark 2.1, we store the inferred schema when creating the table, to avoid inferring schema again at read path. However, there is one special case: overlapped columns between data and partition. For this case, it breaks the assumption of table schema that there is on ovelap between data and partition schema, and partition columns should be at the end. The result is, for Spark 2.1, the table scan has incorrect schema that puts partition columns at the end. For Spark 2.2, we add a check in CatalogTable to validate table schema, which fails at this case. To fix this issue, a simple and safe approach is to fallback to old behavior when overlapeed columns detected, i.e. store empty schema in metastore. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #19579 from cloud-fan/bug2. --- .../command/createDataSourceTables.scala | 35 +++++++++++++---- .../datasources/HadoopFsRelation.scala | 25 ++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ++++++++ .../HiveExternalCatalogVersionsSuite.scala | 38 ++++++++++++++----- 4 files changed, 89 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 9e3907996995c..306f43dc4214a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.StructType /** * A command used to create a data source table. @@ -85,14 +86,32 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo } } - val newTable = table.copy( - schema = dataSource.schema, - partitionColumnNames = partitionColumnNames, - // If metastore partition management for file source tables is enabled, we start off with - // partition provider hive, but no partitions in the metastore. The user has to call - // `msck repair table` to populate the table partitions. - tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && - sessionState.conf.manageFilesourcePartitions) + val newTable = dataSource match { + // Since Spark 2.1, we store the inferred schema of data source in metastore, to avoid + // inferring the schema again at read path. However if the data source has overlapped columns + // between data and partition schema, we can't store it in metastore as it breaks the + // assumption of table schema. Here we fallback to the behavior of Spark prior to 2.1, store + // empty schema in metastore and infer it at runtime. Note that this also means the new + // scalable partitioning handling feature(introduced at Spark 2.1) is disabled in this case. + case r: HadoopFsRelation if r.overlappedPartCols.nonEmpty => + logWarning("It is not recommended to create a table with overlapped data and partition " + + "columns, as Spark cannot store a valid table schema and has to infer it at runtime, " + + "which hurts performance. Please check your data files and remove the partition " + + "columns in it.") + table.copy(schema = new StructType(), partitionColumnNames = Nil) + + case _ => + table.copy( + schema = dataSource.schema, + partitionColumnNames = partitionColumnNames, + // If metastore partition management for file source tables is enabled, we start off with + // partition provider hive, but no partitions in the metastore. The user has to call + // `msck repair table` to populate the table partitions. + tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && + sessionState.conf.manageFilesourcePartitions) + + } + // We will return Nil or throw exception at the beginning if the table already exists, so when // we reach here, the table should not exist and we should set `ignoreIfExists` to false. sessionState.catalog.createTable(newTable, ignoreIfExists = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 9a08524476baa..89d8a85a9cbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import scala.collection.mutable import org.apache.spark.sql.{SparkSession, SQLContext} @@ -50,15 +52,22 @@ case class HadoopFsRelation( override def sqlContext: SQLContext = sparkSession.sqlContext - val schema: StructType = { - val getColName: (StructField => String) = - if (sparkSession.sessionState.conf.caseSensitiveAnalysis) _.name else _.name.toLowerCase - val overlappedPartCols = mutable.Map.empty[String, StructField] - partitionSchema.foreach { partitionField => - if (dataSchema.exists(getColName(_) == getColName(partitionField))) { - overlappedPartCols += getColName(partitionField) -> partitionField - } + private def getColName(f: StructField): String = { + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { + f.name + } else { + f.name.toLowerCase(Locale.ROOT) + } + } + + val overlappedPartCols = mutable.Map.empty[String, StructField] + partitionSchema.foreach { partitionField => + if (dataSchema.exists(getColName(_) == getColName(partitionField))) { + overlappedPartCols += getColName(partitionField) -> partitionField } + } + + val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index caf332d050d7b..5d0bba69daca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2741,4 +2741,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert (aggregateExpressions.isDefined) assert (aggregateExpressions.get.size == 2) } + + test("SPARK-22356: overlapped columns between data and partition schema in data source tables") { + withTempPath { path => + Seq((1, 1, 1), (1, 2, 1)).toDF("i", "p", "j") + .write.mode("overwrite").parquet(new File(path, "p=1").getCanonicalPath) + withTable("t") { + sql(s"create table t using parquet options(path='${path.getCanonicalPath}')") + // We should respect the column order in data schema. + assert(spark.table("t").columns === Array("i", "p", "j")) + checkAnswer(spark.table("t"), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + // The DESC TABLE should report same schema as table scan. + assert(sql("desc t").select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 5f8c9d5799662..6859432c406a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -40,7 +40,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to // avoid downloading Spark of different versions in each run. - private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark") + private val sparkTestingDir = new File("/tmp/test-spark") private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) override def afterAll(): Unit = { @@ -77,35 +77,38 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { super.beforeAll() val tempPyFile = File.createTempFile("test", ".py") + // scalastyle:off line.size.limit Files.write(tempPyFile.toPath, s""" |from pyspark.sql import SparkSession + |import os | |spark = SparkSession.builder.enableHiveSupport().getOrCreate() |version_index = spark.conf.get("spark.sql.test.version.index", None) | |spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index)) | - |spark.sql("create table hive_compatible_data_source_tbl_" + version_index + \\ - | " using parquet as select 1 i") + |spark.sql("create table hive_compatible_data_source_tbl_{} using parquet as select 1 i".format(version_index)) | |json_file = "${genDataDir("json_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file) - |spark.sql("create table external_data_source_tbl_" + version_index + \\ - | "(i int) using json options (path '{}')".format(json_file)) + |spark.sql("create table external_data_source_tbl_{}(i int) using json options (path '{}')".format(version_index, json_file)) | |parquet_file = "${genDataDir("parquet_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file) - |spark.sql("create table hive_compatible_external_data_source_tbl_" + version_index + \\ - | "(i int) using parquet options (path '{}')".format(parquet_file)) + |spark.sql("create table hive_compatible_external_data_source_tbl_{}(i int) using parquet options (path '{}')".format(version_index, parquet_file)) | |json_file2 = "${genDataDir("json2_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2) - |spark.sql("create table external_table_without_schema_" + version_index + \\ - | " using json options (path '{}')".format(json_file2)) + |spark.sql("create table external_table_without_schema_{} using json options (path '{}')".format(version_index, json_file2)) + | + |parquet_file2 = "${genDataDir("parquet2_")}" + str(version_index) + |spark.range(1, 3).selectExpr("1 as i", "cast(id as int) as p", "1 as j").write.parquet(os.path.join(parquet_file2, "p=1")) + |spark.sql("create table tbl_with_col_overlap_{} using parquet options(path '{}')".format(version_index, parquet_file2)) | |spark.sql("create view v_{} as select 1 i".format(version_index)) """.stripMargin.getBytes("utf8")) + // scalastyle:on line.size.limit PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") @@ -153,6 +156,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { .enableHiveSupport() .getOrCreate() spark = session + import session.implicits._ testingVersions.indices.foreach { index => Seq( @@ -194,6 +198,22 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // test permanent view checkAnswer(sql(s"select i from v_$index"), Row(1)) + + // SPARK-22356: overlapped columns between data and partition schema in data source tables + val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { + spark.sql("msck repair table " + tbl_with_col_overlap) + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,j,p")) + } else { + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } } } } From 5c3a1f3fad695317c2fff1243cdb9b3ceb25c317 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Oct 2017 17:51:16 -0700 Subject: [PATCH 774/779] [SPARK-22355][SQL] Dataset.collect is not threadsafe ## What changes were proposed in this pull request? It's possible that users create a `Dataset`, and call `collect` of this `Dataset` in many threads at the same time. Currently `Dataset#collect` just call `encoder.fromRow` to convert spark rows to objects of type T, and this encoder is per-dataset. This means `Dataset#collect` is not thread-safe, because the encoder uses a projection to output the object to a re-usable row. This PR fixes this problem, by creating a new projection when calling `Dataset#collect`, so that we have the re-usable row for each method call, instead of each Dataset. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19577 from cloud-fan/encoder. --- .../scala/org/apache/spark/sql/Dataset.scala | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b70dfc05330f8..0e23983786b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} @@ -198,15 +199,10 @@ class Dataset[T] private[sql]( */ private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) - /** - * Encoder is used mostly as a container of serde expressions in Dataset. We build logical - * plans by these serde expressions and execute it within the query framework. However, for - * performance reasons we may want to use encoder as a function to deserialize internal rows to - * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its - * `fromRow` method later. - */ - private val boundEnc = - exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) + // The deserializer expression which can be used to build a projection and turn rows to objects + // of type T, after collecting rows to the driver side. + private val deserializer = + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer private implicit def classTag = exprEnc.clsTag @@ -2661,7 +2657,15 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = { withAction("toLocalIterator", queryExecution) { plan => - plan.executeToIterator().map(boundEnc.fromRow).asJava + // This projection writes output to a `InternalRow`, which means applying this projection is + // not thread-safe. Here we create the projection inside this method to make `Dataset` + // thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeToIterator().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + }.asJava } } @@ -3102,7 +3106,14 @@ class Dataset[T] private[sql]( * Collect all elements from a spark plan. */ private def collectFromPlan(plan: SparkPlan): Array[T] = { - plan.executeCollect().map(boundEnc.fromRow) + // This projection writes output to a `InternalRow`, which means applying this projection is not + // thread-safe. Here we create the projection inside this method to make `Dataset` thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeCollect().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + } } private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { From 17af727e38c3faaeab5b91a8cdab5f2181cf3fc4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 26 Oct 2017 23:02:46 -0700 Subject: [PATCH 775/779] [SPARK-21375][PYSPARK][SQL] Add Date and Timestamp support to ArrowConverters for toPandas() Conversion ## What changes were proposed in this pull request? Adding date and timestamp support with Arrow for `toPandas()` and `pandas_udf`s. Timestamps are stored in Arrow as UTC and manifested to the user as timezone-naive localized to the Python system timezone. ## How was this patch tested? Added Scala tests for date and timestamp types under ArrowConverters, ArrowUtils, and ArrowWriter suites. Added Python tests for `toPandas()` and `pandas_udf`s with date and timestamp types. Author: Bryan Cutler Author: Takuya UESHIN Closes #18664 from BryanCutler/arrow-date-timestamp-SPARK-21375. --- python/pyspark/serializers.py | 24 +++- python/pyspark/sql/dataframe.py | 7 +- python/pyspark/sql/tests.py | 106 ++++++++++++++-- python/pyspark/sql/types.py | 36 ++++++ .../vectorized/ArrowColumnVector.java | 34 +++++ .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../sql/execution/arrow/ArrowConverters.scala | 3 +- .../sql/execution/arrow/ArrowUtils.scala | 30 +++-- .../sql/execution/arrow/ArrowWriter.scala | 39 +++++- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 5 +- .../python/FlatMapGroupsInPandasExec.scala | 4 +- .../arrow/ArrowConvertersSuite.scala | 120 ++++++++++++++++-- .../sql/execution/arrow/ArrowUtilsSuite.scala | 26 +++- .../execution/arrow/ArrowWriterSuite.scala | 24 ++-- .../vectorized/ArrowColumnVectorSuite.scala | 22 ++-- .../vectorized/ColumnarBatchSuite.scala | 4 +- 17 files changed, 417 insertions(+), 73 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a0adeed994456..d7979f095da76 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -214,6 +214,7 @@ def __repr__(self): def _create_batch(series): + from pyspark.sql.types import _check_series_convert_timestamps_internal import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -224,12 +225,25 @@ def _create_batch(series): # If a nullable integer series has been promoted to floating point with NaNs, need to cast # NOTE: this is not necessary with Arrow >= 0.7 def cast_series(s, t): - if t is None or s.dtype == t.to_pandas_dtype(): + if type(t) == pa.TimestampType: + # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 + return _check_series_convert_timestamps_internal(s.fillna(0))\ + .values.astype('datetime64[us]', copy=False) + elif t == pa.date32(): + # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 + return s.dt.date + elif t is None or s.dtype == t.to_pandas_dtype(): return s else: return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + # Some object types don't support masks in Arrow, see ARROW-1721 + def create_array(s, t): + casted = cast_series(s, t) + mask = None if casted.dtype == 'object' else s.isnull() + return pa.Array.from_pandas(casted, mask=mask, type=t) + + arrs = [create_array(s, t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) @@ -260,11 +274,13 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ + from pyspark.sql.types import _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) for batch in reader: - table = pa.Table.from_batches([batch]) - yield [c.to_pandas() for c in table.itercolumns()] + # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 + pdf = _check_dataframe_localize_timestamps(batch.to_pandas()) + yield [c for _, c in pdf.iteritems()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c0b574e2b93a1..406686e6df724 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1883,11 +1883,13 @@ def toPandas(self): import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: + from pyspark.sql.types import _check_dataframe_localize_timestamps import pyarrow tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) - return table.to_pandas() + pdf = table.to_pandas() + return _check_dataframe_localize_timestamps(pdf) else: return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: @@ -1955,6 +1957,7 @@ def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. + NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] """ import numpy as np if type(dt) == ByteType: @@ -1965,6 +1968,8 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 + elif type(dt) == DateType: + return 'datetime64[ns]' else: return None diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 685eebcafefba..98afae662b42d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3086,18 +3086,38 @@ class ArrowTests(ReusedPySparkTestCase): @classmethod def setUpClass(cls): + from datetime import datetime ReusedPySparkTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0), - ("b", 2, 20, 0.4, 4.0), - ("c", 3, 30, 0.8, 6.0)] + StructField("5_double_t", DoubleType(), True), + StructField("6_date_t", DateType(), True), + StructField("7_timestamp_t", TimestampType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -3106,8 +3126,8 @@ def assertFramesEqual(self, df_with_arrow, df_without): self.assertTrue(df_without.equals(df_with_arrow), msg=msg) def test_unsupported_datatype(self): - schema = StructType([StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + schema = StructType([StructField("decimal", DecimalType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): self.assertRaises(Exception, lambda: df.toPandas()) @@ -3385,13 +3405,77 @@ def test_vectorized_udf_varargs(self): def test_vectorized_udf_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) - f = pandas_udf(lambda x: x, DateType()) + schema = StructType([StructField("dt", DecimalType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + f = pandas_udf(lambda x: x, DecimalType()) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.select(f(col('dt'))).collect() + def test_vectorized_udf_null_date(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import date + schema = StructType().add("date", DateType()) + data = [(date(1969, 1, 1),), + (date(2012, 2, 2),), + (None,), + (date(2100, 4, 4),)] + df = self.spark.createDataFrame(data, schema=schema) + date_f = pandas_udf(lambda t: t, returnType=DateType()) + res = df.select(date_f(col("date"))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_timestamps(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import datetime + schema = StructType([ + StructField("idx", LongType(), True), + StructField("timestamp", TimestampType(), True)]) + data = [(0, datetime(1969, 1, 1, 1, 1, 1)), + (1, datetime(2012, 2, 2, 2, 2, 2)), + (2, None), + (3, datetime(2100, 4, 4, 4, 4, 4))] + df = self.spark.createDataFrame(data, schema=schema) + + # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc + f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) + df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) + + @pandas_udf(returnType=BooleanType()) + def check_data(idx, timestamp, timestamp_copy): + is_equal = timestamp.isnull() # use this array to check values are equal + for i in range(len(idx)): + # Check that timestamps are as expected in the UDF + is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \ + timestamp[i].to_pydatetime() == data[idx[i]][1] + return is_equal + + result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"), + col("timestamp_copy"))).collect() + # Check that collection values are correct + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected + + def test_vectorized_udf_return_timestamp_tz(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + + @pandas_udf(returnType=TimestampType()) + def gen_timestamps(id): + ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id] + return pd.Series(ts) + + result = df.withColumn("ts", gen_timestamps(col("id"))).collect() + spark_ts_t = TimestampType() + for r in result: + i, ts = r + ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime() + expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) + self.assertEquals(expected, ts) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @@ -3550,8 +3634,8 @@ def test_wrong_args(self): def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col schema = StructType( - [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) + df = self.spark.createDataFrame([(1, None,)], schema=schema) f = pandas_udf(lambda x: x, df.schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f65273d5f0b6c..7dd8fa04160e0 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1619,11 +1619,47 @@ def to_arrow_type(dt): arrow_type = pa.decimal(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() + elif type(dt) == DateType: + arrow_type = pa.date32() + elif type(dt) == TimestampType: + # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read + arrow_type = pa.timestamp('us', tz='UTC') else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type +def _check_dataframe_localize_timestamps(pdf): + """ + Convert timezone aware timestamps to timezone-naive in local time + + :param pdf: pandas.DataFrame + :return pandas.DataFrame where any timezone aware columns have be converted to tz-naive + """ + from pandas.api.types import is_datetime64tz_dtype + for column, series in pdf.iteritems(): + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(series.dtype): + pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) + return pdf + + +def _check_series_convert_timestamps_internal(s): + """ + Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage + :param s: a pandas.Series + :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone + """ + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64_dtype(s.dtype): + return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') + elif is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert('UTC') + else: + return s + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 1f171049820b2..51ea719f8c4a6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -320,6 +320,10 @@ public ArrowColumnVector(ValueVector vector) { accessor = new StringAccessor((NullableVarCharVector) vector); } else if (vector instanceof NullableVarBinaryVector) { accessor = new BinaryAccessor((NullableVarBinaryVector) vector); + } else if (vector instanceof NullableDateDayVector) { + accessor = new DateAccessor((NullableDateDayVector) vector); + } else if (vector instanceof NullableTimeStampMicroTZVector) { + accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); @@ -575,6 +579,36 @@ final byte[] getBinary(int rowId) { } } + private static class DateAccessor extends ArrowVectorAccessor { + + private final NullableDateDayVector.Accessor accessor; + + DateAccessor(NullableDateDayVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class TimestampAccessor extends ArrowVectorAccessor { + + private final NullableTimeStampMicroTZVector.Accessor accessor; + + TimestampAccessor(NullableTimeStampMicroTZVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + private static class ArrayAccessor extends ArrowVectorAccessor { private final UInt4Vector.Accessor accessor; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0e23983786b08..fe4e192e43dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3154,9 +3154,11 @@ class Dataset[T] private[sql]( private[sql] def toArrowPayload: RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone queryExecution.toRdd.mapPartitionsInternal { iter => val context = TaskContext.get() - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context) + ArrowConverters.toPayloadIterator( + iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 561a067a2f81f..05ea1517fcac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -74,9 +74,10 @@ private[sql] object ArrowConverters { rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, + timeZoneId: String, context: TaskContext): Iterator[ArrowPayload] = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 2caf1ef02909a..6ad11bda84bf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.types._ @@ -31,7 +31,8 @@ object ArrowUtils { // todo: support more types. - def toArrowType(dt: DataType): ArrowType = dt match { + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ + def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) case ShortType => new ArrowType.Int(8 * 2, true) @@ -42,6 +43,13 @@ object ArrowUtils { case StringType => ArrowType.Utf8.INSTANCE case BinaryType => ArrowType.Binary.INSTANCE case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType => + if (timeZoneId == null) { + throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + } else { + new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + } case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } @@ -58,22 +66,27 @@ object ArrowUtils { case ArrowType.Utf8.INSTANCE => StringType case ArrowType.Binary.INSTANCE => BinaryType case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } - def toArrowField(name: String, dt: DataType, nullable: Boolean): Field = { + /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ + def toArrowField( + name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { dt match { case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) - new Field(name, fieldType, Seq(toArrowField("element", elementType, containsNull)).asJava) + new Field(name, fieldType, + Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) case StructType(fields) => val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) new Field(name, fieldType, fields.map { field => - toArrowField(field.name, field.dataType, field.nullable) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.toSeq.asJava) case dataType => - val fieldType = new FieldType(nullable, toArrowType(dataType), null) + val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) new Field(name, fieldType, Seq.empty[Field].asJava) } } @@ -94,9 +107,10 @@ object ArrowUtils { } } - def toArrowSchema(schema: StructType): Schema = { + /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ + def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { new Schema(schema.map { field => - toArrowField(field.name, field.dataType, field.nullable) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.asJava) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 0b740735ffe19..e4af4f65da127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ -import org.apache.arrow.vector.util.DecimalUtility +import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters @@ -29,8 +29,8 @@ import org.apache.spark.sql.types._ object ArrowWriter { - def create(schema: StructType): ArrowWriter = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + def create(schema: StructType, timeZoneId: String): ArrowWriter = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) create(root) } @@ -55,6 +55,8 @@ object ArrowWriter { case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) + case (DateType, vector: NullableDateDayVector) => new DateWriter(vector) + case (TimestampType, vector: NullableTimeStampMicroTZVector) => new TimestampWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) @@ -69,9 +71,7 @@ object ArrowWriter { } } -class ArrowWriter( - val root: VectorSchemaRoot, - fields: Array[ArrowFieldWriter]) { +class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { def schema: StructType = StructType(fields.map { f => StructField(f.name, f.dataType, f.nullable) @@ -255,6 +255,33 @@ private[arrow] class BinaryWriter( } } +private[arrow] class DateWriter(val valueVector: NullableDateDayVector) extends ArrowFieldWriter { + + override def valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getInt(ordinal)) + } +} + +private[arrow] class TimestampWriter( + val valueVector: NullableTimeStampMicroTZVector) extends ArrowFieldWriter { + + override def valueMutator: NullableTimeStampMicroTZVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getLong(ordinal)) + } +} + private[arrow] class ArrayWriter( val valueVector: ListVector, val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 81896187ecc46..0db463a5fbd89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -79,7 +79,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index f6c03c415dc66..94c05b9b5e49f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -43,7 +43,8 @@ class ArrowPythonRunner( reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], - schema: StructType) + schema: StructType, + timeZoneId: String) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -60,7 +61,7 @@ class ArrowPythonRunner( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 5ed88ada428cb..cc93fda9f81da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,8 +94,8 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema) - .compute(grouped, context.partitionId(), context) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, conf.sessionLocalTimeZone) + .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 30422b657742c..ba2903babbba8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -32,6 +32,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -793,6 +795,103 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "binaryData.json") } + test("date type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "date", + | "type" : { + | "name" : "date", + | "unit" : "DAY" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "date", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ -1, 0, 16533, 382607 ] + | } ] + | } ] + |} + """.stripMargin + + val d1 = DateTimeUtils.toJavaDate(-1) // "1969-12-31" + val d2 = DateTimeUtils.toJavaDate(0) // "1970-01-01" + val d3 = Date.valueOf("2015-04-08") + val d4 = Date.valueOf("3017-07-18") + + val df = Seq(d1, d2, d3, d4).toDF("date") + + collectAndValidate(df, json, "dateData.json") + } + + test("timestamp type conversion") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "timestamp", + | "type" : { + | "name" : "timestamp", + | "unit" : "MICROSECOND", + | "timezone" : "America/Los_Angeles" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "timestamp", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ -1234, 0, 1365383415567000, 33057298500000000 ] + | } ] + | } ] + |} + """.stripMargin + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val ts1 = DateTimeUtils.toJavaTimestamp(-1234L) + val ts2 = DateTimeUtils.toJavaTimestamp(0L) + val ts3 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts4 = new Timestamp(sdf.parse("3017-07-18 14:55:00.000 UTC").getTime) + val data = Seq(ts1, ts2, ts3, ts4) + + val df = data.toDF("timestamp") + + collectAndValidate(df, json, "timestampData.json", "America/Los_Angeles") + } + } + test("floating-point NaN") { val json = s""" @@ -1486,15 +1585,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { runUnsupported { decimalData.toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } - - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) - runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } - - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } } test("test Arrow Validator") { @@ -1638,7 +1728,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, ctx) + val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) assert(schema.equals(outputRowIter.schema)) @@ -1657,22 +1747,24 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { + private def collectAndValidate( + df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator val arrowPayload = df.coalesce(1).toArrowPayload.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile) + validateConversion(df.schema, arrowPayload, tempFile, timeZoneId) } private def validateConversion( sparkSchema: StructType, arrowPayload: ArrowPayload, - jsonFile: File): Unit = { + jsonFile: File, + timeZoneId: String = null): Unit = { val allocator = new RootAllocator(Long.MaxValue) val jsonReader = new JsonFileReader(jsonFile, allocator) - val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala index 638619fd39d06..d801f62b62323 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.arrow +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class ArrowUtilsSuite extends SparkFunSuite { @@ -25,7 +28,7 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtrip(dt: DataType): Unit = { dt match { case schema: StructType => - assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema)) === schema) + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null)) === schema) case _ => roundtrip(new StructType().add("value", dt)) } @@ -42,6 +45,27 @@ class ArrowUtilsSuite extends SparkFunSuite { roundtrip(StringType) roundtrip(BinaryType) roundtrip(DecimalType.SYSTEM_DEFAULT) + roundtrip(DateType) + val tsExMsg = intercept[UnsupportedOperationException] { + roundtrip(TimestampType) + } + assert(tsExMsg.getMessage.contains("timeZoneId")) + } + + test("timestamp") { + + def roundtripWithTz(timeZoneId: String): Unit = { + val schema = new StructType().add("value", TimestampType) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] + assert(fieldType.getTimezone() === timeZoneId) + assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) + } + + roundtripWithTz(DateTimeUtils.defaultTimeZone().getID) + roundtripWithTz("Asia/Tokyo") + roundtripWithTz("UTC") + roundtripWithTz("America/Los_Angeles") } test("array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index e9a629315f5f4..a71e30aa3ca96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { test("simple") { - def check(dt: DataType, data: Seq[Any]): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val schema = new StructType().add("value", dt, nullable = true) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) data.foreach { datum => @@ -51,6 +51,8 @@ class ArrowWriterSuite extends SparkFunSuite { case DoubleType => reader.getDouble(rowId) case StringType => reader.getUTF8String(rowId) case BinaryType => reader.getBinary(rowId) + case DateType => reader.getInt(rowId) + case TimestampType => reader.getLong(rowId) } assert(value === datum) } @@ -66,12 +68,14 @@ class ArrowWriterSuite extends SparkFunSuite { check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) + check(DateType, Seq(0, 1, 2, null, 4)) + check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") } test("get multiple") { - def check(dt: DataType, data: Seq[Any]): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val schema = new StructType().add("value", dt, nullable = false) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) data.foreach { datum => @@ -88,6 +92,8 @@ class ArrowWriterSuite extends SparkFunSuite { case LongType => reader.getLongs(0, data.size) case FloatType => reader.getFloats(0, data.size) case DoubleType => reader.getDoubles(0, data.size) + case DateType => reader.getInts(0, data.size) + case TimestampType => reader.getLongs(0, data.size) } assert(values === data) @@ -100,12 +106,14 @@ class ArrowWriterSuite extends SparkFunSuite { check(LongType, (0 until 10).map(_.toLong)) check(FloatType, (0 until 10).map(_.toFloat)) check(DoubleType, (0 until 10).map(_.toDouble)) + check(DateType, (0 until 10)) + check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") } test("array") { val schema = new StructType() .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) @@ -144,7 +152,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested array") { val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array( @@ -195,7 +203,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("struct") { val schema = new StructType() .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) @@ -231,7 +239,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested struct") { val schema = new StructType().add("struct", new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index d24a9e1f4bd16..068a17bf772e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -29,7 +29,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("boolean") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true) + val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableBitVector] vector.allocateNew() val mutator = vector.getMutator() @@ -58,7 +58,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("byte") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true) + val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableTinyIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -87,7 +87,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("short") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true) + val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableSmallIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -116,7 +116,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("int") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true) + val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -145,7 +145,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("long") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("long", LongType, nullable = true) + val vector = ArrowUtils.toArrowField("long", LongType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableBigIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -174,7 +174,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("float") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true) + val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableFloat4Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -203,7 +203,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("double") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true) + val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableFloat8Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -232,7 +232,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("string") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("string", StringType, nullable = true) + val vector = ArrowUtils.toArrowField("string", StringType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableVarCharVector] vector.allocateNew() val mutator = vector.getMutator() @@ -260,7 +260,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("binary") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true) + val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableVarBinaryVector] vector.allocateNew() val mutator = vector.getMutator() @@ -288,7 +288,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("array") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true) + val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null) .createVector(allocator).asInstanceOf[ListVector] vector.allocateNew() val mutator = vector.getMutator() @@ -345,7 +345,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) - val vector = ArrowUtils.toArrowField("struct", schema, nullable = true) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) .createVector(allocator).asInstanceOf[NullableMapVector] vector.allocateNew() val mutator = vector.getMutator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0b179aa97c479..4cfc776e51db1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1249,11 +1249,11 @@ class ColumnarBatchSuite extends SparkFunSuite { test("create columnar batch from Arrow column vectors") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true) + val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector1.allocateNew() val mutator1 = vector1.getMutator() - val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true) + val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector2.allocateNew() val mutator2 = vector2.getMutator() From 36b826f5d17ae7be89135cb2c43ff797f9e7fe48 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 27 Oct 2017 07:52:10 -0700 Subject: [PATCH 776/779] [TRIVIAL][SQL] Code cleaning in ResolveReferences ## What changes were proposed in this pull request? This PR is to clean the related codes majorly based on the today's code review on https://github.com/apache/spark/pull/19559 ## How was this patch tested? N/A Author: gatorsmile Closes #19585 from gatorsmile/trivialFixes. --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++-------- .../scala/org/apache/spark/sql/Column.scala | 10 ++++----- .../spark/sql/RelationalGroupedDataset.scala | 4 ++-- .../sql/execution/WholeStageCodegenExec.scala | 5 ++--- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d6a962a14dc9c..6384a141e83b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -783,6 +783,17 @@ class Analyzer( } } + private def resolve(e: Expression, q: LogicalPlan): Expression = e match { + case u @ UnresolvedAttribute(nameParts) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + case _ => e.mapChildren(resolve(_, q)) + } + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -841,15 +852,7 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q.transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } - logDebug(s"Resolving $u to $result") - result - case UnresolvedExtractValue(child, fieldExpr) if child.resolved => - ExtractValue(child, fieldExpr, resolver) - } + q.mapExpressions(resolve(_, q)) } def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8468a8a96349a..92988680871a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -44,7 +44,7 @@ private[sql] object Column { e match { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => a.aggregateFunction.toString - case expr => usePrettyExpression(expr).sql + case expr => toPrettySQL(expr) } } } @@ -137,7 +137,7 @@ class Column(val expr: Expression) extends Logging { case _ => UnresolvedAttribute.quotedString(name) }) - override def toString: String = usePrettyExpression(expr).sql + override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { case that: Column => that.expr.equals(this.expr) @@ -175,7 +175,7 @@ class Column(val expr: Expression) extends Logging { case c @ Cast(_: NamedExpression, _, _) => UnresolvedAlias(c) } match { case ne: NamedExpression => ne - case other => Alias(expr, usePrettyExpression(expr).sql)() + case _ => Alias(expr, toPrettySQL(expr))() } case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => @@ -184,7 +184,7 @@ class Column(val expr: Expression) extends Logging { // Wait until the struct is resolved. This will generate a nicer looking alias. case struct: CreateNamedStructLike => UnresolvedAlias(struct) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + case expr: Expression => Alias(expr, toPrettySQL(expr))() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6b45790d5ff6e..21e94fa8bb0b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, Unresolved import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf @@ -85,7 +85,7 @@ class RelationalGroupedDataset protected[sql]( case expr: NamedExpression => expr case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + case expr: Expression => Alias(expr, toPrettySQL(expr))() } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index e37d133ff336a..286cb3bb0767c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -521,10 +521,9 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case p if !supportCodegen(p) => // collapse them recursively InputAdapter(insertWholeStageCodegen(p)) - case j @ SortMergeJoinExec(_, _, _, _, left, right) => + case j: SortMergeJoinExec => // The children of SortMergeJoin should do codegen separately. - j.copy(left = InputAdapter(insertWholeStageCodegen(left)), - right = InputAdapter(insertWholeStageCodegen(right))) + j.withNewChildren(j.children.map(child => InputAdapter(insertWholeStageCodegen(child)))) case p => p.withNewChildren(p.children.map(insertInputAdapter)) } From b3d8fc3dc458d42cf11d961762ce99f551f68548 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Oct 2017 13:43:09 -0700 Subject: [PATCH 777/779] [SPARK-22226][SQL] splitExpression can create too many method calls in the outer class ## What changes were proposed in this pull request? SPARK-18016 introduced `NestedClass` to avoid that the many methods generated by `splitExpressions` contribute to the outer class' constant pool, making it growing too much. Unfortunately, despite their definition is stored in the `NestedClass`, they all are invoked in the outer class and for each method invocation, there are two entries added to the constant pool: a `Methodref` and a `Utf8` entry (you can easily check this compiling a simple sample class with `janinoc` and looking at its Constant Pool). This limits the scalability of the solution with very large methods which are split in a lot of small ones. This means that currently we are generating classes like this one: ``` class SpecificUnsafeProjection extends org.apache.spark.sql.catalyst.expressions.UnsafeProjection { ... public UnsafeRow apply(InternalRow i) { rowWriter.zeroOutNullBytes(); apply_0(i); apply_1(i); ... nestedClassInstance.apply_862(i); nestedClassInstance.apply_863(i); ... nestedClassInstance1.apply_1612(i); nestedClassInstance1.apply_1613(i); ... } ... private class NestedClass { private void apply_862(InternalRow i) { ... } private void apply_863(InternalRow i) { ... } ... } private class NestedClass1 { private void apply_1612(InternalRow i) { ... } private void apply_1613(InternalRow i) { ... } ... } } ``` This PR reduce the Constant Pool size of the outer class by adding a new method to each nested class: in this method we invoke all the small methods generated by `splitExpression` in that nested class. In this way, in the outer class there is only one method invocation per nested class, reducing by orders of magnitude the entries in its constant pool because of method invocations. This means that after the patch the generated code becomes: ``` class SpecificUnsafeProjection extends org.apache.spark.sql.catalyst.expressions.UnsafeProjection { ... public UnsafeRow apply(InternalRow i) { rowWriter.zeroOutNullBytes(); apply_0(i); apply_1(i); ... nestedClassInstance.apply(i); nestedClassInstance1.apply(i); ... } ... private class NestedClass { private void apply_862(InternalRow i) { ... } private void apply_863(InternalRow i) { ... } ... private void apply(InternalRow i) { apply_862(i); apply_863(i); ... } } private class NestedClass1 { private void apply_1612(InternalRow i) { ... } private void apply_1613(InternalRow i) { ... } ... private void apply(InternalRow i) { apply_1612(i); apply_1613(i); ... } } } ``` ## How was this patch tested? Added UT and existing UTs Author: Marco Gaido Author: Marco Gaido Closes #19480 from mgaido91/SPARK-22226. --- .../expressions/codegen/CodeGenerator.scala | 156 ++++++++++++++++-- .../expressions/CodeGenerationSuite.scala | 17 ++ .../org/apache/spark/sql/DataFrameSuite.scala | 12 ++ 3 files changed, 167 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2cb66599076a9..58738b52b299f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -77,6 +77,22 @@ case class SubExprEliminationState(isNull: String, value: String) */ case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState]) +/** + * The main information about a new added function. + * + * @param functionName String representing the name of the function + * @param innerClassName Optional value which is empty if the function is added to + * the outer class, otherwise it contains the name of the + * inner class in which the function has been added. + * @param innerClassInstance Optional value which is empty if the function is added to + * the outer class, otherwise it contains the name of the + * instance of the inner class in the outer class. + */ +private[codegen] case class NewFunctionSpec( + functionName: String, + innerClassName: Option[String], + innerClassInstance: Option[String]) + /** * A context for codegen, tracking a list of objects that could be passed into generated Java * function. @@ -228,8 +244,8 @@ class CodegenContext { /** * Holds the class and instance names to be generated, where `OuterClass` is a placeholder * standing for whichever class is generated as the outermost class and which will contain any - * nested sub-classes. All other classes and instance names in this list will represent private, - * nested sub-classes. + * inner sub-classes. All other classes and instance names in this list will represent private, + * inner sub-classes. */ private val classes: mutable.ListBuffer[(String, String)] = mutable.ListBuffer[(String, String)](outerClassName -> null) @@ -260,8 +276,8 @@ class CodegenContext { /** * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the - * function will be inlined into a new private, nested class, and a class-qualified name for the - * function will be returned. Otherwise, the function will be inined to the `OuterClass` the + * function will be inlined into a new private, inner class, and a class-qualified name for the + * function will be returned. Otherwise, the function will be inlined to the `OuterClass` the * simple `funcName` will be returned. * * @param funcName the class-unqualified name of the function @@ -271,19 +287,27 @@ class CodegenContext { * it is eventually referenced and a returned qualified function name * cannot otherwise be accessed. * @return the name of the function, qualified by class if it will be inlined to a private, - * nested sub-class + * inner class */ def addNewFunction( funcName: String, funcCode: String, inlineToOuterClass: Boolean = false): String = { - // The number of named constants that can exist in the class is limited by the Constant Pool - // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a - // threshold of 1600k bytes to determine when a function should be inlined to a private, nested - // sub-class. + val newFunction = addNewFunctionInternal(funcName, funcCode, inlineToOuterClass) + newFunction match { + case NewFunctionSpec(functionName, None, None) => functionName + case NewFunctionSpec(functionName, Some(_), Some(innerClassInstance)) => + innerClassInstance + "." + functionName + } + } + + private[this] def addNewFunctionInternal( + funcName: String, + funcCode: String, + inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > 1600000) { + } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -294,17 +318,23 @@ class CodegenContext { currClass() } - classSize(className) += funcCode.length - classFunctions(className) += funcName -> funcCode + addNewFunctionToClass(funcName, funcCode, className) if (className == outerClassName) { - funcName + NewFunctionSpec(funcName, None, None) } else { - - s"$classInstance.$funcName" + NewFunctionSpec(funcName, Some(className), Some(classInstance)) } } + private[this] def addNewFunctionToClass( + funcName: String, + funcCode: String, + className: String) = { + classSize(className) += funcCode.length + classFunctions(className) += funcName -> funcCode + } + /** * Declares all function code. If the added functions are too many, split them into nested * sub-classes to avoid hitting Java compiler constant pool limitation. @@ -738,7 +768,7 @@ class CodegenContext { /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow - * beyond 1600kb, we declare a private, nested sub-class, and the function is inlined to it + * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it * instead, because classes have a constant pool limit of 65,536 named values. * * @param row the variable name of row that is used by expressions @@ -801,10 +831,90 @@ class CodegenContext { | ${makeSplitFunction(body)} |} """.stripMargin - addNewFunction(name, code) + addNewFunctionInternal(name, code, inlineToOuterClass = false) } - foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) + val (outerClassFunctions, innerClassFunctions) = functions.partition(_.innerClassName.isEmpty) + + val argsString = arguments.map(_._2).mkString(", ") + val outerClassFunctionCalls = outerClassFunctions.map(f => s"${f.functionName}($argsString)") + + val innerClassFunctionCalls = generateInnerClassesFunctionCalls( + innerClassFunctions, + func, + arguments, + returnType, + makeSplitFunction, + foldFunctions) + + foldFunctions(outerClassFunctionCalls ++ innerClassFunctionCalls) + } + } + + /** + * Here we handle all the methods which have been added to the inner classes and + * not to the outer class. + * Since they can be many, their direct invocation in the outer class adds many entries + * to the outer class' constant pool. This can cause the constant pool to past JVM limit. + * Moreover, this can cause also the outer class method where all the invocations are + * performed to grow beyond the 64k limit. + * To avoid these problems, we group them and we call only the grouping methods in the + * outer class. + * + * @param functions a [[Seq]] of [[NewFunctionSpec]] defined in the inner classes + * @param funcName the split function name base. + * @param arguments the list of (type, name) of the arguments of the split function. + * @param returnType the return type of the split function. + * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. + * @param foldFunctions folds the split function calls. + * @return an [[Iterable]] containing the methods' invocations + */ + private def generateInnerClassesFunctionCalls( + functions: Seq[NewFunctionSpec], + funcName: String, + arguments: Seq[(String, String)], + returnType: String, + makeSplitFunction: String => String, + foldFunctions: Seq[String] => String): Iterable[String] = { + val innerClassToFunctions = mutable.LinkedHashMap.empty[(String, String), Seq[String]] + functions.foreach(f => { + val key = (f.innerClassName.get, f.innerClassInstance.get) + val value = f.functionName +: innerClassToFunctions.getOrElse(key, Seq.empty[String]) + innerClassToFunctions.put(key, value) + }) + + val argDefinitionString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") + val argInvocationString = arguments.map(_._2).mkString(", ") + + innerClassToFunctions.flatMap { + case ((innerClassName, innerClassInstance), innerClassFunctions) => + // for performance reasons, the functions are prepended, instead of appended, + // thus here they are in reversed order + val orderedFunctions = innerClassFunctions.reverse + if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + // Adding a new function to each inner class which contains the invocation of all the + // ones which have been added to that inner class. For example, + // private class NestedClass { + // private void apply_862(InternalRow i) { ... } + // private void apply_863(InternalRow i) { ... } + // ... + // private void apply(InternalRow i) { + // apply_862(i); + // apply_863(i); + // ... + // } + // } + val body = foldFunctions(orderedFunctions.map(name => s"$name($argInvocationString)")) + val code = s""" + |private $returnType $funcName($argDefinitionString) { + | ${makeSplitFunction(body)} + |} + """.stripMargin + addNewFunctionToClass(funcName, code, innerClassName) + Seq(s"$innerClassInstance.$funcName($argInvocationString)") + } else { + orderedFunctions.map(f => s"$innerClassInstance.$f($argInvocationString)") + } } } @@ -1013,6 +1123,16 @@ object CodeGenerator extends Logging { // This is the value of HugeMethodLimit in the OpenJDK JVM settings val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + // This is the threshold over which the methods in an inner class are grouped in a single + // method which is going to be called by the outer class instead of the many small ones + val MERGE_SPLIT_METHODS_THRESHOLD = 3 + + // The number of named constants that can exist in the class is limited by the Constant Pool + // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a + // threshold of 1000k bytes to determine when a function should be inlined to a private, inner + // class. + val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + /** * Compile the Java source code into a Java class, using Janino. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7ea0bec145481..1e6f7b65e7e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -201,6 +201,23 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-22226: group splitted expressions into one method per nested class") { + val length = 10000 + val expressions = Seq.fill(length) { + ToUTCTimestamp( + Literal.create(Timestamp.valueOf("2017-10-10 00:00:00"), TimestampType), + Literal.create("PST", StringType)) + } + val plan = GenerateMutableProjection.generate(expressions) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) + val expected = Seq.fill(length)( + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2017-10-10 07:00:00"))) + + if (actual != expected) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 473c355cf3c7f..17c88b0690800 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2106,6 +2106,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + test("SPARK-22226: splitExpressions should not generate codes beyond 64KB") { + val colNumber = 10000 + val input = spark.range(2).rdd.map(_ => Row(1 to colNumber: _*)) + val df = sqlContext.createDataFrame(input, StructType( + (1 to colNumber).map(colIndex => StructField(s"_$colIndex", IntegerType, false)))) + val newCols = (1 to colNumber).flatMap { colIndex => + Seq(expr(s"if(1000 < _$colIndex, 1000, _$colIndex)"), + expr(s"sqrt(_$colIndex)")) + } + df.select(newCols: _*).collect() + } + test("SPARK-22271: mean overflows and returns null for some decimal variables") { val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") From 20eb95e5e9c562261b44e4e47cad67a31390fa59 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 27 Oct 2017 15:19:27 -0700 Subject: [PATCH 778/779] [SPARK-21911][ML][PYSPARK] Parallel Model Evaluation for ML Tuning in PySpark ## What changes were proposed in this pull request? Add parallelism support for ML tuning in pyspark. ## How was this patch tested? Test updated. Author: WeichenXu Closes #19122 from WeichenXu123/par-ml-tuning-py. --- .../spark/ml/tuning/CrossValidatorSuite.scala | 4 +- .../ml/tuning/TrainValidationSplitSuite.scala | 4 +- python/pyspark/ml/tests.py | 39 +++++++++ python/pyspark/ml/tuning.py | 86 ++++++++++++------- 4 files changed, 96 insertions(+), 37 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index a01744f7b67fd..853eeb39bf8df 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -137,8 +137,8 @@ class CrossValidatorSuite cv.setParallelism(2) val cvParallelModel = cv.fit(dataset) - val serialMetrics = cvSerialModel.avgMetrics.sorted - val parallelMetrics = cvParallelModel.avgMetrics.sorted + val serialMetrics = cvSerialModel.avgMetrics + val parallelMetrics = cvParallelModel.avgMetrics assert(serialMetrics === parallelMetrics) val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression] diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 2ed4fbb601b61..f8d9c66be2c40 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -138,8 +138,8 @@ class TrainValidationSplitSuite cv.setParallelism(2) val cvParallelModel = cv.fit(dataset) - val serialMetrics = cvSerialModel.validationMetrics.sorted - val parallelMetrics = cvParallelModel.validationMetrics.sorted + val serialMetrics = cvSerialModel.validationMetrics + val parallelMetrics = cvParallelModel.validationMetrics assert(serialMetrics === parallelMetrics) val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression] diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 8b8bcc7b13a38..2f1f3af957e4d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -836,6 +836,27 @@ def test_save_load_simple_estimator(self): loadedModel = CrossValidatorModel.load(cvModelPath) self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cv.setParallelism(1) + cvSerialModel = cv.fit(dataset) + cv.setParallelism(2) + cvParallelModel = cv.fit(dataset) + self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + def test_save_load_nested_estimator(self): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( @@ -986,6 +1007,24 @@ def test_save_load_simple_estimator(self): loadedModel = TrainValidationSplitModel.load(tvsModelPath) self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvs.setParallelism(1) + tvsSerialModel = tvs.fit(dataset) + tvs.setParallelism(2) + tvsParallelModel = tvs.fit(dataset) + self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + def test_save_load_nested_estimator(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 00c348aa9f7de..47351133524e7 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -14,15 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import itertools import numpy as np +from multiprocessing.pool import ThreadPool from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters -from pyspark.ml.param.shared import HasSeed +from pyspark.ml.param.shared import HasParallelism, HasSeed from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand @@ -170,7 +170,7 @@ def _to_java_impl(self): return java_estimator, java_epms, java_evaluator -class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): +class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -193,7 +193,8 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() - >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + ... parallelism=2) >>> cvModel = cv.fit(dataset) >>> cvModel.avgMetrics[0] 0.5 @@ -208,23 +209,23 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None): + seed=None, parallelism=1): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None) + seed=None, parallelism=1) """ super(CrossValidator, self).__init__() - self._setDefault(numFolds=3) + self._setDefault(numFolds=3, parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None): + seed=None, parallelism=1): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None): + seed=None, parallelism=1): Sets params for cross validator. """ kwargs = self._input_kwargs @@ -255,18 +256,27 @@ def _fit(self, dataset): randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) metrics = [0.0] * numModels + + pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + for i in range(nFolds): validateLB = i * h validateUB = (i + 1) * h condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) - validation = df.filter(condition) - train = df.filter(~condition) - models = est.fit(train, epm) - for j in range(numModels): - model = models[j] + validation = df.filter(condition).cache() + train = df.filter(~condition).cache() + + def singleTrain(paramMap): + model = est.fit(train, paramMap) # TODO: duplicate evaluator to take extra params from input - metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric/nFolds + metric = eva.evaluate(model.transform(validation, paramMap)) + return metric + + currentFoldMetrics = pool.map(singleTrain, epm) + for j in range(numModels): + metrics[j] += (currentFoldMetrics[j] / nFolds) + validation.unpersist() + train.unpersist() if eva.isLargerBetter(): bestIndex = np.argmax(metrics) @@ -316,9 +326,10 @@ def _from_java(cls, java_stage): estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() + parallelism = java_stage.getParallelism() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - numFolds=numFolds, seed=seed) + numFolds=numFolds, seed=seed, parallelism=parallelism) py_stage._resetUid(java_stage.uid()) return py_stage @@ -337,6 +348,7 @@ def _to_java(self): _java_obj.setEstimator(estimator) _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) + _java_obj.setParallelism(self.getParallelism()) return _java_obj @@ -427,7 +439,7 @@ def _to_java(self): return _java_obj -class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): +class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): """ .. note:: Experimental @@ -448,7 +460,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() - >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + ... parallelism=2) >>> tvsModel = tvs.fit(dataset) >>> evaluator.evaluate(tvsModel.transform(dataset)) 0.8333... @@ -461,23 +474,23 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - seed=None): + parallelism=1, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - seed=None) + parallelism=1, seed=None) """ super(TrainValidationSplit, self).__init__() - self._setDefault(trainRatio=0.75) + self._setDefault(trainRatio=0.75, parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - seed=None): + parallelism=1, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - seed=None): + parallelism=1, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs @@ -506,15 +519,20 @@ def _fit(self, dataset): seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) - metrics = [0.0] * numModels condition = (df[randCol] >= tRatio) - validation = df.filter(condition) - train = df.filter(~condition) - models = est.fit(train, epm) - for j in range(numModels): - model = models[j] - metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric + validation = df.filter(condition).cache() + train = df.filter(~condition).cache() + + def singleTrain(paramMap): + model = est.fit(train, paramMap) + metric = eva.evaluate(model.transform(validation, paramMap)) + return metric + + pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + metrics = pool.map(singleTrain, epm) + train.unpersist() + validation.unpersist() + if eva.isLargerBetter(): bestIndex = np.argmax(metrics) else: @@ -563,9 +581,10 @@ def _from_java(cls, java_stage): estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() + parallelism = java_stage.getParallelism() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - trainRatio=trainRatio, seed=seed) + trainRatio=trainRatio, seed=seed, parallelism=parallelism) py_stage._resetUid(java_stage.uid()) return py_stage @@ -584,6 +603,7 @@ def _to_java(self): _java_obj.setEstimator(estimator) _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) + _java_obj.setParallelism(self.getParallelism()) return _java_obj From 01f6ba0e7a12ef818d56e7d5b1bd889b79f2b57c Mon Sep 17 00:00:00 2001 From: Sathiya Date: Fri, 27 Oct 2017 18:57:08 -0700 Subject: [PATCH 779/779] [SPARK-22181][SQL] Adds ReplaceExceptWithFilter rule ## What changes were proposed in this pull request? Adds a new optimisation rule 'ReplaceExceptWithNotFilter' that replaces Except logical with Filter operator and schedule it before applying 'ReplaceExceptWithAntiJoin' rule. This way we can avoid expensive join operation if one or both of the datasets of the Except operation are fully derived out of Filters from a same parent. ## How was this patch tested? The patch is tested locally using spark-shell + unit test. Author: Sathiya Closes #19451 from sathiyapk/SPARK-22181-optimize-exceptWithFilter. --- .../sql/catalyst/expressions/subquery.scala | 10 ++ .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../optimizer/ReplaceExceptWithFilter.scala | 101 +++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 15 +++ .../optimizer/ReplaceOperatorSuite.scala | 106 +++++++++++++++++- .../resources/sql-tests/inputs/except.sql | 57 ++++++++++ .../sql-tests/results/except.sql.out | 105 +++++++++++++++++ 7 files changed, 394 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala create mode 100644 sql/core/src/test/resources/sql-tests/inputs/except.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/except.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index c6146042ef1a6..6acc87a3e7367 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -89,6 +89,16 @@ object SubqueryExpression { case _ => false }.isDefined } + + /** + * Returns true when an expression contains a subquery + */ + def hasSubquery(e: Expression): Boolean = { + e.find { + case _: SubqueryExpression => true + case _ => false + }.isDefined + } } object SubExprUtils extends PredicateHelper { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d829e01441dcc..3273a61dc7b35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -76,6 +76,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, ReplaceIntersectWithSemiJoin, + ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, ReplaceDistinctWithAggregate) :: Batch("Aggregate", fixedPoint, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala new file mode 100644 index 0000000000000..89bfcee078fba --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * If one or both of the datasets in the logical [[Except]] operator are purely transformed using + * [[Filter]], this rule will replace logical [[Except]] operator with a [[Filter]] operator by + * flipping the filter condition of the right child. + * {{{ + * SELECT a1, a2 FROM Tab1 WHERE a2 = 12 EXCEPT SELECT a1, a2 FROM Tab1 WHERE a1 = 5 + * ==> SELECT DISTINCT a1, a2 FROM Tab1 WHERE a2 = 12 AND (a1 is null OR a1 <> 5) + * }}} + * + * Note: + * Before flipping the filter condition of the right node, we should: + * 1. Combine all it's [[Filter]]. + * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition). + */ +object ReplaceExceptWithFilter extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.conf.replaceExceptWithFilter) { + return plan + } + + plan.transform { + case Except(left, right) if isEligible(left, right) => + Distinct(Filter(Not(transformCondition(left, skipProject(right))), left)) + } + } + + private def transformCondition(left: LogicalPlan, right: LogicalPlan): Expression = { + val filterCondition = + InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition + + val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap + + filterCondition.transform { case a : AttributeReference => attributeNameMap(a.name) } + } + + // TODO: This can be further extended in the future. + private def isEligible(left: LogicalPlan, right: LogicalPlan): Boolean = (left, right) match { + case (_, right @ (Project(_, _: Filter) | Filter(_, _))) => verifyConditions(left, right) + case _ => false + } + + private def verifyConditions(left: LogicalPlan, right: LogicalPlan): Boolean = { + val leftProjectList = projectList(left) + val rightProjectList = projectList(right) + + left.output.size == left.output.map(_.name).distinct.size && + left.find(_.expressions.exists(SubqueryExpression.hasSubquery)).isEmpty && + right.find(_.expressions.exists(SubqueryExpression.hasSubquery)).isEmpty && + Project(leftProjectList, nonFilterChild(skipProject(left))).sameResult( + Project(rightProjectList, nonFilterChild(skipProject(right)))) + } + + private def projectList(node: LogicalPlan): Seq[NamedExpression] = node match { + case p: Project => p.projectList + case x => x.output + } + + private def skipProject(node: LogicalPlan): LogicalPlan = node match { + case p: Project => p.child + case x => x + } + + private def nonFilterChild(plan: LogicalPlan) = plan.find(!_.isInstanceOf[Filter]).getOrElse { + throw new IllegalStateException("Leaf node is expected") + } + + private def combineFilters(plan: LogicalPlan): LogicalPlan = { + @tailrec + def iterate(plan: LogicalPlan, acc: LogicalPlan): LogicalPlan = { + if (acc.fastEquals(plan)) acc else iterate(acc, CombineFilters(acc)) + } + iterate(plan, CombineFilters(plan)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 21e4685fcc456..5203e8833fbbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -948,6 +948,19 @@ object SQLConf { .intConf .createWithDefault(10000) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") + .internal() + .doc("When true, the apply function of the rule verifies whether the right node of the" + + " except operation is of type Filter or Project followed by Filter. If yes, the rule" + + " further verifies 1) Excluding the filter operations from the right (as well as the" + + " left node, if any) on the top, whether both the nodes evaluates to a same result." + + " 2) The left and right nodes don't contain any SubqueryExpressions. 3) The output" + + " column names of the left node are distinct. If all the conditions are met, the" + + " rule will replace the except operation with a Filter by flipping the filter" + + " condition(s) of the right node.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1233,6 +1246,8 @@ class SQLConf extends Serializable with Logging { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 85988d2fb948c..0fa1aaeb9e164 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.{Alias, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ @@ -31,6 +32,7 @@ class ReplaceOperatorSuite extends PlanTest { val batches = Batch("Replace Operators", FixedPoint(100), ReplaceDistinctWithAggregate, + ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, ReplaceDeduplicateWithAggregate) :: Nil @@ -50,6 +52,108 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("replace Except with Filter while both the nodes are of type Filter") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) + val table3 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), + Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while only right node is of type Filter") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) + + val query = Except(table1, table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), table1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while both the nodes are of type Project") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Project(Seq(attributeA, attributeB), table1) + val table3 = Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), + Project(Seq(attributeA, attributeB), table1))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while only right node is of type Project") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) + val table3 = Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), + Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while left node is Project and right node is Filter") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))) + val table3 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA === 1 && attributeB === 2)), + Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze + + comparePlans(optimized, correctAnswer) + } + test("replace Except with Left-anti Join") { val table1 = LocalRelation('a.int, 'b.int) val table2 = LocalRelation('c.int, 'd.int) diff --git a/sql/core/src/test/resources/sql-tests/inputs/except.sql b/sql/core/src/test/resources/sql-tests/inputs/except.sql new file mode 100644 index 0000000000000..1d579e65f3473 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/except.sql @@ -0,0 +1,57 @@ +-- Tests different scenarios of except operation +create temporary view t1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", NULL) + as t1(k, v); + +create temporary view t2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5), + ("one", NULL), + (NULL, 5) + as t2(k, v); + + +-- Except operation that will be replaced by left anti join +SELECT * FROM t1 EXCEPT SELECT * FROM t2; + + +-- Except operation that will be replaced by Filter: SPARK-22181 +SELECT * FROM t1 EXCEPT SELECT * FROM t1 where v <> 1 and v <> 2; + + +-- Except operation that will be replaced by Filter: SPARK-22181 +SELECT * FROM t1 where v <> 1 and v <> 22 EXCEPT SELECT * FROM t1 where v <> 2 and v >= 3; + + +-- Except operation that will be replaced by Filter: SPARK-22181 +SELECT t1.* FROM t1, t2 where t1.k = t2.k +EXCEPT +SELECT t1.* FROM t1, t2 where t1.k = t2.k and t1.k != 'one'; + + +-- Except operation that will be replaced by left anti join +SELECT * FROM t2 where v >= 1 and v <> 22 EXCEPT SELECT * FROM t1; + + +-- Except operation that will be replaced by left anti join +SELECT (SELECT min(k) FROM t2 WHERE t2.k = t1.k) min_t2 FROM t1 +MINUS +SELECT (SELECT min(k) FROM t2) abs_min_t2 FROM t1 WHERE t1.k = 'one'; + + +-- Except operation that will be replaced by left anti join +SELECT t1.k +FROM t1 +WHERE t1.v <= (SELECT max(t2.v) + FROM t2 + WHERE t2.k = t1.k) +MINUS +SELECT t1.k +FROM t1 +WHERE t1.v >= (SELECT min(t2.v) + FROM t2 + WHERE t2.k = t1.k); diff --git a/sql/core/src/test/resources/sql-tests/results/except.sql.out b/sql/core/src/test/resources/sql-tests/results/except.sql.out new file mode 100644 index 0000000000000..c9b712d4d2949 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/except.sql.out @@ -0,0 +1,105 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 9 + + +-- !query 0 +create temporary view t1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", NULL) + as t1(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view t2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5), + ("one", NULL), + (NULL, 5) + as t2(k, v) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM t1 EXCEPT SELECT * FROM t2 +-- !query 2 schema +struct +-- !query 2 output +three 3 +two 2 + + +-- !query 3 +SELECT * FROM t1 EXCEPT SELECT * FROM t1 where v <> 1 and v <> 2 +-- !query 3 schema +struct +-- !query 3 output +one 1 +one NULL +two 2 + + +-- !query 4 +SELECT * FROM t1 where v <> 1 and v <> 22 EXCEPT SELECT * FROM t1 where v <> 2 and v >= 3 +-- !query 4 schema +struct +-- !query 4 output +two 2 + + +-- !query 5 +SELECT t1.* FROM t1, t2 where t1.k = t2.k +EXCEPT +SELECT t1.* FROM t1, t2 where t1.k = t2.k and t1.k != 'one' +-- !query 5 schema +struct +-- !query 5 output +one 1 +one NULL + + +-- !query 6 +SELECT * FROM t2 where v >= 1 and v <> 22 EXCEPT SELECT * FROM t1 +-- !query 6 schema +struct +-- !query 6 output +NULL 5 +one 5 + + +-- !query 7 +SELECT (SELECT min(k) FROM t2 WHERE t2.k = t1.k) min_t2 FROM t1 +MINUS +SELECT (SELECT min(k) FROM t2) abs_min_t2 FROM t1 WHERE t1.k = 'one' +-- !query 7 schema +struct +-- !query 7 output +NULL +two + + +-- !query 8 +SELECT t1.k +FROM t1 +WHERE t1.v <= (SELECT max(t2.v) + FROM t2 + WHERE t2.k = t1.k) +MINUS +SELECT t1.k +FROM t1 +WHERE t1.v >= (SELECT min(t2.v) + FROM t2 + WHERE t2.k = t1.k) +-- !query 8 schema +struct +-- !query 8 output +two