diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ab319c860ee69..fac834a70b893 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -33,7 +33,8 @@ private[deploy] object DependencyUtils { packagesExclusions: String, packages: String, repositories: String, - ivyRepoPath: String): String = { + ivyRepoPath: String, + ivySettingsPath: Option[String]): String = { val exclusions: Seq[String] = if (!StringUtils.isBlank(packagesExclusions)) { packagesExclusions.split(",") @@ -41,10 +42,12 @@ private[deploy] object DependencyUtils { Nil } // Create the IvySettings, either load from file or build defaults - val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile => - SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath)) - }.getOrElse { - SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) + val ivySettings = ivySettingsPath match { + case Some(path) => + SparkSubmitUtils.loadIvySettings(path, Option(repositories), Option(ivyRepoPath)) + + case None => + SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) } SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3965f17f4b56e..eddbedeb1024d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -359,7 +359,8 @@ object SparkSubmit extends CommandLineUtils with Logging { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( - args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath) + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath, + args.ivySettingsPath) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index e7796d4ddbe34..8e7070593687b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -63,6 +63,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var ivySettingsPath: Option[String] = None var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false @@ -184,6 +185,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index b19c9904d5982..3f71237164a15 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -79,12 +79,17 @@ object DriverWrapper extends Logging { val secMgr = new SecurityManager(sparkConf) val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf) - val Seq(packagesExclusions, packages, repositories, ivyRepoPath) = - Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy") - .map(sys.props.get(_).orNull) + val Seq(packagesExclusions, packages, repositories, ivyRepoPath, ivySettingsPath) = + Seq( + "spark.jars.excludes", + "spark.jars.packages", + "spark.jars.repositories", + "spark.jars.ivy", + "spark.jars.ivySettings" + ).map(sys.props.get(_).orNull) val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions, - packages, repositories, ivyRepoPath) + packages, repositories, ivyRepoPath, Option(ivySettingsPath)) val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 56b95c31eb4c3..8e8f7d197c9ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -164,7 +164,8 @@ class BlockManagerMasterEndpoint( val futures = blockManagerInfo.values.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove RDD $rddId", e) + logWarning(s"Error trying to remove RDD $rddId from block manager ${bm.blockManagerId}", + e) 0 // zero blocks were removed } }.toSeq @@ -195,7 +196,8 @@ class BlockManagerMasterEndpoint( val futures = requiredBlockManagers.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove broadcast $broadcastId", e) + logWarning(s"Error trying to remove broadcast $broadcastId from block manager " + + s"${bm.blockManagerId}", e) 0 // zero blocks were removed } }.toSeq diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d86ef907b4492..0d7c342a5eacd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -106,6 +106,9 @@ class SparkSubmitSuite // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + private val emptyIvySettings = File.createTempFile("ivy", ".xml") + FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + override def beforeEach() { super.beforeEach() } @@ -520,6 +523,7 @@ class SparkSubmitSuite "--repositories", repo, "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -530,7 +534,6 @@ class SparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") - // Test using "spark.jars.packages" and "spark.jars.repositories" configurations. IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), @@ -540,6 +543,7 @@ class SparkSubmitSuite "--conf", s"spark.jars.repositories=$repo", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -550,7 +554,6 @@ class SparkSubmitSuite // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -563,6 +566,7 @@ class SparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--packages", main.toString, "--repositories", repo, + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", "--verbose", "--conf", "spark.ui.enabled=false", rScriptDir) @@ -573,7 +577,6 @@ class SparkSubmitSuite test("include an external JAR in SparkR") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val rScriptDir = Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 702bcf748fc74..aea07be34cb86 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -111,7 +111,7 @@ and the migration guide below will explain all changes between releases. * The class and trait hierarchy for logistic regression model summaries was changed to be cleaner and better accommodate the addition of the multi-class summary. This is a breaking change for user code that casts a `LogisticRegressionTrainingSummary` to a -` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail (_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which will still work correctly for both multinomial and binary cases. diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 75aea70601875..8b89296b14cdd 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -278,8 +278,8 @@ for details on the API. multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. -Qu8T948*1# -Denoting the `scalingVec` as "`w`," this transformation may be written as: + +Denoting the `scalingVec` as "`w`", this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index d3530908706d0..f567565437927 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -7,7 +7,7 @@ displayTitle: PMML model export - RDD-based API * Table of contents {:toc} -## `spark.mllib` supported models +## spark.mllib supported models `spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). @@ -15,7 +15,7 @@ The table below outlines the `spark.mllib` models that can be exported to PMML a - + diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 975b28de47e20..9c4644947c911 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -549,14 +549,23 @@ specific to Spark on Kubernetes. + + + + @@ -593,4 +602,4 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. -
`spark.mllib` modelPMML model
spark.mllib modelPMML model
spark.kubernetes.driver.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod.
spark.kubernetes.executor.request.cores(none) + Specify the cpu request for each executor pod. Values conform to the Kubernetes [convention](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#meaning-of-cpu). + Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units). + This is distinct from spark.executor.cores: it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task + parallelism, e.g., number of tasks an executor can run concurrently is not affected by this. +
spark.kubernetes.executor.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application.
\ No newline at end of file + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index e7e27876088f3..f26c134c2f6e9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -27,13 +27,10 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String /** * A [[ContinuousReader]] for data from kafka. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala index 1acdd56125741..f35a143e00374 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -20,18 +20,16 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String /** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ private[kafka010] class KafkaRecordToUnsafeRowConverter { - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val rowWriter = new UnsafeRowWriter(7) def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { - bufferHolder.reset() + rowWriter.reset() if (record.key == null) { rowWriter.setNullAt(0) @@ -46,7 +44,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter { 5, DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) rowWriter.write(6, record.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow + rowWriter.getRow() } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 5413d3a416545..f8dc0ec7a0bf6 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -196,6 +196,14 @@ public void testAppHandleDisconnect() throws Exception { Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); client = new TestClient(s); client.send(new Hello(secret, "1.4.0")); + client.send(new SetAppId("someId")); + + // Wait until we know the server has received the messages and matched the handle to the + // connection before disconnecting. + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + assertEquals("someId", handle.getAppId()); + }); + handle.disconnect(); waitForError(client, secret); } finally { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 1cdcdfcaeab78..67cdb097217a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -234,7 +234,7 @@ class StringIndexerModel ( val metadata = NominalAttribute.defaultAttr .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = getHandleInvalid match { + val (filteredDataset, keepInvalid) = $(handleInvalid) match { case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index b373ae921ed38..6bf4aa38b1fcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -17,14 +17,17 @@ package org.apache.spark.ml.feature -import scala.collection.mutable.ArrayBuilder +import java.util.NoSuchElementException + +import scala.collection.mutable +import scala.language.existentials import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -33,10 +36,14 @@ import org.apache.spark.sql.types._ /** * A feature transformer that merges multiple columns into a vector column. + * + * This requires one pass over the entire dataset. In case we need to infer column lengths from the + * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter. */ @Since("1.4.0") class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { + extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid + with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("vecAssembler")) @@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.4.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** + * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + * output). Column lengths are taken from the size of ML Attribute Group, which can be set using + * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + * Default: "error" + * @group param + */ + @Since("2.4.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + |output). Column lengths are taken from the size of ML Attribute Group, which can be set using + |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + |""".stripMargin.replaceAll("\n", " "), + ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema - lazy val first = dataset.toDF.first() - val attrs = $(inputCols).flatMap { c => + + val vectorCols = $(inputCols).filter { c => + schema(c).dataType match { + case _: VectorUDT => true + case _ => false + } + } + val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid)) + + val featureAttributesMap = $(inputCols).map { c => val field = schema(c) - val index = schema.fieldIndex(c) field.dataType match { case DoubleType => - val attr = Attribute.fromStructField(field) - // If the input column doesn't have ML attribute, assume numeric. - if (attr == UnresolvedAttribute) { - Some(NumericAttribute.defaultAttr.withName(c)) - } else { - Some(attr.withName(c)) + val attribute = Attribute.fromStructField(field) + attribute match { + case UnresolvedAttribute => + Seq(NumericAttribute.defaultAttr.withName(c)) + case _ => + Seq(attribute.withName(c)) } case _: NumericType | BooleanType => // If the input column type is a compatible scalar type, assume numeric. - Some(NumericAttribute.defaultAttr.withName(c)) + Seq(NumericAttribute.defaultAttr.withName(c)) case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - if (group.attributes.isDefined) { - // If attributes are defined, copy them with updated names. - group.attributes.get.zipWithIndex.map { case (attr, i) => + val attributeGroup = AttributeGroup.fromStructField(field) + if (attributeGroup.attributes.isDefined) { + attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) => if (attr.name.isDefined) { // TODO: Define a rigorous naming scheme. attr.withName(c + "_" + attr.name.get) @@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. - val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) - Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) + (0 until vectorColsLengths(c)).map { i => + NumericAttribute.defaultAttr.withName(c + "_" + i) + } } case otherType => throw new SparkException(s"VectorAssembler does not support the $otherType type") } } - val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() - + val featureAttributes = featureAttributesMap.flatten[Attribute].toArray + val lengths = featureAttributesMap.map(a => a.length).toArray + val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata() + val (filteredDataset, keepInvalid) = $(handleInvalid) match { + case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false) + case VectorAssembler.KEEP_INVALID => (dataset, true) + case VectorAssembler.ERROR_INVALID => (dataset, false) + } // Data transformation. val assembleFunc = udf { r: Row => - VectorAssembler.assemble(r.toSeq: _*) + VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*) }.asNondeterministic() val args = $(inputCols).map { c => schema(c).dataType match { @@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } } - dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) + filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } @Since("1.4.0") @@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.6.0") object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + + /** + * Infers lengths of vector columns from the first row of the dataset + * @param dataset the dataset + * @param columns name of vector columns whose lengths need to be inferred + * @return map of column names to lengths + */ + private[feature] def getVectorLengthsFromFirstRow( + dataset: Dataset[_], + columns: Seq[String]): Map[String, Int] = { + try { + val first_row = dataset.toDF().select(columns.map(col): _*).first() + columns.zip(first_row.toSeq).map { + case (c, x) => c -> x.asInstanceOf[Vector].size + }.toMap + } catch { + case e: NullPointerException => throw new NullPointerException( + s"""Encountered null value while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + case e: NoSuchElementException => throw new NoSuchElementException( + s"""Encountered empty dataframe while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + } + } + + private[feature] def getLengths( + dataset: Dataset[_], + columns: Seq[String], + handleInvalid: String): Map[String, Int] = { + val groupSizes = columns.map { c => + c -> AttributeGroup.fromStructField(dataset.schema(c)).size + }.toMap + val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq + val firstSizes = (missingColumns.nonEmpty, handleInvalid) match { + case (true, VectorAssembler.ERROR_INVALID) => + getVectorLengthsFromFirstRow(dataset, missingColumns) + case (true, VectorAssembler.SKIP_INVALID) => + getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns) + case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( + s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint + |to add metadata for columns: ${columns.mkString("[", ", ", "]")}.""" + .stripMargin.replaceAll("\n", " ")) + case (_, _) => Map.empty + } + groupSizes ++ firstSizes + } + + @Since("1.6.0") override def load(path: String): VectorAssembler = super.load(path) - private[feature] def assemble(vv: Any*): Vector = { - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - var cur = 0 + /** + * Returns a function that has the required information to assemble each row. + * @param lengths an array of lengths of input columns, whose size should be equal to the number + * of cells in the row (vv) + * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows + * @return a udf that can be applied on each row + */ + private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = { + val indices = mutable.ArrayBuilder.make[Int] + val values = mutable.ArrayBuilder.make[Double] + var featureIndex = 0 + + var inputColumnIndex = 0 vv.foreach { case v: Double => - if (v != 0.0) { - indices += cur + if (v.isNaN && !keepInvalid) { + throw new SparkException( + s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider + |removing NaNs from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } else if (v != 0.0) { + indices += featureIndex values += v } - cur += 1 + inputColumnIndex += 1 + featureIndex += 1 case vec: Vector => vec.foreachActive { case (i, v) => if (v != 0.0) { - indices += cur + i + indices += featureIndex + i values += v } } - cur += vec.size + inputColumnIndex += 1 + featureIndex += vec.size case null => - // TODO: output Double.NaN? - throw new SparkException("Values to assemble cannot be null.") + if (keepInvalid) { + val length: Int = lengths(inputColumnIndex) + Array.range(0, length).foreach { i => + indices += featureIndex + i + values += Double.NaN + } + inputColumnIndex += 1 + featureIndex += length + } else { + throw new SparkException( + s"""Encountered null while assembling a row with handleInvalid = "keep". Consider + |removing nulls from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } case o => throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } - Vectors.sparse(cur, indices.result(), values.result()).compressed + Vectors.sparse(featureIndex, indices.result(), values.result()).compressed } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index eca065f7e775d..91fb24a268b8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, udf} class VectorAssemblerSuite @@ -31,30 +31,49 @@ class VectorAssemblerSuite import testImplicits._ + @transient var dfWithNullsAndNaNs: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sv = Vectors.sparse(2, Array(1), Array(3.0)) + dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)]( + (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (2, 1, 0.0, null, "a", sv, 6L, null), + (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null), + (4, 4, null, null, "a", sv, 9L, null), + (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null)) + .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls") + } + test("params") { ParamsSuite.checkParams(new VectorAssembler) } test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble - assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) - assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + assert(assemble(Array(1), keepInvalid = true)(0.0) + === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0) + === Vectors.sparse(2, Array(1), Array(1.0))) val dv = Vectors.dense(2.0, 0.0) - assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) === + Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) - assert(assemble(0.0, dv, 1.0, sv) === + assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) === Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) - for (v <- Seq(1, "a", null)) { - intercept[SparkException](assemble(v)) - intercept[SparkException](assemble(1.0, v)) + for (v <- Seq(1, "a")) { + intercept[SparkException](assemble(Array(1), keepInvalid = true)(v)) + intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v)) } } test("assemble should compress vectors") { import org.apache.spark.ml.feature.VectorAssembler.assemble - val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) + val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0)) assert(v1.isInstanceOf[SparseVector]) - val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) + val sv = Vectors.sparse(1, Array(0), Array(4.0)) + val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv) assert(v2.isInstanceOf[DenseVector]) } @@ -147,4 +166,94 @@ class VectorAssemblerSuite .filter(vectorUDF($"features") > 1) .count() == 1) } + + test("assemble should keep nulls when keepInvalid is true") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN)) + assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null) + === Vectors.dense(1.0, Double.NaN, Double.NaN)) + assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN)) + assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN)) + } + + test("assemble should throw errors when keepInvalid is false") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1), keepInvalid = false)(null)) + intercept[SparkException](assemble(Array(2), keepInvalid = false)(null)) + } + + test("get lengths functions") { + import org.apache.spark.ml.feature.VectorAssembler._ + val df = dfWithNullsAndNaNs + assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2)) + assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y"))) + .getMessage.contains("VectorSizeHint")) + assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"), + Seq("y"))).getMessage.contains("VectorSizeHint")) + + assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2)) + assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID)) + .getMessage.contains("VectorSizeHint")) + assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID)) + .getMessage.contains("VectorSizeHint")) + } + + test("Handle Invalid should behave properly") { + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + + def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = { + val attributeY = new AttributeGroup("y", 2) + val attributeZ = new AttributeGroup( + "z", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata()) + .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter) + val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata) + output.collect() + output + } + + def runWithFirstRow(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs) + output.collect() + output + } + + def runWithAllNullVectors(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode) + .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2")) + output.collect() + output + } + + // behavior when vector size hint is given + assert(runWithMetadata("keep").count() == 6, "should keep all rows") + assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls") + // should throw error with nulls + intercept[SparkException](runWithMetadata("error")) + // should throw error with NaNs + intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4")) + + // behavior when first row has information + assert(intercept[RuntimeException](runWithFirstRow("keep").count()) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls") + intercept[SparkException](runWithFirstRow("error")) + + // behavior when vector column is all null + assert(intercept[RuntimeException](runWithAllNullVectors("skip")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(intercept[NullPointerException](runWithAllNullVectors("error")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + + // behavior when scalar column is all null + assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4) + } + } diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 78c071da1b7a0..4c0787f767c0b 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2094,6 +2094,11 @@ def test_java_params(self): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel + ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + def _squared_distance(a, b): if isinstance(a, Vector): diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index d9332765f5f9d..5b9af952c7ca4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -91,6 +91,12 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_EXECUTOR_REQUEST_CORES = + ConfigBuilder("spark.kubernetes.executor.request.cores") + .doc("Specify the cpu request for each executor pod") + .stringConf + .createOptional + val KUBERNETES_DRIVER_POD_NAME = ConfigBuilder("spark.kubernetes.driver.pod.name") .doc("Name of the driver pod.") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index 347c4d2d66826..b811db324108c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -93,9 +93,6 @@ private[spark] class BasicDriverConfigurationStep( .withAmount(driverCpuCores) .build() val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryMiB}Mi") - .build() - val driverMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${driverMemoryWithOverheadMiB}Mi") .build() val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => @@ -117,7 +114,7 @@ private[spark] class BasicDriverConfigurationStep( .withNewResources() .addToRequests("cpu", driverCpuQuantity) .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryLimitQuantity) + .addToLimits("memory", driverMemoryQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() .addToArgs("driver") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 9d09dbda0c990..d9f20e18cefe1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -85,7 +85,12 @@ private[spark] class ExecutorPodFactory( MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) + private val executorCores = sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) /** @@ -110,13 +115,10 @@ private[spark] class ExecutorPodFactory( SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ executorLabels val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryMiB}Mi") - .build() - val executorMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${executorMemoryWithOverhead}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCores.toString) + .withAmount(executorCoresRequest) .build() val executorExtraClasspathEnv = executorExtraClasspath.map { cp => new EnvVarBuilder() @@ -135,8 +137,7 @@ private[spark] class ExecutorPodFactory( }.getOrElse(Seq.empty[EnvVar]) val executorEnv = (Seq( (ENV_DRIVER_URL, driverUrl), - // Executor backend expects integral value for executor cores, so round it up to an int. - (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_CORES, executorCores.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), // This is to set the SPARK_CONF_DIR to be /opt/spark/conf @@ -169,7 +170,7 @@ private[spark] class ExecutorPodFactory( .withImagePullPolicy(imagePullPolicy) .withNewResources() .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryLimitQuantity) + .addToLimits("memory", executorMemoryQuantity) .addToRequests("cpu", executorCpuQuantity) .endResources() .addAllToEnv(executorEnv.asJava) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index ce068531c7673..e59c6d28a8cc2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -91,7 +91,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { val resourceRequirements = preparedDriverSpec.driverContainer.getResources val requests = resourceRequirements.getRequests.asScala assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "256Mi") + assert(requests("memory").getAmount === "456Mi") val limits = resourceRequirements.getLimits.asScala assert(limits("memory").getAmount === "456Mi") assert(limits("cpu").getAmount === "4") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 50b97d923ca2d..e8996216cefa7 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -67,12 +67,14 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef assert(executor.getMetadata.getLabels.size() === 3) assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - // There is exactly 1 container with no volume mounts and default memory limits. - // Default memory limit is 1024M + 384M (minimum overhead constant). + // There is exactly 1 container with no volume mounts and default memory limits and requests. + // Default memory limit/request is 1024M + 384M (minimum overhead constant). assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getImage === executorImage) assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources + .getRequests.get("memory").getAmount === "1408Mi") assert(executor.getSpec.getContainers.get(0).getResources .getLimits.get("memory").getAmount === "1408Mi") @@ -84,6 +86,33 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } + test("executor core request specification") { + var factory = new ExecutorPodFactory(baseConf, None, None) + var executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "1") + + val conf = baseConf.clone() + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") + factory = new ExecutorPodFactory(conf, None, None) + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "0.1") + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + factory = new ExecutorPodFactory(conf, None, None) + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "100m") + } + test("executor pod hostnames get truncated to 63 characters") { val conf = baseConf.clone() conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 259976118c12f..537ef244b7e81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -30,25 +30,21 @@ * this class per writing program, so that the memory segment/data buffer can be reused. Note that * for each incoming record, we should call `reset` of BufferHolder instance before write the record * and reuse the data buffer. - * - * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update - * the size of the result row, after writing a record to the buffer. However, we can skip this step - * if the fields of row are all fixed-length, as the size of result row is also fixed. */ -public class BufferHolder { +final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - public byte[] buffer; - public int cursor = Platform.BYTE_ARRAY_OFFSET; + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; private final int fixedSize; - public BufferHolder(UnsafeRow row) { + BufferHolder(UnsafeRow row) { this(row, 64); } - public BufferHolder(UnsafeRow row, int initialSize) { + BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( @@ -64,7 +60,7 @@ public BufferHolder(UnsafeRow row, int initialSize) { /** * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize) { + void grow(int neededSize) { if (neededSize > ARRAY_MAX - totalSize()) { throw new UnsupportedOperationException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + @@ -86,11 +82,23 @@ public void grow(int neededSize) { } } - public void reset() { + byte[] getBuffer() { + return buffer; + } + + int getCursor() { + return cursor; + } + + void increaseCursor(int val) { + cursor += val; + } + + void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } - public int totalSize() { + int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 82cd1b24607e1..a78dd970d23e4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -21,8 +21,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; @@ -32,14 +30,12 @@ */ public final class UnsafeArrayWriter extends UnsafeWriter { - private BufferHolder holder; - - // The offset of the global buffer where we start to write this array. - private int startingOffset; - // The number of elements in this array private int numElements; + // The element size in this array + private int elementSize; + private int headerInBytes; private void assertIndexIsValid(int index) { @@ -47,13 +43,17 @@ private void assertIndexIsValid(int index) { assert index < numElements : "index (" + index + ") should < " + numElements; } - public void initialize(BufferHolder holder, int numElements, int elementSize) { + public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) { + super(writer.getBufferHolder()); + this.elementSize = elementSize; + } + + public void initialize(int numElements) { // We need 8 bytes to store numElements in header this.numElements = numElements; this.headerInBytes = calculateHeaderPortionInBytes(numElements); - this.holder = holder; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); // Grows the global buffer ahead for header and fixed size data. int fixedPartInBytes = @@ -61,112 +61,92 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.grow(headerInBytes + fixedPartInBytes); // Write numElements and clear out null bits to header - Platform.putLong(holder.buffer, startingOffset, numElements); + Platform.putLong(getBuffer(), startingOffset, numElements); for (int i = 8; i < headerInBytes; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } // fill 0 into reminder part of 8-bytes alignment in unsafe array for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { - Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0); + Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0); } - holder.cursor += (headerInBytes + fixedPartInBytes); + increaseCursor(headerInBytes + fixedPartInBytes); } - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); - } - } - - private long getElementOffset(int ordinal, int elementSize) { + private long getElementOffset(int ordinal) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - assertIndexIsValid(ordinal); - final long relativeOffset = currentCursor - startingOffset; - final long offsetAndSize = (relativeOffset << 32) | (long)size; - - write(ordinal, offsetAndSize); - } - private void setNullBit(int ordinal) { assertIndexIsValid(ordinal); - BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); + BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal); } public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); + writeByte(getElementOffset(ordinal), (byte)0); } public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); + writeShort(getElementOffset(ordinal), (short)0); } public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); + writeInt(getElementOffset(ordinal), 0); } public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); + writeLong(getElementOffset(ordinal), 0); } public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value); + writeBoolean(getElementOffset(ordinal), value); } public void write(int ordinal, byte value) { assertIndexIsValid(ordinal); - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value); + writeByte(getElementOffset(ordinal), value); } public void write(int ordinal, short value) { assertIndexIsValid(ordinal); - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value); + writeShort(getElementOffset(ordinal), value); } public void write(int ordinal, int value) { assertIndexIsValid(ordinal); - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value); + writeInt(getElementOffset(ordinal), value); } public void write(int ordinal, long value) { assertIndexIsValid(ordinal); - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value); + writeLong(getElementOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value); + writeFloat(getElementOffset(ordinal), value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value); + writeDouble(getElementOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType assertIndexIsValid(ordinal); - if (input.changePrecision(precision, scale)) { + if (input != null && input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { write(ordinal, input.toUnscaledLong()); } else { @@ -180,65 +160,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - setOffsetAndSize(ordinal, holder.cursor, numBytes); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary - holder.cursor += roundedSize; + increaseCursor(roundedSize); } } else { setNull(ordinal); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - final int numBytes = input.length; - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, holder.cursor, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 2620bbcfb87a2..71c49d8ed0177 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -20,10 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; /** * A helper class to write data into global row buffer using `UnsafeRow` format. @@ -31,7 +28,7 @@ * It will remember the offset of row buffer which it starts to write, and move the cursor of row * buffer while writing. If new data(can be the input record if this is the outermost writer, or * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be - * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the * `startingOffset` and clear out null bits. * * Note that if this is the outermost writer, which means we will always write from the very @@ -40,29 +37,58 @@ */ public final class UnsafeRowWriter extends UnsafeWriter { - private final BufferHolder holder; - // The offset of the global buffer where we start to write this row. - private int startingOffset; + private final UnsafeRow row; + private final int nullBitsSize; private final int fixedSize; - public UnsafeRowWriter(BufferHolder holder, int numFields) { - this.holder = holder; + public UnsafeRowWriter(int numFields) { + this(new UnsafeRow(numFields)); + } + + public UnsafeRowWriter(int numFields, int initialBufferSize) { + this(new UnsafeRow(numFields), initialBufferSize); + } + + public UnsafeRowWriter(UnsafeWriter writer, int numFields) { + this(null, writer.getBufferHolder(), numFields); + } + + private UnsafeRowWriter(UnsafeRow row) { + this(row, new BufferHolder(row), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) { + this(row, new BufferHolder(row, initialBufferSize), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { + super(holder); + this.row = row; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.fixedSize = nullBitsSize + 8 * numFields; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); + } + + /** + * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns + * the UnsafeRow created at a constructor + */ + public UnsafeRow getRow() { + row.setTotalSize(totalSize()); + return row; } /** * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null * bits. This should be called before we write a new nested struct to the row buffer. */ - public void reset() { - this.startingOffset = holder.cursor; + public void resetRowWriter() { + this.startingOffset = cursor(); // grow the global buffer to make sure it has enough space to write fixed-length data. - holder.grow(fixedSize); - holder.cursor += fixedSize; + grow(fixedSize); + increaseCursor(fixedSize); zeroOutNullBytes(); } @@ -72,25 +98,17 @@ public void reset() { */ public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); - } - } - - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } } - public BufferHolder holder() { return holder; } - public boolean isNullAt(int ordinal) { - return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal); } public void setNullAt(int ordinal) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + BitSetMethods.set(getBuffer(), startingOffset, ordinal); + write(ordinal, 0L); } @Override @@ -117,67 +135,49 @@ public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - public void setOffsetAndSize(int ordinal, int size) { - setOffsetAndSize(ordinal, holder.cursor, size); - } - - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - final long relativeOffset = currentCursor - startingOffset; - final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | (long) size; - - Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); - } - public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putBoolean(holder.buffer, offset, value); + writeLong(offset, 0L); + writeBoolean(offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putByte(holder.buffer, offset, value); + writeLong(offset, 0L); + writeByte(offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putShort(holder.buffer, offset, value); + writeLong(offset, 0L); + writeShort(offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putInt(holder.buffer, offset, value); + writeLong(offset, 0L); + writeInt(offset, value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + writeLong(getFieldOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putFloat(holder.buffer, offset, value); + writeLong(offset, 0); + writeFloat(offset, value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } - Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + writeDouble(getFieldOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + if (input != null && input.changePrecision(precision, scale)) { + write(ordinal, input.toUnscaledLong()); } else { setNullAt(ordinal); } @@ -185,82 +185,31 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // grow the global buffer before writing data. holder.grow(16); - // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); - // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // zero-out the bytes + Platform.putLong(getBuffer(), cursor(), 0L); + Platform.putLong(getBuffer(), cursor() + 8, 0L); + + BitSetMethods.set(getBuffer(), startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + final int numBytes = bytes.length; + assert numBytes <= 16; + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, bytes.length); } // move the cursor forward. - holder.cursor += 16; + increaseCursor(16); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - write(ordinal, input, 0, input.length); - } - - public void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, - holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index c94b5c7a367ef..de0eb6dbb76be 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -24,10 +26,73 @@ * Base class for writing Unsafe* structures. */ public abstract class UnsafeWriter { + // Keep internal buffer holder + protected final BufferHolder holder; + + // The offset of the global buffer where we start to write this structure. + protected int startingOffset; + + protected UnsafeWriter(BufferHolder holder) { + this.holder = holder; + } + + /** + * Accessor methods are delegated from BufferHolder class + */ + public final BufferHolder getBufferHolder() { + return holder; + } + + public final byte[] getBuffer() { + return holder.getBuffer(); + } + + public final void reset() { + holder.reset(); + } + + public final int totalSize() { + return holder.totalSize(); + } + + public final void grow(int neededSize) { + holder.grow(neededSize); + } + + public final int cursor() { + return holder.getCursor(); + } + + public final void increaseCursor(int val) { + holder.increaseCursor(val); + } + + public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { + setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor); + } + + protected void setOffsetAndSize(int ordinal, int size) { + setOffsetAndSize(ordinal, cursor(), size); + } + + protected void setOffsetAndSize(int ordinal, int currentCursor, int size) { + final long relativeOffset = currentCursor - startingOffset; + final long offsetAndSize = (relativeOffset << 32) | (long)size; + + write(ordinal, offsetAndSize); + } + + protected final void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L); + } + } + public abstract void setNull1Bytes(int ordinal); public abstract void setNull2Bytes(int ordinal); public abstract void setNull4Bytes(int ordinal); public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); public abstract void write(int ordinal, byte value); public abstract void write(int ordinal, short value); @@ -36,8 +101,92 @@ public abstract class UnsafeWriter { public abstract void write(int ordinal, float value); public abstract void write(int ordinal, double value); public abstract void write(int ordinal, Decimal input, int precision, int scale); - public abstract void write(int ordinal, UTF8String input); - public abstract void write(int ordinal, byte[] input); - public abstract void write(int ordinal, CalendarInterval input); - public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); + + public final void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(getBuffer(), cursor()); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, byte[] input) { + write(ordinal, input, 0, input.length); + } + + public final void write(int ordinal, byte[] input, int offset, int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(getBuffer(), cursor(), input.months); + Platform.putLong(getBuffer(), cursor() + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + increaseCursor(16); + } + + protected final void writeBoolean(long offset, boolean value) { + Platform.putBoolean(getBuffer(), offset, value); + } + + protected final void writeByte(long offset, byte value) { + Platform.putByte(getBuffer(), offset, value); + } + + protected final void writeShort(long offset, short value) { + Platform.putShort(getBuffer(), offset, value); + } + + protected final void writeInt(long offset, int value) { + Platform.putInt(getBuffer(), offset, value); + } + + protected final void writeLong(long offset, long value) { + Platform.putLong(getBuffer(), offset, value); + } + + protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(getBuffer(), offset, value); + } + + protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(getBuffer(), offset, value); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 0da5ece7e47fe..b31466f5c92d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{UserDefinedType, _} import org.apache.spark.unsafe.Platform @@ -42,17 +42,12 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** The row representing the expression results. */ private[this] val intermediate = new GenericInternalRow(values) - /** The row returned by the projection. */ - private[this] val result = new UnsafeRow(numFields) - - /** The buffer which holds the resulting row's backing data. */ - private[this] val holder = new BufferHolder(result, numFields * 32) + /* The row writer for UnsafeRow result */ + private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32) /** The writer that writes the intermediate result to the result row. */ private[this] val writer: InternalRow => Unit = { - val rowWriter = new UnsafeRowWriter(holder, numFields) val baseWriter = generateStructWriter( - holder, rowWriter, expressions.map(e => StructField("", e.dataType, e.nullable))) if (!expressions.exists(_.nullable)) { @@ -83,10 +78,9 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe } // Write the intermediate row to an unsafe row. - holder.reset() + rowWriter.reset() writer(intermediate) - result.setTotalSize(holder.totalSize()) - result + rowWriter.getRow() } } @@ -111,14 +105,13 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * given buffer using the given [[UnsafeRowWriter]]. */ private def generateStructWriter( - bufferHolder: BufferHolder, rowWriter: UnsafeRowWriter, fields: Array[StructField]): InternalRow => Unit = { val numFields = fields.length // Create field writers. val fieldWriters = fields.map { field => - generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + generateFieldWriter(rowWriter, field.dataType, field.nullable) } // Create basic writer. row => { @@ -136,7 +129,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * or array) to the given buffer using the given [[UnsafeWriter]]. */ private def generateFieldWriter( - bufferHolder: BufferHolder, writer: UnsafeWriter, dt: DataType, nullable: Boolean): (SpecializedGetters, Int) => Unit = { @@ -178,81 +170,79 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { case StructType(fields) => val numFields = fields.length - val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) - val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + val rowWriter = new UnsafeRowWriter(writer, numFields) + val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( - bufferHolder, + rowWriter, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) case row => // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. - rowWriter.reset() + rowWriter.resetRowWriter() structWriter.apply(row) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => - val arrayWriter = new UnsafeArrayWriter - val elementSize = getElementSize(elementType) + val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType)) val elementWriter = generateFieldWriter( - bufferHolder, arrayWriter, elementType, containsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor - writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + val previousCursor = writer.cursor() + writeArray(arrayWriter, elementWriter, v.getArray(i)) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case MapType(keyType, valueType, valueContainsNull) => - val keyArrayWriter = new UnsafeArrayWriter - val keySize = getElementSize(keyType) + val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType)) val keyWriter = generateFieldWriter( - bufferHolder, keyArrayWriter, keyType, nullable = false) - val valueArrayWriter = new UnsafeArrayWriter - val valueSize = getElementSize(valueType) + val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType)) val valueWriter = generateFieldWriter( - bufferHolder, valueArrayWriter, valueType, valueContainsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( - bufferHolder, + valueArrayWriter, map.getBaseObject, map.getBaseOffset, map.getSizeInBytes) case map => // preserve 8 bytes to write the key array numBytes later. - bufferHolder.grow(8) - bufferHolder.cursor += 8 + valueArrayWriter.grow(8) + valueArrayWriter.increaseCursor(8) // Write the keys and write the numBytes of key array into the first 8 bytes. - writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) - Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + writeArray(keyArrayWriter, keyWriter, map.keyArray()) + Platform.putLong( + valueArrayWriter.getBuffer, + previousCursor, + valueArrayWriter.cursor - previousCursor - 8 + ) // Write the values. - writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + writeArray(valueArrayWriter, valueWriter, map.valueArray()) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => - generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + generateFieldWriter(writer, udt.sqlType, nullable) case NullType => (_, _) => {} @@ -324,20 +314,18 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * copy. */ private def writeArray( - bufferHolder: BufferHolder, arrayWriter: UnsafeArrayWriter, elementWriter: (SpecializedGetters, Int) => Unit, - array: ArrayData, - elementSize: Int): Unit = array match { + array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => writeUnsafeData( - bufferHolder, + arrayWriter, unsafe.getBaseObject, unsafe.getBaseOffset, unsafe.getSizeInBytes) case _ => val numElements = array.numElements() - arrayWriter.initialize(bufferHolder, numElements, elementSize) + arrayWriter.initialize(numElements) var i = 0 while (i < numElements) { elementWriter.apply(array, i) @@ -350,17 +338,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. */ private def writeUnsafeData( - bufferHolder: BufferHolder, + writer: UnsafeWriter, baseObject: AnyRef, baseOffset: Long, sizeInBytes: Int) : Unit = { - bufferHolder.grow(sizeInBytes) + writer.grow(sizeInBytes) Platform.copyMemory( baseObject, baseOffset, - bufferHolder.buffer, - bufferHolder.cursor, + writer.getBuffer, + writer.cursor, sizeInBytes) - bufferHolder.cursor += sizeInBytes + writer.increaseCursor(sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6682ba55b18b1..ab2254cd9f70a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -48,19 +48,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, fieldTypes: Seq[DataType], - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") + s""" final InternalRow $tmpInput = $input; if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)} + ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} } """ } @@ -70,12 +74,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String, + rowWriter: String, isTopLevel: Boolean = false): String = { - val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") - val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call @@ -88,7 +88,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.zeroOutNullBytes();" } } else { - s"$rowWriter.reset();" + s"$rowWriter.resetRowWriter();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { @@ -97,7 +97,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case udt: UserDefinedType[_] => udt.sqlType case other => other } - val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -105,33 +104,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } + val previousCursor = ctx.freshName("previousCursor") val writeField = dt match { case t: StructType => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -181,12 +181,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", - v => s"$v = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -203,28 +200,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => 8 // we need 8 bytes to store offset and length } - val tmpCursor = ctx.freshName("tmpCursor") + val arrayWriterClass = classOf[UnsafeArrayWriter].getName + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") + val previousCursor = ctx.freshName("previousCursor") + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeArrayToBuffer(ctx, element, et, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -240,10 +241,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} } else { final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); + $arrayWriter.initialize($numElements); for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { @@ -262,7 +263,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, keyType: DataType, valueType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") @@ -271,20 +272,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final MapData $tmpInput = $input; if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} } else { // preserve 8 bytes to write the key array numBytes later. - $bufferHolder.grow(8); - $bufferHolder.cursor += 8; + $rowWriter.grow(8); + $rowWriter.increaseCursor(8); // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + final int $tmpCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); + Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} } """ } @@ -293,14 +294,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * If the input is already in unsafe format, we don't need to go through all elements/fields, * we can directly write it. */ - private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { + private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { val sizeInBytes = ctx.freshName("sizeInBytes") s""" final int $sizeInBytes = $input.getSizeInBytes(); // grow the global buffer before writing data. - $bufferHolder.grow($sizeInBytes); - $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); - $bufferHolder.cursor += $sizeInBytes; + $rowWriter.grow($sizeInBytes); + $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); + $rowWriter.increaseCursor($sizeInBytes); """ } @@ -317,38 +318,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.addMutableState("UnsafeRow", "result", - v => s"$v = new UnsafeRow(${expressions.length});") - - val holderClass = classOf[BufferHolder].getName - val holder = ctx.addMutableState(holderClass, "holder", - v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") - - val resetBufferHolder = if (numVarLenFields == 0) { - "" - } else { - s"$holder.reset();" - } - val updateRowSize = if (numVarLenFields == 0) { - "" - } else { - s"$result.setTotalSize($holder.totalSize());" - } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val writeExpressions = writeExpressionsToBuffer( + ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = s""" - $resetBufferHolder + $rowWriter.reset(); $evalSubexpr $writeExpressions - $updateRowSize """ - ExprCode(code, "false", result) + ExprCode(code, "false", s"$rowWriter.getRow()") } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index adf9ddf327c96..0e9d357c19c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.JavaConverters._ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -501,12 +502,22 @@ case class LambdaVariable( value: String, isNull: String, dataType: DataType, - nullable: Boolean = true) extends LeafExpression - with Unevaluable with NonSQLExpression { + nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field.") + input.get(0, dataType) + } override def genCode(ctx: CodegenContext): ExprCode = { ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") } + + // This won't be called as `genCode` is overrided, just overriding it to make + // `LambdaVariable` non-abstract. + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev } /** @@ -599,8 +610,92 @@ case class MapObjects private( override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + // The data with UserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + lazy private val inputDataType = inputData.dataType match { + case u: UserDefinedType[_] => u.sqlType + case _ => inputData.dataType + } + + private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { + val row = new GenericInternalRow(1) + inputCollection.toIterator.map { element => + row.update(0, element) + lambdaFunction.eval(row) + } + } + + private lazy val convertToSeq: Any => Seq[_] = inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + _.asInstanceOf[Seq[_]] + case ObjectType(cls) if cls.isArray => + _.asInstanceOf[Array[_]].toSeq + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + _.asInstanceOf[java.util.List[_]].asScala + case ObjectType(cls) if cls == classOf[Object] => + (inputCollection) => { + if (inputCollection.getClass.isArray) { + inputCollection.asInstanceOf[Array[_]].toSeq + } else { + inputCollection.asInstanceOf[Seq[_]] + } + } + case ArrayType(et, _) => + _.asInstanceOf[ArrayData].array + } + + private lazy val mapElements: Seq[_] => Any = customCollectionCls match { + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence + executeFuncOnCollection(_).toSeq + case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala set + executeFuncOnCollection(_).toSet + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + // Specifying non concrete implementations of `java.util.List` + executeFuncOnCollection(_).toSeq.asJava + } else { + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + + val constructor = intParamConstructor.map { intConstructor => + (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object]) + }.getOrElse { + (_: Int) => noParamConstructor.get.newInstance() + } + + // Specifying concrete implementations of `java.util.List` + (inputs) => { + val results = executeFuncOnCollection(inputs) + val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]] + results.foreach(builder.add(_)) + builder + } + } + case None => + // array + x => new GenericArrayData(executeFuncOnCollection(x).toArray) + case Some(cls) => + throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " + + "resulting collection.") + } + + override def eval(input: InternalRow): Any = { + val inputCollection = inputData.eval(input) + + if (inputCollection == null) { + return null + } + mapElements(convertToSeq(inputCollection)) + } override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse( @@ -647,13 +742,6 @@ case class MapObjects private( case _ => "" } - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - val inputDataType = inputData.dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType - } - // `MapObjects` generates a while loop to traverse the elements of the input collection. We // need to take care of Seq and List because they may have O(n) complexity for indexed accessing // like `list.get(1)`. Here we use Iterator to traverse Seq and List. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2829d1d81eb1a..9a1bbc675e397 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,7 +153,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) + RemoveRedundantProject) :+ + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) } /** @@ -1309,3 +1311,18 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } } } + +/** + * Updates nullability in [[AttributeReference]]s if nullability is different between + * non-leaf plan's expressions and the children output. + */ +object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p if !p.isInstanceOf[LeafNode] => + val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable }) + p transformExpressions { + case ar: AttributeReference if nullabilityMap.contains(ar) => + ar.withNullability(nullabilityMap(ar)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index a6e5aa6daca65..c3fdb924243df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf /** * Collapse plans consisting empty local relations generated by [[PruneFilters]]. @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules._ * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ -object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { +object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport { private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match { case p: LocalRelation => p.data.isEmpty case _ => false @@ -43,7 +45,9 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { // Construct a project list from plan's output, while the value is always NULL. private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = - plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) } + plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } + + override def conf: SQLConf = SQLConf.get def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: Union if p.children.forall(isEmptyLocalRelation) => diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index fb3dbe8ed1996..2da87113c6229 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -27,7 +27,6 @@ import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.unsafe.types.UTF8String; @@ -55,36 +54,27 @@ private String getRandomString(int length) { } private UnsafeRow makeKeyRow(long k1, String k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 32); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, UTF8String.fromString(k2)); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeKeyRow(long k1, long k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, k2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeValueRow(long v1, long v2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, v1); writer.write(1, v2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 1f6964dfef598..0edd27c8241e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkFunSuite} @@ -25,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -135,6 +136,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23587: MapObjects should support interpreted execution") { + def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val expected = Seq(2, 3, 4) + + val inputObject = BoundReference(0, inputType, nullable = true) + val optClass = Option(collectionCls) + val mapObj = MapObjects(function, inputObject, elementType, true, optClass) + val row = InternalRow.fromSeq(Seq(collection)) + val result = mapObj.eval(row) + + collectionCls match { + case null => + assert(result.asInstanceOf[ArrayData].array.toSeq == expected) + case l if classOf[java.util.List[_]].isAssignableFrom(l) => + assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected) + case s if classOf[Seq[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[Seq[_]].toSeq == expected) + case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet) + } + } + + val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]], + classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], + classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], + classOf[java.util.Stack[Int]], null) + + val list = new java.util.ArrayList[Int]() + list.add(1) + list.add(2) + list.add(3) + val arrayData = new GenericArrayData(Array(1, 2, 3)) + val vector = new java.util.Vector[Int]() + vector.add(1) + vector.add(2) + vector.add(3) + val stack = new java.util.Stack[Int]() + stack.add(1) + stack.add(2) + stack.add(3) + + Seq( + (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), + (Array(1, 2, 3), ObjectType(classOf[Array[Int]])), + (Seq(1, 2, 3), ObjectType(classOf[Object])), + (Array(1, 2, 3), ObjectType(classOf[Object])), + (list, ObjectType(classOf[java.util.List[Int]])), + (vector, ObjectType(classOf[java.util.Vector[Int]])), + (stack, ObjectType(classOf[java.util.Stack[Int]])), + (arrayData, ArrayType(IntegerType)) + ).foreach { case (collection, inputType) => + customCollectionClasses.foreach(testMapObjects(collection, _, inputType)) + + // Unsupported custom collection class + val errMsg = intercept[RuntimeException] { + testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType) + }.getMessage() + assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " + + "as resulting collection.")) + } + } + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { val cls = classOf[java.lang.Integer] val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 3964508e3a55e..f1ce7543ffdc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -37,7 +37,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceIntersectWithSemiJoin, PushDownPredicate, PruneFilters, - PropagateEmptyRelation) :: Nil + PropagateEmptyRelation, + CollapseProject) :: Nil } object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] { @@ -48,7 +49,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters, + CollapseProject) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) @@ -79,9 +81,11 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, LeftOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, FullOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, LeftAnti, Some(testRelation1)), (true, false, LeftSemi, Some(LocalRelation('a.int))), @@ -89,8 +93,9 @@ class PropagateEmptyRelationSuite extends PlanTest { (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), (false, true, RightOuter, - Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), - (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), + (false, true, FullOuter, + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), (false, true, LeftAnti, Some(LocalRelation('a.int))), (false, true, LeftSemi, Some(LocalRelation('a.int))), @@ -209,4 +214,11 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("propagate empty relation keeps the plan resolved") { + val query = testRelation1.join( + LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None) + val optimized = Optimize.execute(query.analyze) + assert(optimized.resolved) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala new file mode 100644 index 0000000000000..09b11f5aba2a0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Constant Folding", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + SimplifyExtractValueOps) :: + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) :: Nil + } + + test("update nullability in AttributeReference") { + val rel = LocalRelation('a.long.notNull) + // In the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to `b`, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. So, we need to update nullability + // by the `UpdateNullabilityInAttributeReferences` rule. + val original = rel + .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b") + .groupBy($"b")("1") + val expected = rel.select('a as "b").groupBy($"b")("1").analyze + val optimized = Optimizer.execute(original.analyze) + comparePlans(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 21ed987627b3b..633d86d495581 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -378,15 +378,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .groupBy($"foo")("1") checkRule(structRel, structExpected) - // These tests must use nullable attributes from the base relation for the following reason: - // in the 'original' plans below, the Aggregate node produced by groupBy() has a - // nullable AttributeReference to a1, because both array indexing and map lookup are - // nullable expressions. After optimization, the same attribute is now non-nullable, - // but the AttributeReference is not updated to reflect this. In the 'expected' plans, - // the grouping expressions have the same nullability as the original attribute in the - // relation. If that attribute is non-nullable, the tests will fail as the plans will - // compare differently, so for these tests we must use a nullable attribute. See - // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") .groupBy($"a1")("1") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 11bfaa0a726a7..4f0e2f58843a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -952,7 +952,8 @@ object SparkSession { session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } - defaultSession.set(session) + setDefaultSession(session) + setActiveSession(session) // Register a successfully instantiated context to the singleton. This should be at the // end of the class definition so that the singleton is updated only if there is no @@ -1028,6 +1029,17 @@ object SparkSession { */ def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 2.4.0 + */ + def active: SparkSession = { + getActiveSession.getOrElse(getDefaultSession.getOrElse( + throw new IllegalStateException("No active or default Spark session found"))) + } + //////////////////////////////////////////////////////////////////////////////////////// // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 8617be88f3570..d5508275c48c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -165,18 +165,14 @@ class RowBasedHashMapGenerator( | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { | // creating the unsafe for new entry - | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length}); - | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder - | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, - | ${numVarLenFields * 32}); | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | agg_holder, - | ${groupingKeySchema.length}); - | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); + | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed | agg_rowWriter.zeroOutNullBytes(); | ${createUnsafeRowForKey}; - | agg_result.setTotalSize(agg_holder.totalSize()); + | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result + | = agg_rowWriter.getRow(); | Object kbase = agg_result.getBaseObject(); | long koff = agg_result.getBaseOffset(); | int klen = agg_result.getSizeInBytes(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 3b5655ba0582e..2d699e8a9d088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -165,9 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(unsafeRow); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -212,11 +210,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; - bufferHolder.reset(); + rowWriter.reset(); rowWriter.zeroOutNullBytes(); ${extractorCalls} - unsafeRow.setTotalSize(bufferHolder.totalSize()); - return unsafeRow; + return rowWriter.getRow(); } ${ctx.declareAddedFunctions()} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 9647f09867643..e93908da43535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -130,16 +130,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) } else { - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + val unsafeRowWriter = new UnsafeRowWriter(1) reader.map { line => // Writes to an UnsafeRow directly - bufferHolder.reset() + unsafeRowWriter.reset() unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow + unsafeRowWriter.getRow() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala deleted file mode 100644 index 2cc54107f8b83..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.encoders.encoderFor - -/** - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @tparam T The expected type of the sink. - */ -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - // This logic should've been as simple as: - // ``` - // data.as[T].foreachPartition { iter => ... } - // ``` - // - // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will - // create a new plan. Because StreamExecution uses the existing plan to collect metrics and - // update watermark, we should never create a new plan. Otherwise, metrics and watermark are - // updated in the new plan, and StreamExecution cannot retrieval them. - // - // Hence, we need to manually convert internal rows to objects using encoder. - val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - data.queryExecution.toRdd.foreachPartition { iter => - if (writer.open(TaskContext.getPartitionId(), batchId)) { - try { - while (iter.hasNext) { - writer.process(encoder.fromRow(iter.next())) - } - } catch { - case e: Throwable => - writer.close(e) - throw e - } - writer.close(null) - } else { - writer.close(null) - } - } - } - - override def toString(): String = "ForeachSink" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala new file mode 100644 index 0000000000000..df5d69d57e36f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @tparam T The expected type of the sink. + */ +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + new StreamWriter with SupportsWriteInternalRow { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + val encoder = encoderFor[T].resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + ForeachWriterFactory(writer, encoder) + } + + override def toString: String = "ForeachSink" + } + } +} + +case class ForeachWriterFactory[T: Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, encoder, partitionId, epochId) + } +} + +/** + * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * @param writer The [[ForeachWriter]] to process all data. + * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param partitionId + * @param epochId + * @tparam T The type expected by the writer. + */ +class ForeachDataWriter[T : Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T], + partitionId: Int, + epochId: Long) + extends DataWriter[InternalRow] { + + // If open returns false, we should skip writing rows. + private val opened = writer.open(partitionId, epochId) + + override def write(record: InternalRow): Unit = { + if (!opened) return + + try { + writer.process(encoder.fromRow(record)) + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + writer.close(null) + ForeachWriterCommitMessage + } + + override def abort(): Unit = {} +} + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 2fc903168cfa0..effc1471e8e12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} import org.apache.spark.sql.sources.v2.StreamWriteSupport /** @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6f43f18ffd0d4..36e25428a1b93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2056,11 +2056,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr: String, expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) - // In the logical plan, all the output columns of input dataframe are nullable - dfWithFilter.queryExecution.optimizedPlan.collect { - case e: Filter => assert(e.output.forall(_.nullable)) - } - dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will // result in null output), the involved columns are converted to not nullable; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index c0301f2ce2d66..44bf8624a6bcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -50,6 +50,24 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { assert(SparkSession.builder().getOrCreate() == session) } + test("sets default and active session") { + assert(SparkSession.getDefaultSession == None) + assert(SparkSession.getActiveSession == None) + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.getDefaultSession == Some(session)) + assert(SparkSession.getActiveSession == Some(session)) + } + + test("get active or default session") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.active == session) + SparkSession.clearActiveSession() + assert(SparkSession.active == session) + SparkSession.clearDefaultSession() + intercept[IllegalStateException](SparkSession.active) + session.stop() + } + test("config options are propagated to existing SparkSession") { val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala similarity index 77% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index b249dd41a84a6..03bf71b3f4b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.util.concurrent.ConcurrentLinkedQueue @@ -25,11 +25,12 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext -class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { +class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -47,9 +48,9 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .start() def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { - import ForeachSinkSuite._ + import ForeachWriterSuite._ - val events = ForeachSinkSuite.allEvents() + val events = ForeachWriterSuite.allEvents() assert(events.size === 2) // one seq of events for each of the 2 partitions // Verify both seq of events have an Open event as the first event @@ -64,13 +65,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } // -- batch 0 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(1, 2, 3, 4) query.processAllAvailable() verifyOutput(expectedVersion = 0, expectedData = 1 to 4) // -- batch 1 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(5, 6, 7, 8) query.processAllAvailable() verifyOutput(expectedVersion = 1, expectedData = 5 to 8) @@ -95,27 +96,27 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf input.addData(1, 2, 3, 4) query.processAllAvailable() - var allEvents = ForeachSinkSuite.allEvents() + var allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) var expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 4), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() // -- batch 1 --------------------------------------- input.addData(5, 6, 7, 8) query.processAllAvailable() - allEvents = ForeachSinkSuite.allEvents() + allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Process(value = 8), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) @@ -131,7 +132,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter() { override def process(value: Int): Unit = { super.process(value) - throw new RuntimeException("error") + throw new RuntimeException("ForeachSinkSuite error") } }).start() input.addData(1, 2, 3, 4) @@ -141,18 +142,18 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() } assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "error") + assert(e.getCause.getCause.getCause.getMessage === "ForeachSinkSuite error") assert(query.isActive === false) - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) - assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(0) === ForeachWriterSuite.Open(partition = 0, version = 0)) + assert(allEvents(0)(1) === ForeachWriterSuite.Process(value = 1)) // `close` should be called with the error - val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] + val errorEvent = allEvents(0)(2).asInstanceOf[ForeachWriterSuite.Close] assert(errorEvent.error.get.isInstanceOf[RuntimeException]) - assert(errorEvent.error.get.getMessage === "error") + assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error") } } @@ -177,12 +178,12 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf inputData.addData(10, 11, 12) query.processAllAvailable() - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) val expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) } finally { @@ -216,21 +217,21 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 3) val expectedEvents = Seq( Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) ) assert(allEvents === expectedEvents) @@ -258,7 +259,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } /** A global object to collect events in the executor */ -object ForeachSinkSuite { +object ForeachWriterSuite { trait Event @@ -285,21 +286,21 @@ object ForeachSinkSuite { /** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ class TestForeachWriter extends ForeachWriter[Int] { - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() - private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]() + private val events = mutable.ArrayBuffer[ForeachWriterSuite.Event]() override def open(partitionId: Long, version: Long): Boolean = { - events += ForeachSinkSuite.Open(partition = partitionId, version = version) + events += ForeachWriterSuite.Open(partition = partitionId, version = version) true } override def process(value: Int): Unit = { - events += ForeachSinkSuite.Process(value) + events += ForeachWriterSuite.Process(value) } override def close(errorOrNull: Throwable): Unit = { - events += ForeachSinkSuite.Close(error = Option(errorOrNull)) - ForeachSinkSuite.addEvents(events) + events += ForeachWriterSuite.Close(error = Option(errorOrNull)) + ForeachWriterSuite.addEvents(events) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 08749b49997e0..20942ed93897c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.DataReaderFactory diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 3038b822beb4a..17603deacdcdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -35,6 +35,7 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) } SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) @transient override lazy val sessionState: SessionState = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 814038d4ef7af..965aea2b61456 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -159,10 +159,6 @@ private[hive] class TestHiveSparkSession( private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => - // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as - // TestSparkSession, but doing this the same way breaks many tests in the package. We need - // to investigate and find a different strategy. - def this(sc: SparkContext, loadTestTables: Boolean) { this( sc, @@ -179,6 +175,9 @@ private[hive] class TestHiveSparkSession( loadTestTables) } + SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) + { // set the metastore temporary configuration val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true",