diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala
index ab319c860ee69..fac834a70b893 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala
@@ -33,7 +33,8 @@ private[deploy] object DependencyUtils {
packagesExclusions: String,
packages: String,
repositories: String,
- ivyRepoPath: String): String = {
+ ivyRepoPath: String,
+ ivySettingsPath: Option[String]): String = {
val exclusions: Seq[String] =
if (!StringUtils.isBlank(packagesExclusions)) {
packagesExclusions.split(",")
@@ -41,10 +42,12 @@ private[deploy] object DependencyUtils {
Nil
}
// Create the IvySettings, either load from file or build defaults
- val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile =>
- SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath))
- }.getOrElse {
- SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath))
+ val ivySettings = ivySettingsPath match {
+ case Some(path) =>
+ SparkSubmitUtils.loadIvySettings(path, Option(repositories), Option(ivyRepoPath))
+
+ case None =>
+ SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath))
}
SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 3965f17f4b56e..eddbedeb1024d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -359,7 +359,8 @@ object SparkSubmit extends CommandLineUtils with Logging {
// Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files
// too for packages that include Python code
val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(
- args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath)
+ args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath,
+ args.ivySettingsPath)
if (!StringUtils.isBlank(resolvedMavenCoordinates)) {
args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index e7796d4ddbe34..8e7070593687b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -63,6 +63,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
var packages: String = null
var repositories: String = null
var ivyRepoPath: String = null
+ var ivySettingsPath: Option[String] = None
var packagesExclusions: String = null
var verbose: Boolean = false
var isPython: Boolean = false
@@ -184,6 +185,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull
files = Option(files).orElse(sparkProperties.get("spark.files")).orNull
ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull
+ ivySettingsPath = sparkProperties.get("spark.jars.ivySettings")
packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull
packagesExclusions = Option(packagesExclusions)
.orElse(sparkProperties.get("spark.jars.excludes")).orNull
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index b19c9904d5982..3f71237164a15 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -79,12 +79,17 @@ object DriverWrapper extends Logging {
val secMgr = new SecurityManager(sparkConf)
val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf)
- val Seq(packagesExclusions, packages, repositories, ivyRepoPath) =
- Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy")
- .map(sys.props.get(_).orNull)
+ val Seq(packagesExclusions, packages, repositories, ivyRepoPath, ivySettingsPath) =
+ Seq(
+ "spark.jars.excludes",
+ "spark.jars.packages",
+ "spark.jars.repositories",
+ "spark.jars.ivy",
+ "spark.jars.ivySettings"
+ ).map(sys.props.get(_).orNull)
val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions,
- packages, repositories, ivyRepoPath)
+ packages, repositories, ivyRepoPath, Option(ivySettingsPath))
val jars = {
val jarsProp = sys.props.get("spark.jars").orNull
if (!StringUtils.isBlank(resolvedMavenCoordinates)) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 56b95c31eb4c3..8e8f7d197c9ef 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -164,7 +164,8 @@ class BlockManagerMasterEndpoint(
val futures = blockManagerInfo.values.map { bm =>
bm.slaveEndpoint.ask[Int](removeMsg).recover {
case e: IOException =>
- logWarning(s"Error trying to remove RDD $rddId", e)
+ logWarning(s"Error trying to remove RDD $rddId from block manager ${bm.blockManagerId}",
+ e)
0 // zero blocks were removed
}
}.toSeq
@@ -195,7 +196,8 @@ class BlockManagerMasterEndpoint(
val futures = requiredBlockManagers.map { bm =>
bm.slaveEndpoint.ask[Int](removeMsg).recover {
case e: IOException =>
- logWarning(s"Error trying to remove broadcast $broadcastId", e)
+ logWarning(s"Error trying to remove broadcast $broadcastId from block manager " +
+ s"${bm.blockManagerId}", e)
0 // zero blocks were removed
}
}.toSeq
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index d86ef907b4492..0d7c342a5eacd 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -106,6 +106,9 @@ class SparkSubmitSuite
// Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
implicit val defaultSignaler: Signaler = ThreadSignaler
+ private val emptyIvySettings = File.createTempFile("ivy", ".xml")
+ FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8)
+
override def beforeEach() {
super.beforeEach()
}
@@ -520,6 +523,7 @@ class SparkSubmitSuite
"--repositories", repo,
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}",
unusedJar.toString,
"my.great.lib.MyLib", "my.great.dep.MyLib")
runSparkSubmit(args)
@@ -530,7 +534,6 @@ class SparkSubmitSuite
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val main = MavenCoordinate("my.great.lib", "mylib", "0.1")
val dep = MavenCoordinate("my.great.dep", "mylib", "0.1")
- // Test using "spark.jars.packages" and "spark.jars.repositories" configurations.
IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo =>
val args = Seq(
"--class", JarCreationTest.getClass.getName.stripSuffix("$"),
@@ -540,6 +543,7 @@ class SparkSubmitSuite
"--conf", s"spark.jars.repositories=$repo",
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}",
unusedJar.toString,
"my.great.lib.MyLib", "my.great.dep.MyLib")
runSparkSubmit(args)
@@ -550,7 +554,6 @@ class SparkSubmitSuite
// See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log
ignore("correctly builds R packages included in a jar with --packages") {
assume(RUtils.isRInstalled, "R isn't installed on this machine.")
- // Check if the SparkR package is installed
assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.")
val main = MavenCoordinate("my.great.lib", "mylib", "0.1")
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
@@ -563,6 +566,7 @@ class SparkSubmitSuite
"--master", "local-cluster[2,1,1024]",
"--packages", main.toString,
"--repositories", repo,
+ "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}",
"--verbose",
"--conf", "spark.ui.enabled=false",
rScriptDir)
@@ -573,7 +577,6 @@ class SparkSubmitSuite
test("include an external JAR in SparkR") {
assume(RUtils.isRInstalled, "R isn't installed on this machine.")
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
- // Check if the SparkR package is installed
assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.")
val rScriptDir =
Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator)
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index 702bcf748fc74..aea07be34cb86 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -111,7 +111,7 @@ and the migration guide below will explain all changes between releases.
* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner
and better accommodate the addition of the multi-class summary. This is a breaking change for user
code that casts a `LogisticRegressionTrainingSummary` to a
-` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary`
+`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary`
method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail
(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which
will still work correctly for both multinomial and binary cases.
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 75aea70601875..8b89296b14cdd 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -278,8 +278,8 @@ for details on the API.
multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This
represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29)
between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector.
-Qu8T948*1#
-Denoting the `scalingVec` as "`w`," this transformation may be written as:
+
+Denoting the `scalingVec` as "`w`", this transformation may be written as:
`\[ \begin{pmatrix}
v_1 \\
diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md
index d3530908706d0..f567565437927 100644
--- a/docs/mllib-pmml-model-export.md
+++ b/docs/mllib-pmml-model-export.md
@@ -7,7 +7,7 @@ displayTitle: PMML model export - RDD-based API
* Table of contents
{:toc}
-## `spark.mllib` supported models
+## spark.mllib supported models
`spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)).
@@ -15,7 +15,7 @@ The table below outlines the `spark.mllib` models that can be exported to PMML a
- `spark.mllib` model | PMML model |
+ spark.mllib model | PMML model |
diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md
index 975b28de47e20..9c4644947c911 100644
--- a/docs/running-on-kubernetes.md
+++ b/docs/running-on-kubernetes.md
@@ -549,14 +549,23 @@ specific to Spark on Kubernetes.
spark.kubernetes.driver.limit.cores |
(none) |
- Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod.
+ Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod.
|
+
+ spark.kubernetes.executor.request.cores |
+ (none) |
+
+ Specify the cpu request for each executor pod. Values conform to the Kubernetes [convention](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#meaning-of-cpu).
+ Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units).
+ This is distinct from spark.executor.cores : it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task
+ parallelism, e.g., number of tasks an executor can run concurrently is not affected by this.
+ |
spark.kubernetes.executor.limit.cores |
(none) |
- Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application.
+ Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application.
|
@@ -593,4 +602,4 @@ specific to Spark on Kubernetes.
spark.kubernetes.executor.secrets.spark-secret=/etc/secrets
.
-
\ No newline at end of file
+
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
index e7e27876088f3..f26c134c2f6e9 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
@@ -27,13 +27,10 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.unsafe.types.UTF8String
/**
* A [[ContinuousReader]] for data from kafka.
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
index 1acdd56125741..f35a143e00374 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
@@ -20,18 +20,16 @@ package org.apache.spark.sql.kafka010
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.unsafe.types.UTF8String
/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */
private[kafka010] class KafkaRecordToUnsafeRowConverter {
- private val sharedRow = new UnsafeRow(7)
- private val bufferHolder = new BufferHolder(sharedRow)
- private val rowWriter = new UnsafeRowWriter(bufferHolder, 7)
+ private val rowWriter = new UnsafeRowWriter(7)
def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = {
- bufferHolder.reset()
+ rowWriter.reset()
if (record.key == null) {
rowWriter.setNullAt(0)
@@ -46,7 +44,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter {
5,
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp)))
rowWriter.write(6, record.timestampType.id)
- sharedRow.setTotalSize(bufferHolder.totalSize)
- sharedRow
+ rowWriter.getRow()
}
}
diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
index 5413d3a416545..f8dc0ec7a0bf6 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
@@ -196,6 +196,14 @@ public void testAppHandleDisconnect() throws Exception {
Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
client = new TestClient(s);
client.send(new Hello(secret, "1.4.0"));
+ client.send(new SetAppId("someId"));
+
+ // Wait until we know the server has received the messages and matched the handle to the
+ // connection before disconnecting.
+ eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> {
+ assertEquals("someId", handle.getAppId());
+ });
+
handle.disconnect();
waitForError(client, secret);
} finally {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 1cdcdfcaeab78..67cdb097217a2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -234,7 +234,7 @@ class StringIndexerModel (
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
- val (filteredDataset, keepInvalid) = getHandleInvalid match {
+ val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index b373ae921ed38..6bf4aa38b1fcb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -17,14 +17,17 @@
package org.apache.spark.ml.feature
-import scala.collection.mutable.ArrayBuilder
+import java.util.NoSuchElementException
+
+import scala.collection.mutable
+import scala.language.existentials
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -33,10 +36,14 @@ import org.apache.spark.sql.types._
/**
* A feature transformer that merges multiple columns into a vector column.
+ *
+ * This requires one pass over the entire dataset. In case we need to infer column lengths from the
+ * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
*/
@Since("1.4.0")
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
- extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
+ extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
+ with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("vecAssembler"))
@@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+ /**
+ * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
+ * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
+ * output). Column lengths are taken from the size of ML Attribute Group, which can be set using
+ * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
+ * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
+ * Default: "error"
+ * @group param
+ */
+ @Since("2.4.0")
+ override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
+ """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
+ |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
+ |output). Column lengths are taken from the size of ML Attribute Group, which can be set using
+ |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
+ |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
+ |""".stripMargin.replaceAll("\n", " "),
+ ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
+
+ setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
- lazy val first = dataset.toDF.first()
- val attrs = $(inputCols).flatMap { c =>
+
+ val vectorCols = $(inputCols).filter { c =>
+ schema(c).dataType match {
+ case _: VectorUDT => true
+ case _ => false
+ }
+ }
+ val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))
+
+ val featureAttributesMap = $(inputCols).map { c =>
val field = schema(c)
- val index = schema.fieldIndex(c)
field.dataType match {
case DoubleType =>
- val attr = Attribute.fromStructField(field)
- // If the input column doesn't have ML attribute, assume numeric.
- if (attr == UnresolvedAttribute) {
- Some(NumericAttribute.defaultAttr.withName(c))
- } else {
- Some(attr.withName(c))
+ val attribute = Attribute.fromStructField(field)
+ attribute match {
+ case UnresolvedAttribute =>
+ Seq(NumericAttribute.defaultAttr.withName(c))
+ case _ =>
+ Seq(attribute.withName(c))
}
case _: NumericType | BooleanType =>
// If the input column type is a compatible scalar type, assume numeric.
- Some(NumericAttribute.defaultAttr.withName(c))
+ Seq(NumericAttribute.defaultAttr.withName(c))
case _: VectorUDT =>
- val group = AttributeGroup.fromStructField(field)
- if (group.attributes.isDefined) {
- // If attributes are defined, copy them with updated names.
- group.attributes.get.zipWithIndex.map { case (attr, i) =>
+ val attributeGroup = AttributeGroup.fromStructField(field)
+ if (attributeGroup.attributes.isDefined) {
+ attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get)
@@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
- val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
- Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
+ (0 until vectorColsLengths(c)).map { i =>
+ NumericAttribute.defaultAttr.withName(c + "_" + i)
+ }
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
}
}
- val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
-
+ val featureAttributes = featureAttributesMap.flatten[Attribute].toArray
+ val lengths = featureAttributesMap.map(a => a.length).toArray
+ val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata()
+ val (filteredDataset, keepInvalid) = $(handleInvalid) match {
+ case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
+ case VectorAssembler.KEEP_INVALID => (dataset, true)
+ case VectorAssembler.ERROR_INVALID => (dataset, false)
+ }
// Data transformation.
val assembleFunc = udf { r: Row =>
- VectorAssembler.assemble(r.toSeq: _*)
+ VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
schema(c).dataType match {
@@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}
}
- dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
+ filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}
@Since("1.4.0")
@@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.6.0")
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
+ private[feature] val SKIP_INVALID: String = "skip"
+ private[feature] val ERROR_INVALID: String = "error"
+ private[feature] val KEEP_INVALID: String = "keep"
+ private[feature] val supportedHandleInvalids: Array[String] =
+ Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
+
+ /**
+ * Infers lengths of vector columns from the first row of the dataset
+ * @param dataset the dataset
+ * @param columns name of vector columns whose lengths need to be inferred
+ * @return map of column names to lengths
+ */
+ private[feature] def getVectorLengthsFromFirstRow(
+ dataset: Dataset[_],
+ columns: Seq[String]): Map[String, Int] = {
+ try {
+ val first_row = dataset.toDF().select(columns.map(col): _*).first()
+ columns.zip(first_row.toSeq).map {
+ case (c, x) => c -> x.asInstanceOf[Vector].size
+ }.toMap
+ } catch {
+ case e: NullPointerException => throw new NullPointerException(
+ s"""Encountered null value while inferring lengths from the first row. Consider using
+ |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
+ .stripMargin.replaceAll("\n", " ") + e.toString)
+ case e: NoSuchElementException => throw new NoSuchElementException(
+ s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
+ |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
+ .stripMargin.replaceAll("\n", " ") + e.toString)
+ }
+ }
+
+ private[feature] def getLengths(
+ dataset: Dataset[_],
+ columns: Seq[String],
+ handleInvalid: String): Map[String, Int] = {
+ val groupSizes = columns.map { c =>
+ c -> AttributeGroup.fromStructField(dataset.schema(c)).size
+ }.toMap
+ val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
+ val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
+ case (true, VectorAssembler.ERROR_INVALID) =>
+ getVectorLengthsFromFirstRow(dataset, missingColumns)
+ case (true, VectorAssembler.SKIP_INVALID) =>
+ getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
+ case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
+ s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
+ |to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
+ .stripMargin.replaceAll("\n", " "))
+ case (_, _) => Map.empty
+ }
+ groupSizes ++ firstSizes
+ }
+
+
@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)
- private[feature] def assemble(vv: Any*): Vector = {
- val indices = ArrayBuilder.make[Int]
- val values = ArrayBuilder.make[Double]
- var cur = 0
+ /**
+ * Returns a function that has the required information to assemble each row.
+ * @param lengths an array of lengths of input columns, whose size should be equal to the number
+ * of cells in the row (vv)
+ * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
+ * @return a udf that can be applied on each row
+ */
+ private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
+ val indices = mutable.ArrayBuilder.make[Int]
+ val values = mutable.ArrayBuilder.make[Double]
+ var featureIndex = 0
+
+ var inputColumnIndex = 0
vv.foreach {
case v: Double =>
- if (v != 0.0) {
- indices += cur
+ if (v.isNaN && !keepInvalid) {
+ throw new SparkException(
+ s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
+ |removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
+ .stripMargin)
+ } else if (v != 0.0) {
+ indices += featureIndex
values += v
}
- cur += 1
+ inputColumnIndex += 1
+ featureIndex += 1
case vec: Vector =>
vec.foreachActive { case (i, v) =>
if (v != 0.0) {
- indices += cur + i
+ indices += featureIndex + i
values += v
}
}
- cur += vec.size
+ inputColumnIndex += 1
+ featureIndex += vec.size
case null =>
- // TODO: output Double.NaN?
- throw new SparkException("Values to assemble cannot be null.")
+ if (keepInvalid) {
+ val length: Int = lengths(inputColumnIndex)
+ Array.range(0, length).foreach { i =>
+ indices += featureIndex + i
+ values += Double.NaN
+ }
+ inputColumnIndex += 1
+ featureIndex += length
+ } else {
+ throw new SparkException(
+ s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
+ |removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
+ .stripMargin)
+ }
case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
}
- Vectors.sparse(cur, indices.result(), values.result()).compressed
+ Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index eca065f7e775d..91fb24a268b8c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
class VectorAssemblerSuite
@@ -31,30 +31,49 @@ class VectorAssemblerSuite
import testImplicits._
+ @transient var dfWithNullsAndNaNs: Dataset[_] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val sv = Vectors.sparse(2, Array(1), Array(3.0))
+ dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)](
+ (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
+ (2, 1, 0.0, null, "a", sv, 6L, null),
+ (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null),
+ (4, 4, null, null, "a", sv, 9L, null),
+ (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
+ (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null))
+ .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls")
+ }
+
test("params") {
ParamsSuite.checkParams(new VectorAssembler)
}
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
- assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
- assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
+ assert(assemble(Array(1), keepInvalid = true)(0.0)
+ === Vectors.sparse(1, Array.empty, Array.empty))
+ assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0)
+ === Vectors.sparse(2, Array(1), Array(1.0)))
val dv = Vectors.dense(2.0, 0.0)
- assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
+ assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) ===
+ Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
- assert(assemble(0.0, dv, 1.0, sv) ===
+ assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) ===
Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
- for (v <- Seq(1, "a", null)) {
- intercept[SparkException](assemble(v))
- intercept[SparkException](assemble(1.0, v))
+ for (v <- Seq(1, "a")) {
+ intercept[SparkException](assemble(Array(1), keepInvalid = true)(v))
+ intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v))
}
}
test("assemble should compress vectors") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
- val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
+ val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0))
assert(v1.isInstanceOf[SparseVector])
- val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
+ val sv = Vectors.sparse(1, Array(0), Array(4.0))
+ val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv)
assert(v2.isInstanceOf[DenseVector])
}
@@ -147,4 +166,94 @@ class VectorAssemblerSuite
.filter(vectorUDF($"features") > 1)
.count() == 1)
}
+
+ test("assemble should keep nulls when keepInvalid is true") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN))
+ assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null)
+ === Vectors.dense(1.0, Double.NaN, Double.NaN))
+ assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN))
+ assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN))
+ }
+
+ test("assemble should throw errors when keepInvalid is false") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null))
+ intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null))
+ intercept[SparkException](assemble(Array(1), keepInvalid = false)(null))
+ intercept[SparkException](assemble(Array(2), keepInvalid = false)(null))
+ }
+
+ test("get lengths functions") {
+ import org.apache.spark.ml.feature.VectorAssembler._
+ val df = dfWithNullsAndNaNs
+ assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2))
+ assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y")))
+ .getMessage.contains("VectorSizeHint"))
+ assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"),
+ Seq("y"))).getMessage.contains("VectorSizeHint"))
+
+ assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2))
+ assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID))
+ .getMessage.contains("VectorSizeHint"))
+ assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID))
+ .getMessage.contains("VectorSizeHint"))
+ }
+
+ test("Handle Invalid should behave properly") {
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("x", "y", "z", "n"))
+ .setOutputCol("features")
+
+ def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = {
+ val attributeY = new AttributeGroup("y", 2)
+ val attributeZ = new AttributeGroup(
+ "z",
+ Array[Attribute](
+ NumericAttribute.defaultAttr.withName("foo"),
+ NumericAttribute.defaultAttr.withName("bar")))
+ val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata())
+ .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter)
+ val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata)
+ output.collect()
+ output
+ }
+
+ def runWithFirstRow(mode: String): Dataset[_] = {
+ val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs)
+ output.collect()
+ output
+ }
+
+ def runWithAllNullVectors(mode: String): Dataset[_] = {
+ val output = assembler.setHandleInvalid(mode)
+ .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2"))
+ output.collect()
+ output
+ }
+
+ // behavior when vector size hint is given
+ assert(runWithMetadata("keep").count() == 6, "should keep all rows")
+ assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls")
+ // should throw error with nulls
+ intercept[SparkException](runWithMetadata("error"))
+ // should throw error with NaNs
+ intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4"))
+
+ // behavior when first row has information
+ assert(intercept[RuntimeException](runWithFirstRow("keep").count())
+ .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+ assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls")
+ intercept[SparkException](runWithFirstRow("error"))
+
+ // behavior when vector column is all null
+ assert(intercept[RuntimeException](runWithAllNullVectors("skip"))
+ .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+ assert(intercept[NullPointerException](runWithAllNullVectors("error"))
+ .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+
+ // behavior when scalar column is all null
+ assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4)
+ }
+
}
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 78c071da1b7a0..4c0787f767c0b 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -2094,6 +2094,11 @@ def test_java_params(self):
# NOTE: disable check_params_exist until there is parity with Scala API
ParamTests.check_params(self, cls(), check_params_exist=False)
+ # Additional classes that need explicit construction
+ from pyspark.ml.feature import CountVectorizerModel
+ ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'),
+ check_params_exist=False)
+
def _squared_distance(a, b):
if isinstance(a, Vector):
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
index d9332765f5f9d..5b9af952c7ca4 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
@@ -91,6 +91,12 @@ private[spark] object Config extends Logging {
.stringConf
.createOptional
+ val KUBERNETES_EXECUTOR_REQUEST_CORES =
+ ConfigBuilder("spark.kubernetes.executor.request.cores")
+ .doc("Specify the cpu request for each executor pod")
+ .stringConf
+ .createOptional
+
val KUBERNETES_DRIVER_POD_NAME =
ConfigBuilder("spark.kubernetes.driver.pod.name")
.doc("Name of the driver pod.")
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala
index 347c4d2d66826..b811db324108c 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala
@@ -93,9 +93,6 @@ private[spark] class BasicDriverConfigurationStep(
.withAmount(driverCpuCores)
.build()
val driverMemoryQuantity = new QuantityBuilder(false)
- .withAmount(s"${driverMemoryMiB}Mi")
- .build()
- val driverMemoryLimitQuantity = new QuantityBuilder(false)
.withAmount(s"${driverMemoryWithOverheadMiB}Mi")
.build()
val maybeCpuLimitQuantity = driverLimitCores.map { limitCores =>
@@ -117,7 +114,7 @@ private[spark] class BasicDriverConfigurationStep(
.withNewResources()
.addToRequests("cpu", driverCpuQuantity)
.addToRequests("memory", driverMemoryQuantity)
- .addToLimits("memory", driverMemoryLimitQuantity)
+ .addToLimits("memory", driverMemoryQuantity)
.addToLimits(maybeCpuLimitQuantity.toMap.asJava)
.endResources()
.addToArgs("driver")
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala
index 9d09dbda0c990..d9f20e18cefe1 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala
@@ -85,7 +85,12 @@ private[spark] class ExecutorPodFactory(
MEMORY_OVERHEAD_MIN_MIB))
private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB
- private val executorCores = sparkConf.getDouble("spark.executor.cores", 1)
+ private val executorCores = sparkConf.getInt("spark.executor.cores", 1)
+ private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) {
+ sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get
+ } else {
+ executorCores.toString
+ }
private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES)
/**
@@ -110,13 +115,10 @@ private[spark] class ExecutorPodFactory(
SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++
executorLabels
val executorMemoryQuantity = new QuantityBuilder(false)
- .withAmount(s"${executorMemoryMiB}Mi")
- .build()
- val executorMemoryLimitQuantity = new QuantityBuilder(false)
.withAmount(s"${executorMemoryWithOverhead}Mi")
.build()
val executorCpuQuantity = new QuantityBuilder(false)
- .withAmount(executorCores.toString)
+ .withAmount(executorCoresRequest)
.build()
val executorExtraClasspathEnv = executorExtraClasspath.map { cp =>
new EnvVarBuilder()
@@ -135,8 +137,7 @@ private[spark] class ExecutorPodFactory(
}.getOrElse(Seq.empty[EnvVar])
val executorEnv = (Seq(
(ENV_DRIVER_URL, driverUrl),
- // Executor backend expects integral value for executor cores, so round it up to an int.
- (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString),
+ (ENV_EXECUTOR_CORES, executorCores.toString),
(ENV_EXECUTOR_MEMORY, executorMemoryString),
(ENV_APPLICATION_ID, applicationId),
// This is to set the SPARK_CONF_DIR to be /opt/spark/conf
@@ -169,7 +170,7 @@ private[spark] class ExecutorPodFactory(
.withImagePullPolicy(imagePullPolicy)
.withNewResources()
.addToRequests("memory", executorMemoryQuantity)
- .addToLimits("memory", executorMemoryLimitQuantity)
+ .addToLimits("memory", executorMemoryQuantity)
.addToRequests("cpu", executorCpuQuantity)
.endResources()
.addAllToEnv(executorEnv.asJava)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala
index ce068531c7673..e59c6d28a8cc2 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala
@@ -91,7 +91,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite {
val resourceRequirements = preparedDriverSpec.driverContainer.getResources
val requests = resourceRequirements.getRequests.asScala
assert(requests("cpu").getAmount === "2")
- assert(requests("memory").getAmount === "256Mi")
+ assert(requests("memory").getAmount === "456Mi")
val limits = resourceRequirements.getLimits.asScala
assert(limits("memory").getAmount === "456Mi")
assert(limits("cpu").getAmount === "4")
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala
index 50b97d923ca2d..e8996216cefa7 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala
@@ -67,12 +67,14 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef
assert(executor.getMetadata.getLabels.size() === 3)
assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1")
- // There is exactly 1 container with no volume mounts and default memory limits.
- // Default memory limit is 1024M + 384M (minimum overhead constant).
+ // There is exactly 1 container with no volume mounts and default memory limits and requests.
+ // Default memory limit/request is 1024M + 384M (minimum overhead constant).
assert(executor.getSpec.getContainers.size() === 1)
assert(executor.getSpec.getContainers.get(0).getImage === executorImage)
assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty)
assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1)
+ assert(executor.getSpec.getContainers.get(0).getResources
+ .getRequests.get("memory").getAmount === "1408Mi")
assert(executor.getSpec.getContainers.get(0).getResources
.getLimits.get("memory").getAmount === "1408Mi")
@@ -84,6 +86,33 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef
checkOwnerReferences(executor, driverPodUid)
}
+ test("executor core request specification") {
+ var factory = new ExecutorPodFactory(baseConf, None, None)
+ var executor = factory.createExecutorPod(
+ "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
+ assert(executor.getSpec.getContainers.size() === 1)
+ assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount
+ === "1")
+
+ val conf = baseConf.clone()
+
+ conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1")
+ factory = new ExecutorPodFactory(conf, None, None)
+ executor = factory.createExecutorPod(
+ "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
+ assert(executor.getSpec.getContainers.size() === 1)
+ assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount
+ === "0.1")
+
+ conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m")
+ factory = new ExecutorPodFactory(conf, None, None)
+ conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m")
+ executor = factory.createExecutorPod(
+ "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
+ assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount
+ === "100m")
+ }
+
test("executor pod hostnames get truncated to 63 characters") {
val conf = baseConf.clone()
conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX,
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
index 259976118c12f..537ef244b7e81 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
@@ -30,25 +30,21 @@
* this class per writing program, so that the memory segment/data buffer can be reused. Note that
* for each incoming record, we should call `reset` of BufferHolder instance before write the record
* and reuse the data buffer.
- *
- * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update
- * the size of the result row, after writing a record to the buffer. However, we can skip this step
- * if the fields of row are all fixed-length, as the size of result row is also fixed.
*/
-public class BufferHolder {
+final class BufferHolder {
private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;
- public byte[] buffer;
- public int cursor = Platform.BYTE_ARRAY_OFFSET;
+ private byte[] buffer;
+ private int cursor = Platform.BYTE_ARRAY_OFFSET;
private final UnsafeRow row;
private final int fixedSize;
- public BufferHolder(UnsafeRow row) {
+ BufferHolder(UnsafeRow row) {
this(row, 64);
}
- public BufferHolder(UnsafeRow row, int initialSize) {
+ BufferHolder(UnsafeRow row, int initialSize) {
int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields());
if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) {
throw new UnsupportedOperationException(
@@ -64,7 +60,7 @@ public BufferHolder(UnsafeRow row, int initialSize) {
/**
* Grows the buffer by at least neededSize and points the row to the buffer.
*/
- public void grow(int neededSize) {
+ void grow(int neededSize) {
if (neededSize > ARRAY_MAX - totalSize()) {
throw new UnsupportedOperationException(
"Cannot grow BufferHolder by size " + neededSize + " because the size after growing " +
@@ -86,11 +82,23 @@ public void grow(int neededSize) {
}
}
- public void reset() {
+ byte[] getBuffer() {
+ return buffer;
+ }
+
+ int getCursor() {
+ return cursor;
+ }
+
+ void increaseCursor(int val) {
+ cursor += val;
+ }
+
+ void reset() {
cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize;
}
- public int totalSize() {
+ int totalSize() {
return cursor - Platform.BYTE_ARRAY_OFFSET;
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 82cd1b24607e1..a78dd970d23e4 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -21,8 +21,6 @@
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
-import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes;
@@ -32,14 +30,12 @@
*/
public final class UnsafeArrayWriter extends UnsafeWriter {
- private BufferHolder holder;
-
- // The offset of the global buffer where we start to write this array.
- private int startingOffset;
-
// The number of elements in this array
private int numElements;
+ // The element size in this array
+ private int elementSize;
+
private int headerInBytes;
private void assertIndexIsValid(int index) {
@@ -47,13 +43,17 @@ private void assertIndexIsValid(int index) {
assert index < numElements : "index (" + index + ") should < " + numElements;
}
- public void initialize(BufferHolder holder, int numElements, int elementSize) {
+ public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) {
+ super(writer.getBufferHolder());
+ this.elementSize = elementSize;
+ }
+
+ public void initialize(int numElements) {
// We need 8 bytes to store numElements in header
this.numElements = numElements;
this.headerInBytes = calculateHeaderPortionInBytes(numElements);
- this.holder = holder;
- this.startingOffset = holder.cursor;
+ this.startingOffset = cursor();
// Grows the global buffer ahead for header and fixed size data.
int fixedPartInBytes =
@@ -61,112 +61,92 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) {
holder.grow(headerInBytes + fixedPartInBytes);
// Write numElements and clear out null bits to header
- Platform.putLong(holder.buffer, startingOffset, numElements);
+ Platform.putLong(getBuffer(), startingOffset, numElements);
for (int i = 8; i < headerInBytes; i += 8) {
- Platform.putLong(holder.buffer, startingOffset + i, 0L);
+ Platform.putLong(getBuffer(), startingOffset + i, 0L);
}
// fill 0 into reminder part of 8-bytes alignment in unsafe array
for (int i = elementSize * numElements; i < fixedPartInBytes; i++) {
- Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0);
+ Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0);
}
- holder.cursor += (headerInBytes + fixedPartInBytes);
+ increaseCursor(headerInBytes + fixedPartInBytes);
}
- private void zeroOutPaddingBytes(int numBytes) {
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
- }
- }
-
- private long getElementOffset(int ordinal, int elementSize) {
+ private long getElementOffset(int ordinal) {
return startingOffset + headerInBytes + ordinal * elementSize;
}
- public void setOffsetAndSize(int ordinal, int currentCursor, int size) {
- assertIndexIsValid(ordinal);
- final long relativeOffset = currentCursor - startingOffset;
- final long offsetAndSize = (relativeOffset << 32) | (long)size;
-
- write(ordinal, offsetAndSize);
- }
-
private void setNullBit(int ordinal) {
assertIndexIsValid(ordinal);
- BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal);
+ BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal);
}
public void setNull1Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0);
+ writeByte(getElementOffset(ordinal), (byte)0);
}
public void setNull2Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0);
+ writeShort(getElementOffset(ordinal), (short)0);
}
public void setNull4Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0);
+ writeInt(getElementOffset(ordinal), 0);
}
public void setNull8Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0);
+ writeLong(getElementOffset(ordinal), 0);
}
public void setNull(int ordinal) { setNull8Bytes(ordinal); }
public void write(int ordinal, boolean value) {
assertIndexIsValid(ordinal);
- Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value);
+ writeBoolean(getElementOffset(ordinal), value);
}
public void write(int ordinal, byte value) {
assertIndexIsValid(ordinal);
- Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value);
+ writeByte(getElementOffset(ordinal), value);
}
public void write(int ordinal, short value) {
assertIndexIsValid(ordinal);
- Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value);
+ writeShort(getElementOffset(ordinal), value);
}
public void write(int ordinal, int value) {
assertIndexIsValid(ordinal);
- Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value);
+ writeInt(getElementOffset(ordinal), value);
}
public void write(int ordinal, long value) {
assertIndexIsValid(ordinal);
- Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value);
+ writeLong(getElementOffset(ordinal), value);
}
public void write(int ordinal, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
assertIndexIsValid(ordinal);
- Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value);
+ writeFloat(getElementOffset(ordinal), value);
}
public void write(int ordinal, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
assertIndexIsValid(ordinal);
- Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value);
+ writeDouble(getElementOffset(ordinal), value);
}
public void write(int ordinal, Decimal input, int precision, int scale) {
// make sure Decimal object has the same scale as DecimalType
assertIndexIsValid(ordinal);
- if (input.changePrecision(precision, scale)) {
+ if (input != null && input.changePrecision(precision, scale)) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
write(ordinal, input.toUnscaledLong());
} else {
@@ -180,65 +160,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
// Write the bytes to the variable length portion.
Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
- setOffsetAndSize(ordinal, holder.cursor, numBytes);
+ bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes);
+ setOffsetAndSize(ordinal, numBytes);
// move the cursor forward with 8-bytes boundary
- holder.cursor += roundedSize;
+ increaseCursor(roundedSize);
}
} else {
setNull(ordinal);
}
}
-
- public void write(int ordinal, UTF8String input) {
- final int numBytes = input.numBytes();
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
-
- setOffsetAndSize(ordinal, holder.cursor, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, byte[] input) {
- final int numBytes = input.length;
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- Platform.copyMemory(
- input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
-
- setOffsetAndSize(ordinal, holder.cursor, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, CalendarInterval input) {
- // grow the global buffer before writing data.
- holder.grow(16);
-
- // Write the months and microseconds fields of Interval to the variable length portion.
- Platform.putLong(holder.buffer, holder.cursor, input.months);
- Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
-
- setOffsetAndSize(ordinal, holder.cursor, 16);
-
- // move the cursor forward.
- holder.cursor += 16;
- }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 2620bbcfb87a2..71c49d8ed0177 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -20,10 +20,7 @@
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
-import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.spark.unsafe.types.UTF8String;
/**
* A helper class to write data into global row buffer using `UnsafeRow` format.
@@ -31,7 +28,7 @@
* It will remember the offset of row buffer which it starts to write, and move the cursor of row
* buffer while writing. If new data(can be the input record if this is the outermost writer, or
* nested struct if this is an inner writer) comes, the starting cursor of row buffer may be
- * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the
+ * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the
* `startingOffset` and clear out null bits.
*
* Note that if this is the outermost writer, which means we will always write from the very
@@ -40,29 +37,58 @@
*/
public final class UnsafeRowWriter extends UnsafeWriter {
- private final BufferHolder holder;
- // The offset of the global buffer where we start to write this row.
- private int startingOffset;
+ private final UnsafeRow row;
+
private final int nullBitsSize;
private final int fixedSize;
- public UnsafeRowWriter(BufferHolder holder, int numFields) {
- this.holder = holder;
+ public UnsafeRowWriter(int numFields) {
+ this(new UnsafeRow(numFields));
+ }
+
+ public UnsafeRowWriter(int numFields, int initialBufferSize) {
+ this(new UnsafeRow(numFields), initialBufferSize);
+ }
+
+ public UnsafeRowWriter(UnsafeWriter writer, int numFields) {
+ this(null, writer.getBufferHolder(), numFields);
+ }
+
+ private UnsafeRowWriter(UnsafeRow row) {
+ this(row, new BufferHolder(row), row.numFields());
+ }
+
+ private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) {
+ this(row, new BufferHolder(row, initialBufferSize), row.numFields());
+ }
+
+ private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) {
+ super(holder);
+ this.row = row;
this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
this.fixedSize = nullBitsSize + 8 * numFields;
- this.startingOffset = holder.cursor;
+ this.startingOffset = cursor();
+ }
+
+ /**
+ * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns
+ * the UnsafeRow created at a constructor
+ */
+ public UnsafeRow getRow() {
+ row.setTotalSize(totalSize());
+ return row;
}
/**
* Resets the `startingOffset` according to the current cursor of row buffer, and clear out null
* bits. This should be called before we write a new nested struct to the row buffer.
*/
- public void reset() {
- this.startingOffset = holder.cursor;
+ public void resetRowWriter() {
+ this.startingOffset = cursor();
// grow the global buffer to make sure it has enough space to write fixed-length data.
- holder.grow(fixedSize);
- holder.cursor += fixedSize;
+ grow(fixedSize);
+ increaseCursor(fixedSize);
zeroOutNullBytes();
}
@@ -72,25 +98,17 @@ public void reset() {
*/
public void zeroOutNullBytes() {
for (int i = 0; i < nullBitsSize; i += 8) {
- Platform.putLong(holder.buffer, startingOffset + i, 0L);
- }
- }
-
- private void zeroOutPaddingBytes(int numBytes) {
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
+ Platform.putLong(getBuffer(), startingOffset + i, 0L);
}
}
- public BufferHolder holder() { return holder; }
-
public boolean isNullAt(int ordinal) {
- return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal);
+ return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal);
}
public void setNullAt(int ordinal) {
- BitSetMethods.set(holder.buffer, startingOffset, ordinal);
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L);
+ BitSetMethods.set(getBuffer(), startingOffset, ordinal);
+ write(ordinal, 0L);
}
@Override
@@ -117,67 +135,49 @@ public long getFieldOffset(int ordinal) {
return startingOffset + nullBitsSize + 8 * ordinal;
}
- public void setOffsetAndSize(int ordinal, int size) {
- setOffsetAndSize(ordinal, holder.cursor, size);
- }
-
- public void setOffsetAndSize(int ordinal, int currentCursor, int size) {
- final long relativeOffset = currentCursor - startingOffset;
- final long fieldOffset = getFieldOffset(ordinal);
- final long offsetAndSize = (relativeOffset << 32) | (long) size;
-
- Platform.putLong(holder.buffer, fieldOffset, offsetAndSize);
- }
-
public void write(int ordinal, boolean value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putBoolean(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeBoolean(offset, value);
}
public void write(int ordinal, byte value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putByte(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeByte(offset, value);
}
public void write(int ordinal, short value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putShort(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeShort(offset, value);
}
public void write(int ordinal, int value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putInt(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeInt(offset, value);
}
public void write(int ordinal, long value) {
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), value);
+ writeLong(getFieldOffset(ordinal), value);
}
public void write(int ordinal, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putFloat(holder.buffer, offset, value);
+ writeLong(offset, 0);
+ writeFloat(offset, value);
}
public void write(int ordinal, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
- Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value);
+ writeDouble(getFieldOffset(ordinal), value);
}
public void write(int ordinal, Decimal input, int precision, int scale) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
// make sure Decimal object has the same scale as DecimalType
- if (input.changePrecision(precision, scale)) {
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong());
+ if (input != null && input.changePrecision(precision, scale)) {
+ write(ordinal, input.toUnscaledLong());
} else {
setNullAt(ordinal);
}
@@ -185,82 +185,31 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
// grow the global buffer before writing data.
holder.grow(16);
- // zero-out the bytes
- Platform.putLong(holder.buffer, holder.cursor, 0L);
- Platform.putLong(holder.buffer, holder.cursor + 8, 0L);
-
// Make sure Decimal object has the same scale as DecimalType.
// Note that we may pass in null Decimal object to set null for it.
if (input == null || !input.changePrecision(precision, scale)) {
- BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+ // zero-out the bytes
+ Platform.putLong(getBuffer(), cursor(), 0L);
+ Platform.putLong(getBuffer(), cursor() + 8, 0L);
+
+ BitSetMethods.set(getBuffer(), startingOffset, ordinal);
// keep the offset for future update
setOffsetAndSize(ordinal, 0);
} else {
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
- assert bytes.length <= 16;
+ final int numBytes = bytes.length;
+ assert numBytes <= 16;
+
+ zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+ bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes);
setOffsetAndSize(ordinal, bytes.length);
}
// move the cursor forward.
- holder.cursor += 16;
+ increaseCursor(16);
}
}
-
- public void write(int ordinal, UTF8String input) {
- final int numBytes = input.numBytes();
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
-
- setOffsetAndSize(ordinal, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, byte[] input) {
- write(ordinal, input, 0, input.length);
- }
-
- public void write(int ordinal, byte[] input, int offset, int numBytes) {
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset,
- holder.buffer, holder.cursor, numBytes);
-
- setOffsetAndSize(ordinal, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, CalendarInterval input) {
- // grow the global buffer before writing data.
- holder.grow(16);
-
- // Write the months and microseconds fields of Interval to the variable length portion.
- Platform.putLong(holder.buffer, holder.cursor, input.months);
- Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
-
- setOffsetAndSize(ordinal, 16);
-
- // move the cursor forward.
- holder.cursor += 16;
- }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index c94b5c7a367ef..de0eb6dbb76be 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions.codegen;
import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -24,10 +26,73 @@
* Base class for writing Unsafe* structures.
*/
public abstract class UnsafeWriter {
+ // Keep internal buffer holder
+ protected final BufferHolder holder;
+
+ // The offset of the global buffer where we start to write this structure.
+ protected int startingOffset;
+
+ protected UnsafeWriter(BufferHolder holder) {
+ this.holder = holder;
+ }
+
+ /**
+ * Accessor methods are delegated from BufferHolder class
+ */
+ public final BufferHolder getBufferHolder() {
+ return holder;
+ }
+
+ public final byte[] getBuffer() {
+ return holder.getBuffer();
+ }
+
+ public final void reset() {
+ holder.reset();
+ }
+
+ public final int totalSize() {
+ return holder.totalSize();
+ }
+
+ public final void grow(int neededSize) {
+ holder.grow(neededSize);
+ }
+
+ public final int cursor() {
+ return holder.getCursor();
+ }
+
+ public final void increaseCursor(int val) {
+ holder.increaseCursor(val);
+ }
+
+ public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) {
+ setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor);
+ }
+
+ protected void setOffsetAndSize(int ordinal, int size) {
+ setOffsetAndSize(ordinal, cursor(), size);
+ }
+
+ protected void setOffsetAndSize(int ordinal, int currentCursor, int size) {
+ final long relativeOffset = currentCursor - startingOffset;
+ final long offsetAndSize = (relativeOffset << 32) | (long)size;
+
+ write(ordinal, offsetAndSize);
+ }
+
+ protected final void zeroOutPaddingBytes(int numBytes) {
+ if ((numBytes & 0x07) > 0) {
+ Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L);
+ }
+ }
+
public abstract void setNull1Bytes(int ordinal);
public abstract void setNull2Bytes(int ordinal);
public abstract void setNull4Bytes(int ordinal);
public abstract void setNull8Bytes(int ordinal);
+
public abstract void write(int ordinal, boolean value);
public abstract void write(int ordinal, byte value);
public abstract void write(int ordinal, short value);
@@ -36,8 +101,92 @@ public abstract class UnsafeWriter {
public abstract void write(int ordinal, float value);
public abstract void write(int ordinal, double value);
public abstract void write(int ordinal, Decimal input, int precision, int scale);
- public abstract void write(int ordinal, UTF8String input);
- public abstract void write(int ordinal, byte[] input);
- public abstract void write(int ordinal, CalendarInterval input);
- public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size);
+
+ public final void write(int ordinal, UTF8String input) {
+ final int numBytes = input.numBytes();
+ final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+
+ // grow the global buffer before writing data.
+ grow(roundedSize);
+
+ zeroOutPaddingBytes(numBytes);
+
+ // Write the bytes to the variable length portion.
+ input.writeToMemory(getBuffer(), cursor());
+
+ setOffsetAndSize(ordinal, numBytes);
+
+ // move the cursor forward.
+ increaseCursor(roundedSize);
+ }
+
+ public final void write(int ordinal, byte[] input) {
+ write(ordinal, input, 0, input.length);
+ }
+
+ public final void write(int ordinal, byte[] input, int offset, int numBytes) {
+ final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
+
+ // grow the global buffer before writing data.
+ grow(roundedSize);
+
+ zeroOutPaddingBytes(numBytes);
+
+ // Write the bytes to the variable length portion.
+ Platform.copyMemory(
+ input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes);
+
+ setOffsetAndSize(ordinal, numBytes);
+
+ // move the cursor forward.
+ increaseCursor(roundedSize);
+ }
+
+ public final void write(int ordinal, CalendarInterval input) {
+ // grow the global buffer before writing data.
+ grow(16);
+
+ // Write the months and microseconds fields of Interval to the variable length portion.
+ Platform.putLong(getBuffer(), cursor(), input.months);
+ Platform.putLong(getBuffer(), cursor() + 8, input.microseconds);
+
+ setOffsetAndSize(ordinal, 16);
+
+ // move the cursor forward.
+ increaseCursor(16);
+ }
+
+ protected final void writeBoolean(long offset, boolean value) {
+ Platform.putBoolean(getBuffer(), offset, value);
+ }
+
+ protected final void writeByte(long offset, byte value) {
+ Platform.putByte(getBuffer(), offset, value);
+ }
+
+ protected final void writeShort(long offset, short value) {
+ Platform.putShort(getBuffer(), offset, value);
+ }
+
+ protected final void writeInt(long offset, int value) {
+ Platform.putInt(getBuffer(), offset, value);
+ }
+
+ protected final void writeLong(long offset, long value) {
+ Platform.putLong(getBuffer(), offset, value);
+ }
+
+ protected final void writeFloat(long offset, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ }
+ Platform.putFloat(getBuffer(), offset, value);
+ }
+
+ protected final void writeDouble(long offset, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ }
+ Platform.putDouble(getBuffer(), offset, value);
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index 0da5ece7e47fe..b31466f5c92d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
+import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{UserDefinedType, _}
import org.apache.spark.unsafe.Platform
@@ -42,17 +42,12 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
/** The row representing the expression results. */
private[this] val intermediate = new GenericInternalRow(values)
- /** The row returned by the projection. */
- private[this] val result = new UnsafeRow(numFields)
-
- /** The buffer which holds the resulting row's backing data. */
- private[this] val holder = new BufferHolder(result, numFields * 32)
+ /* The row writer for UnsafeRow result */
+ private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32)
/** The writer that writes the intermediate result to the result row. */
private[this] val writer: InternalRow => Unit = {
- val rowWriter = new UnsafeRowWriter(holder, numFields)
val baseWriter = generateStructWriter(
- holder,
rowWriter,
expressions.map(e => StructField("", e.dataType, e.nullable)))
if (!expressions.exists(_.nullable)) {
@@ -83,10 +78,9 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
}
// Write the intermediate row to an unsafe row.
- holder.reset()
+ rowWriter.reset()
writer(intermediate)
- result.setTotalSize(holder.totalSize())
- result
+ rowWriter.getRow()
}
}
@@ -111,14 +105,13 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
* given buffer using the given [[UnsafeRowWriter]].
*/
private def generateStructWriter(
- bufferHolder: BufferHolder,
rowWriter: UnsafeRowWriter,
fields: Array[StructField]): InternalRow => Unit = {
val numFields = fields.length
// Create field writers.
val fieldWriters = fields.map { field =>
- generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable)
+ generateFieldWriter(rowWriter, field.dataType, field.nullable)
}
// Create basic writer.
row => {
@@ -136,7 +129,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
* or array) to the given buffer using the given [[UnsafeWriter]].
*/
private def generateFieldWriter(
- bufferHolder: BufferHolder,
writer: UnsafeWriter,
dt: DataType,
nullable: Boolean): (SpecializedGetters, Int) => Unit = {
@@ -178,81 +170,79 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
case StructType(fields) =>
val numFields = fields.length
- val rowWriter = new UnsafeRowWriter(bufferHolder, numFields)
- val structWriter = generateStructWriter(bufferHolder, rowWriter, fields)
+ val rowWriter = new UnsafeRowWriter(writer, numFields)
+ val structWriter = generateStructWriter(rowWriter, fields)
(v, i) => {
- val tmpCursor = bufferHolder.cursor
+ val previousCursor = writer.cursor()
v.getStruct(i, fields.length) match {
case row: UnsafeRow =>
writeUnsafeData(
- bufferHolder,
+ rowWriter,
row.getBaseObject,
row.getBaseOffset,
row.getSizeInBytes)
case row =>
// Nested struct. We don't know where this will start because a row can be
// variable length, so we need to update the offsets and zero out the bit mask.
- rowWriter.reset()
+ rowWriter.resetRowWriter()
structWriter.apply(row)
}
- writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+ writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
}
case ArrayType(elementType, containsNull) =>
- val arrayWriter = new UnsafeArrayWriter
- val elementSize = getElementSize(elementType)
+ val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType))
val elementWriter = generateFieldWriter(
- bufferHolder,
arrayWriter,
elementType,
containsNull)
(v, i) => {
- val tmpCursor = bufferHolder.cursor
- writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize)
- writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+ val previousCursor = writer.cursor()
+ writeArray(arrayWriter, elementWriter, v.getArray(i))
+ writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
}
case MapType(keyType, valueType, valueContainsNull) =>
- val keyArrayWriter = new UnsafeArrayWriter
- val keySize = getElementSize(keyType)
+ val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType))
val keyWriter = generateFieldWriter(
- bufferHolder,
keyArrayWriter,
keyType,
nullable = false)
- val valueArrayWriter = new UnsafeArrayWriter
- val valueSize = getElementSize(valueType)
+ val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType))
val valueWriter = generateFieldWriter(
- bufferHolder,
valueArrayWriter,
valueType,
valueContainsNull)
(v, i) => {
- val tmpCursor = bufferHolder.cursor
+ val previousCursor = writer.cursor()
v.getMap(i) match {
case map: UnsafeMapData =>
writeUnsafeData(
- bufferHolder,
+ valueArrayWriter,
map.getBaseObject,
map.getBaseOffset,
map.getSizeInBytes)
case map =>
// preserve 8 bytes to write the key array numBytes later.
- bufferHolder.grow(8)
- bufferHolder.cursor += 8
+ valueArrayWriter.grow(8)
+ valueArrayWriter.increaseCursor(8)
// Write the keys and write the numBytes of key array into the first 8 bytes.
- writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize)
- Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8)
+ writeArray(keyArrayWriter, keyWriter, map.keyArray())
+ Platform.putLong(
+ valueArrayWriter.getBuffer,
+ previousCursor,
+ valueArrayWriter.cursor - previousCursor - 8
+ )
// Write the values.
- writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize)
+ writeArray(valueArrayWriter, valueWriter, map.valueArray())
}
- writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor)
+ writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
}
case udt: UserDefinedType[_] =>
- generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable)
+ generateFieldWriter(writer, udt.sqlType, nullable)
case NullType =>
(_, _) => {}
@@ -324,20 +314,18 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
* copy.
*/
private def writeArray(
- bufferHolder: BufferHolder,
arrayWriter: UnsafeArrayWriter,
elementWriter: (SpecializedGetters, Int) => Unit,
- array: ArrayData,
- elementSize: Int): Unit = array match {
+ array: ArrayData): Unit = array match {
case unsafe: UnsafeArrayData =>
writeUnsafeData(
- bufferHolder,
+ arrayWriter,
unsafe.getBaseObject,
unsafe.getBaseOffset,
unsafe.getSizeInBytes)
case _ =>
val numElements = array.numElements()
- arrayWriter.initialize(bufferHolder, numElements, elementSize)
+ arrayWriter.initialize(numElements)
var i = 0
while (i < numElements) {
elementWriter.apply(array, i)
@@ -350,17 +338,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
* [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects.
*/
private def writeUnsafeData(
- bufferHolder: BufferHolder,
+ writer: UnsafeWriter,
baseObject: AnyRef,
baseOffset: Long,
sizeInBytes: Int) : Unit = {
- bufferHolder.grow(sizeInBytes)
+ writer.grow(sizeInBytes)
Platform.copyMemory(
baseObject,
baseOffset,
- bufferHolder.buffer,
- bufferHolder.cursor,
+ writer.getBuffer,
+ writer.cursor,
sizeInBytes)
- bufferHolder.cursor += sizeInBytes
+ writer.increaseCursor(sizeInBytes)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 6682ba55b18b1..ab2254cd9f70a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -48,19 +48,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodegenContext,
input: String,
fieldTypes: Seq[DataType],
- bufferHolder: String): String = {
+ rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString))
}
+ val rowWriterClass = classOf[UnsafeRowWriter].getName
+ val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
+ v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});")
+
s"""
final InternalRow $tmpInput = $input;
if ($tmpInput instanceof UnsafeRow) {
- ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)}
+ ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)}
} else {
- ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)}
+ ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)}
}
"""
}
@@ -70,12 +74,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
row: String,
inputs: Seq[ExprCode],
inputTypes: Seq[DataType],
- bufferHolder: String,
+ rowWriter: String,
isTopLevel: Boolean = false): String = {
- val rowWriterClass = classOf[UnsafeRowWriter].getName
- val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
- v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});")
-
val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
// which means its fixed-size region always in the same position, so we don't need to call
@@ -88,7 +88,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"$rowWriter.zeroOutNullBytes();"
}
} else {
- s"$rowWriter.reset();"
+ s"$rowWriter.resetRowWriter();"
}
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
@@ -97,7 +97,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case udt: UserDefinedType[_] => udt.sqlType
case other => other
}
- val tmpCursor = ctx.freshName("tmpCursor")
val setNull = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
@@ -105,33 +104,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});"
case _ => s"$rowWriter.setNullAt($index);"
}
+ val previousCursor = ctx.freshName("previousCursor")
val writeField = dt match {
case t: StructType =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)}
- $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $rowWriter.cursor();
+ ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)}
+ $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case a @ ArrayType(et, _) =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)}
- $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $rowWriter.cursor();
+ ${writeArrayToBuffer(ctx, input.value, et, rowWriter)}
+ $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case m @ MapType(kt, vt, _) =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)}
- $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $rowWriter.cursor();
+ ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)}
+ $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case t: DecimalType =>
@@ -181,12 +181,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodegenContext,
input: String,
elementType: DataType,
- bufferHolder: String): String = {
+ rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
- val arrayWriterClass = classOf[UnsafeArrayWriter].getName
- val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter",
- v => s"$v = new $arrayWriterClass();")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
@@ -203,28 +200,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => 8 // we need 8 bytes to store offset and length
}
- val tmpCursor = ctx.freshName("tmpCursor")
+ val arrayWriterClass = classOf[UnsafeArrayWriter].getName
+ val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter",
+ v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);")
+ val previousCursor = ctx.freshName("previousCursor")
+
val element = CodeGenerator.getValue(tmpInput, et, index)
val writeElement = et match {
case t: StructType =>
s"""
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
- $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $arrayWriter.cursor();
+ ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)}
+ $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case a @ ArrayType(et, _) =>
s"""
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeArrayToBuffer(ctx, element, et, bufferHolder)}
- $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $arrayWriter.cursor();
+ ${writeArrayToBuffer(ctx, element, et, arrayWriter)}
+ $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case m @ MapType(kt, vt, _) =>
s"""
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
- $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ final int $previousCursor = $arrayWriter.cursor();
+ ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)}
+ $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
"""
case t: DecimalType =>
@@ -240,10 +241,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
final ArrayData $tmpInput = $input;
if ($tmpInput instanceof UnsafeArrayData) {
- ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)}
+ ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)}
} else {
final int $numElements = $tmpInput.numElements();
- $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize);
+ $arrayWriter.initialize($numElements);
for (int $index = 0; $index < $numElements; $index++) {
if ($tmpInput.isNullAt($index)) {
@@ -262,7 +263,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
input: String,
keyType: DataType,
valueType: DataType,
- bufferHolder: String): String = {
+ rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val tmpCursor = ctx.freshName("tmpCursor")
@@ -271,20 +272,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
final MapData $tmpInput = $input;
if ($tmpInput instanceof UnsafeMapData) {
- ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)}
+ ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)}
} else {
// preserve 8 bytes to write the key array numBytes later.
- $bufferHolder.grow(8);
- $bufferHolder.cursor += 8;
+ $rowWriter.grow(8);
+ $rowWriter.increaseCursor(8);
// Remember the current cursor so that we can write numBytes of key array later.
- final int $tmpCursor = $bufferHolder.cursor;
+ final int $tmpCursor = $rowWriter.cursor();
- ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)}
+ ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)}
// Write the numBytes of key array into the first 8 bytes.
- Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);
+ Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor);
- ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)}
+ ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)}
}
"""
}
@@ -293,14 +294,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
* If the input is already in unsafe format, we don't need to go through all elements/fields,
* we can directly write it.
*/
- private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = {
+ private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = {
val sizeInBytes = ctx.freshName("sizeInBytes")
s"""
final int $sizeInBytes = $input.getSizeInBytes();
// grow the global buffer before writing data.
- $bufferHolder.grow($sizeInBytes);
- $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor);
- $bufferHolder.cursor += $sizeInBytes;
+ $rowWriter.grow($sizeInBytes);
+ $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor());
+ $rowWriter.increaseCursor($sizeInBytes);
"""
}
@@ -317,38 +318,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => true
}
- val result = ctx.addMutableState("UnsafeRow", "result",
- v => s"$v = new UnsafeRow(${expressions.length});")
-
- val holderClass = classOf[BufferHolder].getName
- val holder = ctx.addMutableState(holderClass, "holder",
- v => s"$v = new $holderClass($result, ${numVarLenFields * 32});")
-
- val resetBufferHolder = if (numVarLenFields == 0) {
- ""
- } else {
- s"$holder.reset();"
- }
- val updateRowSize = if (numVarLenFields == 0) {
- ""
- } else {
- s"$result.setTotalSize($holder.totalSize());"
- }
+ val rowWriterClass = classOf[UnsafeRowWriter].getName
+ val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
+ v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});")
// Evaluate all the subexpression.
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
- val writeExpressions =
- writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true)
+ val writeExpressions = writeExpressionsToBuffer(
+ ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
val code =
s"""
- $resetBufferHolder
+ $rowWriter.reset();
$evalSubexpr
$writeExpressions
- $updateRowSize
"""
- ExprCode(code, "false", result)
+ ExprCode(code, "false", s"$rowWriter.getRow()")
}
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index adf9ddf327c96..0e9d357c19c63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
import java.lang.reflect.Modifier
+import scala.collection.JavaConverters._
import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
@@ -501,12 +502,22 @@ case class LambdaVariable(
value: String,
isNull: String,
dataType: DataType,
- nullable: Boolean = true) extends LeafExpression
- with Unevaluable with NonSQLExpression {
+ nullable: Boolean = true) extends LeafExpression with NonSQLExpression {
+
+ // Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
+ override def eval(input: InternalRow): Any = {
+ assert(input.numFields == 1,
+ "The input row of interpreted LambdaVariable should have only 1 field.")
+ input.get(0, dataType)
+ }
override def genCode(ctx: CodegenContext): ExprCode = {
ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
}
+
+ // This won't be called as `genCode` is overrided, just overriding it to make
+ // `LambdaVariable` non-abstract.
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev
}
/**
@@ -599,8 +610,92 @@ case class MapObjects private(
override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ // The data with UserDefinedType are actually stored with the data type of its sqlType.
+ // When we want to apply MapObjects on it, we have to use it.
+ lazy private val inputDataType = inputData.dataType match {
+ case u: UserDefinedType[_] => u.sqlType
+ case _ => inputData.dataType
+ }
+
+ private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = {
+ val row = new GenericInternalRow(1)
+ inputCollection.toIterator.map { element =>
+ row.update(0, element)
+ lambdaFunction.eval(row)
+ }
+ }
+
+ private lazy val convertToSeq: Any => Seq[_] = inputDataType match {
+ case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+ _.asInstanceOf[Seq[_]]
+ case ObjectType(cls) if cls.isArray =>
+ _.asInstanceOf[Array[_]].toSeq
+ case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+ _.asInstanceOf[java.util.List[_]].asScala
+ case ObjectType(cls) if cls == classOf[Object] =>
+ (inputCollection) => {
+ if (inputCollection.getClass.isArray) {
+ inputCollection.asInstanceOf[Array[_]].toSeq
+ } else {
+ inputCollection.asInstanceOf[Seq[_]]
+ }
+ }
+ case ArrayType(et, _) =>
+ _.asInstanceOf[ArrayData].array
+ }
+
+ private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
+ case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+ // Scala sequence
+ executeFuncOnCollection(_).toSeq
+ case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
+ // Scala set
+ executeFuncOnCollection(_).toSet
+ case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+ // Java list
+ if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
+ cls == classOf[java.util.AbstractSequentialList[_]]) {
+ // Specifying non concrete implementations of `java.util.List`
+ executeFuncOnCollection(_).toSeq.asJava
+ } else {
+ val constructors = cls.getConstructors()
+ val intParamConstructor = constructors.find { constructor =>
+ constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int]
+ }
+ val noParamConstructor = constructors.find { constructor =>
+ constructor.getParameterCount == 0
+ }
+
+ val constructor = intParamConstructor.map { intConstructor =>
+ (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
+ }.getOrElse {
+ (_: Int) => noParamConstructor.get.newInstance()
+ }
+
+ // Specifying concrete implementations of `java.util.List`
+ (inputs) => {
+ val results = executeFuncOnCollection(inputs)
+ val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]]
+ results.foreach(builder.add(_))
+ builder
+ }
+ }
+ case None =>
+ // array
+ x => new GenericArrayData(executeFuncOnCollection(x).toArray)
+ case Some(cls) =>
+ throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " +
+ "resulting collection.")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val inputCollection = inputData.eval(input)
+
+ if (inputCollection == null) {
+ return null
+ }
+ mapElements(convertToSeq(inputCollection))
+ }
override def dataType: DataType =
customCollectionCls.map(ObjectType.apply).getOrElse(
@@ -647,13 +742,6 @@ case class MapObjects private(
case _ => ""
}
- // The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
- // When we want to apply MapObjects on it, we have to use it.
- val inputDataType = inputData.dataType match {
- case p: PythonUserDefinedType => p.sqlType
- case _ => inputData.dataType
- }
-
// `MapObjects` generates a while loop to traverse the elements of the input collection. We
// need to take care of Seq and List because they may have O(n) complexity for indexed accessing
// like `list.get(1)`. Here we use Iterator to traverse Seq and List.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 2829d1d81eb1a..9a1bbc675e397 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -153,7 +153,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
RewritePredicateSubquery,
ColumnPruning,
CollapseProject,
- RemoveRedundantProject)
+ RemoveRedundantProject) :+
+ Batch("UpdateAttributeReferences", Once,
+ UpdateNullabilityInAttributeReferences)
}
/**
@@ -1309,3 +1311,18 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Updates nullability in [[AttributeReference]]s if nullability is different between
+ * non-leaf plan's expressions and the children output.
+ */
+object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case p if !p.isInstanceOf[LeafNode] =>
+ val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable })
+ p transformExpressions {
+ case ar: AttributeReference if nullabilityMap.contains(ar) =>
+ ar.withNullability(nullabilityMap(ar))
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index a6e5aa6daca65..c3fdb924243df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.optimizer
+import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
/**
* Collapse plans consisting empty local relations generated by [[PruneFilters]].
@@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules._
* - Aggregate with all empty children and at least one grouping expression.
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
*/
-object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
+object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport {
private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match {
case p: LocalRelation => p.data.isEmpty
case _ => false
@@ -43,7 +45,9 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
// Construct a project list from plan's output, while the value is always NULL.
private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] =
- plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) }
+ plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) }
+
+ override def conf: SQLConf = SQLConf.get
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p: Union if p.children.forall(isEmptyLocalRelation) =>
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
index fb3dbe8ed1996..2da87113c6229 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
@@ -27,7 +27,6 @@
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
import org.apache.spark.unsafe.types.UTF8String;
@@ -55,36 +54,27 @@ private String getRandomString(int length) {
}
private UnsafeRow makeKeyRow(long k1, String k2) {
- UnsafeRow row = new UnsafeRow(2);
- BufferHolder holder = new BufferHolder(row, 32);
- UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
- holder.reset();
+ UnsafeRowWriter writer = new UnsafeRowWriter(2);
+ writer.reset();
writer.write(0, k1);
writer.write(1, UTF8String.fromString(k2));
- row.setTotalSize(holder.totalSize());
- return row;
+ return writer.getRow();
}
private UnsafeRow makeKeyRow(long k1, long k2) {
- UnsafeRow row = new UnsafeRow(2);
- BufferHolder holder = new BufferHolder(row, 0);
- UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
- holder.reset();
+ UnsafeRowWriter writer = new UnsafeRowWriter(2);
+ writer.reset();
writer.write(0, k1);
writer.write(1, k2);
- row.setTotalSize(holder.totalSize());
- return row;
+ return writer.getRow();
}
private UnsafeRow makeValueRow(long v1, long v2) {
- UnsafeRow row = new UnsafeRow(2);
- BufferHolder holder = new BufferHolder(row, 0);
- UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
- holder.reset();
+ UnsafeRowWriter writer = new UnsafeRowWriter(2);
+ writer.reset();
writer.write(0, v1);
writer.write(1, v2);
- row.setTotalSize(holder.totalSize());
- return row;
+ return writer.getRow();
}
private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 1f6964dfef598..0edd27c8241e8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import org.apache.spark.{SparkConf, SparkFunSuite}
@@ -25,7 +26,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
@@ -135,6 +136,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
+ test("SPARK-23587: MapObjects should support interpreted execution") {
+ def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = {
+ val function = (lambda: Expression) => Add(lambda, Literal(1))
+ val elementType = IntegerType
+ val expected = Seq(2, 3, 4)
+
+ val inputObject = BoundReference(0, inputType, nullable = true)
+ val optClass = Option(collectionCls)
+ val mapObj = MapObjects(function, inputObject, elementType, true, optClass)
+ val row = InternalRow.fromSeq(Seq(collection))
+ val result = mapObj.eval(row)
+
+ collectionCls match {
+ case null =>
+ assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
+ case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
+ assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected)
+ case s if classOf[Seq[_]].isAssignableFrom(s) =>
+ assert(result.asInstanceOf[Seq[_]].toSeq == expected)
+ case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
+ assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet)
+ }
+ }
+
+ val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
+ classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
+ classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]],
+ classOf[java.util.Stack[Int]], null)
+
+ val list = new java.util.ArrayList[Int]()
+ list.add(1)
+ list.add(2)
+ list.add(3)
+ val arrayData = new GenericArrayData(Array(1, 2, 3))
+ val vector = new java.util.Vector[Int]()
+ vector.add(1)
+ vector.add(2)
+ vector.add(3)
+ val stack = new java.util.Stack[Int]()
+ stack.add(1)
+ stack.add(2)
+ stack.add(3)
+
+ Seq(
+ (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
+ (Array(1, 2, 3), ObjectType(classOf[Array[Int]])),
+ (Seq(1, 2, 3), ObjectType(classOf[Object])),
+ (Array(1, 2, 3), ObjectType(classOf[Object])),
+ (list, ObjectType(classOf[java.util.List[Int]])),
+ (vector, ObjectType(classOf[java.util.Vector[Int]])),
+ (stack, ObjectType(classOf[java.util.Stack[Int]])),
+ (arrayData, ArrayType(IntegerType))
+ ).foreach { case (collection, inputType) =>
+ customCollectionClasses.foreach(testMapObjects(collection, _, inputType))
+
+ // Unsupported custom collection class
+ val errMsg = intercept[RuntimeException] {
+ testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType)
+ }.getMessage()
+ assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " +
+ "as resulting collection."))
+ }
+ }
+
test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") {
val cls = classOf[java.lang.Integer]
val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index 3964508e3a55e..f1ce7543ffdc1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, StructType}
class PropagateEmptyRelationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
@@ -37,7 +37,8 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
PruneFilters,
- PropagateEmptyRelation) :: Nil
+ PropagateEmptyRelation,
+ CollapseProject) :: Nil
}
object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] {
@@ -48,7 +49,8 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceExceptWithAntiJoin,
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
- PruneFilters) :: Nil
+ PruneFilters,
+ CollapseProject) :: Nil
}
val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1)))
@@ -79,9 +81,11 @@ class PropagateEmptyRelationSuite extends PlanTest {
(true, false, Inner, Some(LocalRelation('a.int, 'b.int))),
(true, false, Cross, Some(LocalRelation('a.int, 'b.int))),
- (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)),
+ (true, false, LeftOuter,
+ Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)),
(true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))),
- (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)),
+ (true, false, FullOuter,
+ Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)),
(true, false, LeftAnti, Some(testRelation1)),
(true, false, LeftSemi, Some(LocalRelation('a.int))),
@@ -89,8 +93,9 @@ class PropagateEmptyRelationSuite extends PlanTest {
(false, true, Cross, Some(LocalRelation('a.int, 'b.int))),
(false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))),
(false, true, RightOuter,
- Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)),
- (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)),
+ Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)),
+ (false, true, FullOuter,
+ Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)),
(false, true, LeftAnti, Some(LocalRelation('a.int))),
(false, true, LeftSemi, Some(LocalRelation('a.int))),
@@ -209,4 +214,11 @@ class PropagateEmptyRelationSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("propagate empty relation keeps the plan resolved") {
+ val query = testRelation1.join(
+ LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None)
+ val optimized = Optimize.execute(query.analyze)
+ assert(optimized.resolved)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala
new file mode 100644
index 0000000000000..09b11f5aba2a0
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+
+class UpdateNullabilityInAttributeReferencesSuite extends PlanTest {
+
+ object Optimizer extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Constant Folding", FixedPoint(10),
+ NullPropagation,
+ ConstantFolding,
+ BooleanSimplification,
+ SimplifyConditionals,
+ SimplifyBinaryComparison,
+ SimplifyExtractValueOps) ::
+ Batch("UpdateAttributeReferences", Once,
+ UpdateNullabilityInAttributeReferences) :: Nil
+ }
+
+ test("update nullability in AttributeReference") {
+ val rel = LocalRelation('a.long.notNull)
+ // In the 'original' plans below, the Aggregate node produced by groupBy() has a
+ // nullable AttributeReference to `b`, because both array indexing and map lookup are
+ // nullable expressions. After optimization, the same attribute is now non-nullable,
+ // but the AttributeReference is not updated to reflect this. So, we need to update nullability
+ // by the `UpdateNullabilityInAttributeReferences` rule.
+ val original = rel
+ .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b")
+ .groupBy($"b")("1")
+ val expected = rel.select('a as "b").groupBy($"b")("1").analyze
+ val optimized = Optimizer.execute(original.analyze)
+ comparePlans(optimized, expected)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index 21ed987627b3b..633d86d495581 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -378,15 +378,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.groupBy($"foo")("1")
checkRule(structRel, structExpected)
- // These tests must use nullable attributes from the base relation for the following reason:
- // in the 'original' plans below, the Aggregate node produced by groupBy() has a
- // nullable AttributeReference to a1, because both array indexing and map lookup are
- // nullable expressions. After optimization, the same attribute is now non-nullable,
- // but the AttributeReference is not updated to reflect this. In the 'expected' plans,
- // the grouping expressions have the same nullability as the original attribute in the
- // relation. If that attribute is non-nullable, the tests will fail as the plans will
- // compare differently, so for these tests we must use a nullable attribute. See
- // SPARK-23634.
val arrayRel = relation
.select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1")
.groupBy($"a1")("1")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 11bfaa0a726a7..4f0e2f58843a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -952,7 +952,8 @@ object SparkSession {
session = new SparkSession(sparkContext, None, None, extensions)
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
- defaultSession.set(session)
+ setDefaultSession(session)
+ setActiveSession(session)
// Register a successfully instantiated context to the singleton. This should be at the
// end of the class definition so that the singleton is updated only if there is no
@@ -1028,6 +1029,17 @@ object SparkSession {
*/
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
+ /**
+ * Returns the currently active SparkSession, otherwise the default one. If there is no default
+ * SparkSession, throws an exception.
+ *
+ * @since 2.4.0
+ */
+ def active: SparkSession = {
+ getActiveSession.getOrElse(getDefaultSession.getOrElse(
+ throw new IllegalStateException("No active or default Spark session found")))
+ }
+
////////////////////////////////////////////////////////////////////////////////////////
// Private methods from now on
////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index 8617be88f3570..d5508275c48c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -165,18 +165,14 @@ class RowBasedHashMapGenerator(
| if (buckets[idx] == -1) {
| if (numRows < capacity && !isBatchFull) {
| // creating the unsafe for new entry
- | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length});
- | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
- | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
- | ${numVarLenFields * 32});
| org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
| = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
- | agg_holder,
- | ${groupingKeySchema.length});
- | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
+ | ${groupingKeySchema.length}, ${numVarLenFields * 32});
+ | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed
| agg_rowWriter.zeroOutNullBytes();
| ${createUnsafeRowForKey};
- | agg_result.setTotalSize(agg_holder.totalSize());
+ | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result
+ | = agg_rowWriter.getRow();
| Object kbase = agg_result.getBaseObject();
| long koff = agg_result.getBaseOffset();
| int klen = agg_result.getSizeInBytes();
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 3b5655ba0582e..2d699e8a9d088 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -165,9 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
private ByteOrder nativeOrder = null;
private byte[][] buffers = null;
- private UnsafeRow unsafeRow = new UnsafeRow($numFields);
- private BufferHolder bufferHolder = new BufferHolder(unsafeRow);
- private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields);
+ private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields);
private MutableUnsafeRow mutableRow = null;
private int currentRow = 0;
@@ -212,11 +210,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
public InternalRow next() {
currentRow += 1;
- bufferHolder.reset();
+ rowWriter.reset();
rowWriter.zeroOutNullBytes();
${extractorCalls}
- unsafeRow.setTotalSize(bufferHolder.totalSize());
- return unsafeRow;
+ return rowWriter.getRow();
}
${ctx.declareAddedFunctions()}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index 9647f09867643..e93908da43535 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -26,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
@@ -130,16 +130,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
val emptyUnsafeRow = new UnsafeRow(0)
reader.map(_ => emptyUnsafeRow)
} else {
- val unsafeRow = new UnsafeRow(1)
- val bufferHolder = new BufferHolder(unsafeRow)
- val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1)
+ val unsafeRowWriter = new UnsafeRowWriter(1)
reader.map { line =>
// Writes to an UnsafeRow directly
- bufferHolder.reset()
+ unsafeRowWriter.reset()
unsafeRowWriter.write(0, line.getBytes, 0, line.getLength)
- unsafeRow.setTotalSize(bufferHolder.totalSize())
- unsafeRow
+ unsafeRowWriter.getRow()
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
deleted file mode 100644
index 2cc54107f8b83..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import org.apache.spark.TaskContext
-import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
-import org.apache.spark.sql.catalyst.encoders.encoderFor
-
-/**
- * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
- * [[ForeachWriter]].
- *
- * @param writer The [[ForeachWriter]] to process all data.
- * @tparam T The expected type of the sink.
- */
-class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {
-
- override def addBatch(batchId: Long, data: DataFrame): Unit = {
- // This logic should've been as simple as:
- // ```
- // data.as[T].foreachPartition { iter => ... }
- // ```
- //
- // Unfortunately, doing that would just break the incremental planing. The reason is,
- // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will
- // create a new plan. Because StreamExecution uses the existing plan to collect metrics and
- // update watermark, we should never create a new plan. Otherwise, metrics and watermark are
- // updated in the new plan, and StreamExecution cannot retrieval them.
- //
- // Hence, we need to manually convert internal rows to objects using encoder.
- val encoder = encoderFor[T].resolveAndBind(
- data.logicalPlan.output,
- data.sparkSession.sessionState.analyzer)
- data.queryExecution.toRdd.foreachPartition { iter =>
- if (writer.open(TaskContext.getPartitionId(), batchId)) {
- try {
- while (iter.hasNext) {
- writer.process(encoder.fromRow(iter.next()))
- }
- } catch {
- case e: Throwable =>
- writer.close(e)
- throw e
- }
- writer.close(null)
- } else {
- writer.close(null)
- }
- }
- }
-
- override def toString(): String = "ForeachSink"
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
new file mode 100644
index 0000000000000..df5d69d57e36f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified
+ * [[ForeachWriter]].
+ *
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @tparam T The expected type of the sink.
+ */
+case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport {
+ override def createStreamWriter(
+ queryId: String,
+ schema: StructType,
+ mode: OutputMode,
+ options: DataSourceOptions): StreamWriter = {
+ new StreamWriter with SupportsWriteInternalRow {
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+ override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+
+ override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
+ val encoder = encoderFor[T].resolveAndBind(
+ schema.toAttributes,
+ SparkSession.getActiveSession.get.sessionState.analyzer)
+ ForeachWriterFactory(writer, encoder)
+ }
+
+ override def toString: String = "ForeachSink"
+ }
+ }
+}
+
+case class ForeachWriterFactory[T: Encoder](
+ writer: ForeachWriter[T],
+ encoder: ExpressionEncoder[T])
+ extends DataWriterFactory[InternalRow] {
+ override def createDataWriter(
+ partitionId: Int,
+ attemptNumber: Int,
+ epochId: Long): ForeachDataWriter[T] = {
+ new ForeachDataWriter(writer, encoder, partitionId, epochId)
+ }
+}
+
+/**
+ * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]].
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]]
+ * @param partitionId
+ * @param epochId
+ * @tparam T The type expected by the writer.
+ */
+class ForeachDataWriter[T : Encoder](
+ writer: ForeachWriter[T],
+ encoder: ExpressionEncoder[T],
+ partitionId: Int,
+ epochId: Long)
+ extends DataWriter[InternalRow] {
+
+ // If open returns false, we should skip writing rows.
+ private val opened = writer.open(partitionId, epochId)
+
+ override def write(record: InternalRow): Unit = {
+ if (!opened) return
+
+ try {
+ writer.process(encoder.fromRow(record))
+ } catch {
+ case t: Throwable =>
+ writer.close(t)
+ throw t
+ }
+ }
+
+ override def commit(): WriterCommitMessage = {
+ writer.close(null)
+ ForeachWriterCommitMessage
+ }
+
+ override def abort(): Unit = {}
+}
+
+/**
+ * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination.
+ */
+case object ForeachWriterCommitMessage extends WriterCommitMessage
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 2fc903168cfa0..effc1471e8e12 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
-import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
+import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2}
import org.apache.spark.sql.sources.v2.StreamWriteSupport
/**
@@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
- val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc)
+ val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 6f43f18ffd0d4..36e25428a1b93 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2056,11 +2056,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
expr: String,
expectedNonNullableColumns: Seq[String]): Unit = {
val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr)
- // In the logical plan, all the output columns of input dataframe are nullable
- dfWithFilter.queryExecution.optimizedPlan.collect {
- case e: Filter => assert(e.output.forall(_.nullable))
- }
-
dfWithFilter.queryExecution.executedPlan.collect {
// When the child expression in isnotnull is null-intolerant (i.e. any null input will
// result in null output), the involved columns are converted to not nullable;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
index c0301f2ce2d66..44bf8624a6bcd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
@@ -50,6 +50,24 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach {
assert(SparkSession.builder().getOrCreate() == session)
}
+ test("sets default and active session") {
+ assert(SparkSession.getDefaultSession == None)
+ assert(SparkSession.getActiveSession == None)
+ val session = SparkSession.builder().master("local").getOrCreate()
+ assert(SparkSession.getDefaultSession == Some(session))
+ assert(SparkSession.getActiveSession == Some(session))
+ }
+
+ test("get active or default session") {
+ val session = SparkSession.builder().master("local").getOrCreate()
+ assert(SparkSession.active == session)
+ SparkSession.clearActiveSession()
+ assert(SparkSession.active == session)
+ SparkSession.clearDefaultSession()
+ intercept[IllegalStateException](SparkSession.active)
+ session.stop()
+ }
+
test("config options are propagated to existing SparkSession") {
val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate()
assert(session1.conf.get("spark-config1") == "a")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala
similarity index 77%
rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala
index b249dd41a84a6..03bf71b3f4b78 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.streaming
+package org.apache.spark.sql.execution.streaming.sources
import java.util.concurrent.ConcurrentLinkedQueue
@@ -25,11 +25,12 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest}
import org.apache.spark.sql.test.SharedSQLContext
-class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
+class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._
@@ -47,9 +48,9 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
.start()
def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = {
- import ForeachSinkSuite._
+ import ForeachWriterSuite._
- val events = ForeachSinkSuite.allEvents()
+ val events = ForeachWriterSuite.allEvents()
assert(events.size === 2) // one seq of events for each of the 2 partitions
// Verify both seq of events have an Open event as the first event
@@ -64,13 +65,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
}
// -- batch 0 ---------------------------------------
- ForeachSinkSuite.clear()
+ ForeachWriterSuite.clear()
input.addData(1, 2, 3, 4)
query.processAllAvailable()
verifyOutput(expectedVersion = 0, expectedData = 1 to 4)
// -- batch 1 ---------------------------------------
- ForeachSinkSuite.clear()
+ ForeachWriterSuite.clear()
input.addData(5, 6, 7, 8)
query.processAllAvailable()
verifyOutput(expectedVersion = 1, expectedData = 5 to 8)
@@ -95,27 +96,27 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
input.addData(1, 2, 3, 4)
query.processAllAvailable()
- var allEvents = ForeachSinkSuite.allEvents()
+ var allEvents = ForeachWriterSuite.allEvents()
assert(allEvents.size === 1)
var expectedEvents = Seq(
- ForeachSinkSuite.Open(partition = 0, version = 0),
- ForeachSinkSuite.Process(value = 4),
- ForeachSinkSuite.Close(None)
+ ForeachWriterSuite.Open(partition = 0, version = 0),
+ ForeachWriterSuite.Process(value = 4),
+ ForeachWriterSuite.Close(None)
)
assert(allEvents === Seq(expectedEvents))
- ForeachSinkSuite.clear()
+ ForeachWriterSuite.clear()
// -- batch 1 ---------------------------------------
input.addData(5, 6, 7, 8)
query.processAllAvailable()
- allEvents = ForeachSinkSuite.allEvents()
+ allEvents = ForeachWriterSuite.allEvents()
assert(allEvents.size === 1)
expectedEvents = Seq(
- ForeachSinkSuite.Open(partition = 0, version = 1),
- ForeachSinkSuite.Process(value = 8),
- ForeachSinkSuite.Close(None)
+ ForeachWriterSuite.Open(partition = 0, version = 1),
+ ForeachWriterSuite.Process(value = 8),
+ ForeachWriterSuite.Close(None)
)
assert(allEvents === Seq(expectedEvents))
@@ -131,7 +132,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
.foreach(new TestForeachWriter() {
override def process(value: Int): Unit = {
super.process(value)
- throw new RuntimeException("error")
+ throw new RuntimeException("ForeachSinkSuite error")
}
}).start()
input.addData(1, 2, 3, 4)
@@ -141,18 +142,18 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
query.processAllAvailable()
}
assert(e.getCause.isInstanceOf[SparkException])
- assert(e.getCause.getCause.getMessage === "error")
+ assert(e.getCause.getCause.getCause.getMessage === "ForeachSinkSuite error")
assert(query.isActive === false)
- val allEvents = ForeachSinkSuite.allEvents()
+ val allEvents = ForeachWriterSuite.allEvents()
assert(allEvents.size === 1)
- assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0))
- assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1))
+ assert(allEvents(0)(0) === ForeachWriterSuite.Open(partition = 0, version = 0))
+ assert(allEvents(0)(1) === ForeachWriterSuite.Process(value = 1))
// `close` should be called with the error
- val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close]
+ val errorEvent = allEvents(0)(2).asInstanceOf[ForeachWriterSuite.Close]
assert(errorEvent.error.get.isInstanceOf[RuntimeException])
- assert(errorEvent.error.get.getMessage === "error")
+ assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error")
}
}
@@ -177,12 +178,12 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
inputData.addData(10, 11, 12)
query.processAllAvailable()
- val allEvents = ForeachSinkSuite.allEvents()
+ val allEvents = ForeachWriterSuite.allEvents()
assert(allEvents.size === 1)
val expectedEvents = Seq(
- ForeachSinkSuite.Open(partition = 0, version = 0),
- ForeachSinkSuite.Process(value = 3),
- ForeachSinkSuite.Close(None)
+ ForeachWriterSuite.Open(partition = 0, version = 0),
+ ForeachWriterSuite.Process(value = 3),
+ ForeachWriterSuite.Close(None)
)
assert(allEvents === Seq(expectedEvents))
} finally {
@@ -216,21 +217,21 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
query.processAllAvailable()
// There should be 3 batches and only does the last batch contain a value.
- val allEvents = ForeachSinkSuite.allEvents()
+ val allEvents = ForeachWriterSuite.allEvents()
assert(allEvents.size === 3)
val expectedEvents = Seq(
Seq(
- ForeachSinkSuite.Open(partition = 0, version = 0),
- ForeachSinkSuite.Close(None)
+ ForeachWriterSuite.Open(partition = 0, version = 0),
+ ForeachWriterSuite.Close(None)
),
Seq(
- ForeachSinkSuite.Open(partition = 0, version = 1),
- ForeachSinkSuite.Close(None)
+ ForeachWriterSuite.Open(partition = 0, version = 1),
+ ForeachWriterSuite.Close(None)
),
Seq(
- ForeachSinkSuite.Open(partition = 0, version = 2),
- ForeachSinkSuite.Process(value = 3),
- ForeachSinkSuite.Close(None)
+ ForeachWriterSuite.Open(partition = 0, version = 2),
+ ForeachWriterSuite.Process(value = 3),
+ ForeachWriterSuite.Close(None)
)
)
assert(allEvents === expectedEvents)
@@ -258,7 +259,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
}
/** A global object to collect events in the executor */
-object ForeachSinkSuite {
+object ForeachWriterSuite {
trait Event
@@ -285,21 +286,21 @@ object ForeachSinkSuite {
/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */
class TestForeachWriter extends ForeachWriter[Int] {
- ForeachSinkSuite.clear()
+ ForeachWriterSuite.clear()
- private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]()
+ private val events = mutable.ArrayBuffer[ForeachWriterSuite.Event]()
override def open(partitionId: Long, version: Long): Boolean = {
- events += ForeachSinkSuite.Open(partition = partitionId, version = version)
+ events += ForeachWriterSuite.Open(partition = partitionId, version = version)
true
}
override def process(value: Int): Unit = {
- events += ForeachSinkSuite.Process(value)
+ events += ForeachWriterSuite.Process(value)
}
override def close(errorOrNull: Throwable): Unit = {
- events += ForeachSinkSuite.Close(error = Option(errorOrNull))
- ForeachSinkSuite.addEvents(events)
+ events += ForeachWriterSuite.Close(error = Option(errorOrNull))
+ ForeachWriterSuite.addEvents(events)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 08749b49997e0..20942ed93897c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 3038b822beb4a..17603deacdcdd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -35,6 +35,7 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc)
}
SparkSession.setDefaultSession(this)
+ SparkSession.setActiveSession(this)
@transient
override lazy val sessionState: SessionState = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 814038d4ef7af..965aea2b61456 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -159,10 +159,6 @@ private[hive] class TestHiveSparkSession(
private val loadTestTables: Boolean)
extends SparkSession(sc) with Logging { self =>
- // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as
- // TestSparkSession, but doing this the same way breaks many tests in the package. We need
- // to investigate and find a different strategy.
-
def this(sc: SparkContext, loadTestTables: Boolean) {
this(
sc,
@@ -179,6 +175,9 @@ private[hive] class TestHiveSparkSession(
loadTestTables)
}
+ SparkSession.setDefaultSession(this)
+ SparkSession.setActiveSession(this)
+
{ // set the metastore temporary configuration
val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map(
ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true",