diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index 1b7e031ee0678..ccb30e205ca40 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -19,6 +19,7 @@ package org.apache.spark.deploy
import java.io.File
import java.net.{InetAddress, URI}
+import java.nio.file.Files
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@@ -48,7 +49,7 @@ object PythonRunner {
// Format python file paths before adding them to the PYTHONPATH
val formattedPythonFile = formatPath(pythonFile)
- val formattedPyFiles = formatPaths(pyFiles)
+ val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles))
// Launch a Py4J gateway server for the process to connect to; this will let it see our
// Java system properties and such
@@ -153,4 +154,30 @@ object PythonRunner {
.map { p => formatPath(p, testWindows) }
}
+ /**
+ * Resolves the ".py" files. ".py" file should not be added as is because PYTHONPATH does
+ * not expect a file. This method creates a temporary directory and puts the ".py" files
+ * if exist in the given paths.
+ */
+ private def resolvePyFiles(pyFiles: Array[String]): Array[String] = {
+ lazy val dest = Utils.createTempDir(namePrefix = "localPyFiles")
+ pyFiles.flatMap { pyFile =>
+ // In case of client with submit, the python paths should be set before context
+ // initialization because the context initialization can be done later.
+ // We will copy the local ".py" files because ".py" file shouldn't be added
+ // alone but its parent directory in PYTHONPATH. See SPARK-24384.
+ if (pyFile.endsWith(".py")) {
+ val source = new File(pyFile)
+ if (source.exists() && source.isFile && source.canRead) {
+ Files.copy(source.toPath, new File(dest, source.getName).toPath)
+ Some(dest.getAbsolutePath)
+ } else {
+ // Don't have to add it if it doesn't exist or isn't readable.
+ None
+ }
+ } else {
+ Some(pyFile)
+ }
+ }.distinct
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 8e97b3da33820..598b62f85a1fa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -42,7 +42,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils}
* up to launch speculative tasks, etc.
*
* Clients should first call initialize() and start(), then submit task sets through the
- * runTasks method.
+ * submitTasks method.
*
* THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
@@ -62,7 +62,7 @@ private[spark] class TaskSchedulerImpl(
this(sc, sc.conf.get(config.MAX_TASK_FAILURES))
}
- // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient,
+ // Lazily initializing blacklistTrackerOpt to avoid getting empty ExecutorAllocationClient,
// because ExecutorAllocationClient is created after this TaskSchedulerImpl.
private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc)
@@ -228,7 +228,7 @@ private[spark] class TaskSchedulerImpl(
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // 2. The task set manager has been created but no tasks have been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
taskIdToExecutorId.get(tid).foreach(execId =>
@@ -694,7 +694,7 @@ private[spark] class TaskSchedulerImpl(
*
* After stage failure and retry, there may be multiple TaskSetManagers for the stage.
* If an earlier attempt of a stage completes a task, we should ensure that the later attempts
- * do not also submit those same tasks. That also means that a task completion from an earlier
+ * do not also submit those same tasks. That also means that a task completion from an earlier
* attempt can lead to the entire stage getting marked as successful.
*/
private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index c2668a7ff832f..a1bc93e8f6781 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -206,7 +206,9 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We
jobs: Seq[v1.JobData],
killEnabled: Boolean): Seq[Node] = {
// stripXSS is called to remove suspicious characters used in XSS attacks
- val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS))
+ val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) =>
+ UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq
+ }
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag))
.map(para => para._1 + "=" + para._2(0))
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 08a927a8b4885..7ab433655233e 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -117,8 +117,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId)
- val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks +
- stageData.numFailedTasks + stageData.numKilledTasks
+ val totalTasks = taskCount(stageData)
if (totalTasks == 0) {
val content =
@@ -133,7 +132,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
val totalTasksNumStr = if (totalTasks == storedTasks) {
s"$totalTasks"
} else {
- s"$totalTasks, showing ${storedTasks}"
+ s"$storedTasks, showing ${totalTasks}"
}
val summary =
@@ -678,7 +677,7 @@ private[ui] class TaskDataSource(
private var _tasksToShow: Seq[TaskData] = null
- override def dataSize: Int = stage.numTasks
+ override def dataSize: Int = taskCount(stage)
override def sliceData(from: Int, to: Int): Seq[TaskData] = {
if (_tasksToShow == null) {
@@ -1044,4 +1043,9 @@ private[ui] object ApiHelper {
(stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name))
}
+ def taskCount(stageData: StageData): Int = {
+ stageData.numActiveTasks + stageData.numCompleteTasks + stageData.numFailedTasks +
+ stageData.numKilledTasks
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 18a4926f2f6c0..f001a01de3952 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -43,7 +43,9 @@ private[ui] class StageTableBase(
killEnabled: Boolean,
isFailedStage: Boolean) {
// stripXSS is called to remove suspicious characters used in XSS attacks
- val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS))
+ val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) =>
+ UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq
+ }
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag))
.map(para => para._1 + "=" + para._2(0))
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 14bc5e626771c..461806a659965 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1737,6 +1737,15 @@ To use `groupBy().apply()`, the user needs to define the following:
* A Python function that defines the computation for each group.
* A `StructType` object or a string that defines the schema of the output `DataFrame`.
+The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position,
+not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their
+position matches the corresponding field in the schema.
+
+Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column
+can differ from the order that it was placed in the dictionary. It is recommended in this case to
+explicitly define the column order using the `columns` keyword, e.g.
+`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`.
+
Note that all data for a group will be loaded into memory before the function is applied. This can
lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for
[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 7469f11df0294..c2e5137645d76 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -728,7 +728,8 @@ object Unidoc {
scalacOptions in (ScalaUnidoc, unidoc) ++= Seq(
"-groups", // Group similar methods together based on the @group annotation.
- "-skip-packages", "org.apache.hadoop"
+ "-skip-packages", "org.apache.hadoop",
+ "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath
) ++ (
// Add links to sources when generating Scaladoc for a non-snapshot release
if (!isSnapshot.value) {
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index cf26523b3cb45..9c02982e4ae22 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2216,7 +2216,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
The returnType should be a :class:`StructType` describing the schema of the returned
`pandas.DataFrame`.
- The length of the returned `pandas.DataFrame` can be arbitrary.
+ The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be
+ indexed so that their position matches the corresponding field in the schema.
Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
@@ -2239,6 +2240,12 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 1.1094003924504583|
+---+-------------------+
+ .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is
+ recommended to explicitly index the columns by name to ensure the positions are correct,
+ or alternatively use an `OrderedDict`.
+ For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or
+ `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`.
+
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
.. note:: The user-defined functions are considered deterministic by default. Due to
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 6a974d558403f..4210737310c6b 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -263,16 +263,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
"PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),
"PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv
- val moduleDir =
- if (clientMode) {
- // In client-mode, .py files added with --py-files are not visible in the driver.
- // This is something that the launcher library would have to handle.
- tempDir
- } else {
- val subdir = new File(tempDir, "pyModules")
- subdir.mkdir()
- subdir
- }
+ val moduleDir = {
+ val subdir = new File(tempDir, "pyModules")
+ subdir.mkdir()
+ subdir
+ }
val pyModule = new File(moduleDir, "mod1.py")
Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index ef3b67c0d48d0..dbf51c398fa47 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -161,13 +161,17 @@ object DecimalType extends AbstractDataType {
* This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true.
*/
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
- // Assumptions:
+ // Assumption:
assert(precision >= scale)
- assert(scale >= 0)
if (precision <= MAX_PRECISION) {
// Adjustment only needed when we exceed max precision
DecimalType(precision, scale)
+ } else if (scale < 0) {
+ // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
+ // loss since we would cause a loss of digits in the integer part.
+ // In this case, we are likely to meet an overflow.
+ DecimalType(MAX_PRECISION, scale)
} else {
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
val intDigits = precision - scale
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index c86dc18dfa680..bd87ca6017e99 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -272,6 +272,15 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
}
}
+ test("SPARK-24468: operations on decimals with negative scale") {
+ val a = AttributeReference("a", DecimalType(3, -10))()
+ val b = AttributeReference("b", DecimalType(1, -1))()
+ val c = AttributeReference("c", DecimalType(35, 1))()
+ checkType(Multiply(a, b), DecimalType(5, -11))
+ checkType(Multiply(a, c), DecimalType(38, -9))
+ checkType(Multiply(b, c), DecimalType(37, 0))
+ }
+
/** strength reduction for integer/decimal comparisons */
def ruleTest(initial: Expression, transformed: Expression): Unit = {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 25436e1c8dd45..c6565fcf66559 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -372,9 +372,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
- if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) {
// This is a sanity check. We should not reach here when we have multiple distinct
- // column sets. Our MultipleDistinctRewriter should take care this case.
+ // column sets. Our `RewriteDistinctAggregates` should take care this case.
sys.error("You hit a query analyzer bug. Please report your query to " +
"Spark user mailing list.")
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index c5070b734d521..2c18d6aaabdba 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -68,4 +68,8 @@ SELECT 1 from (
FROM (select 1 as x) a
WHERE false
) b
-where b.z != b.z
+where b.z != b.z;
+
+-- SPARK-24369 multiple distinct aggregations having the same argument set
+SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
+ FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql
index 9be7fcdadfea8..28a0e20c0f495 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql
@@ -40,12 +40,14 @@ select 10.3000 * 3.0;
select 10.30000 * 30.0;
select 10.300000000000000000 * 3.000000000000000000;
select 10.300000000000000000 * 3.0000000000000000000;
+select 2.35E10 * 1.0;
-- arithmetic operations causing an overflow return NULL
select (5e36 + 0.1) + 5e36;
select (-4e36 - 0.1) - 7e36;
select 12345678901234567890.0 * 12345678901234567890.0;
select 1e35 / 0.1;
+select 1.2345678901234567890E30 * 1.2345678901234567890E25;
-- arithmetic operations causing a precision loss are truncated
select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345;
@@ -67,12 +69,14 @@ select 10.3000 * 3.0;
select 10.30000 * 30.0;
select 10.300000000000000000 * 3.000000000000000000;
select 10.300000000000000000 * 3.0000000000000000000;
+select 2.35E10 * 1.0;
-- arithmetic operations causing an overflow return NULL
select (5e36 + 0.1) + 5e36;
select (-4e36 - 0.1) - 7e36;
select 12345678901234567890.0 * 12345678901234567890.0;
select 1e35 / 0.1;
+select 1.2345678901234567890E30 * 1.2345678901234567890E25;
-- arithmetic operations causing a precision loss return NULL
select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345;
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index c1abc6dff754b..581aa1754ce14 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 26
+-- Number of queries: 27
-- !query 0
@@ -241,3 +241,12 @@ where b.z != b.z
struct<1:int>
-- !query 25 output
+
+
+-- !query 26
+SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
+ FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y)
+-- !query 26 schema
+struct
+-- !query 26 output
+1.0 1.0 3
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out
index 6bfdb84548d4d..cbf44548b3cce 100644
--- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 36
+-- Number of queries: 40
-- !query 0
@@ -114,190 +114,222 @@ struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.00000000000000000
-- !query 13
-select (5e36 + 0.1) + 5e36
+select 2.35E10 * 1.0
-- !query 13 schema
-struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)>
+struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)>
-- !query 13 output
-NULL
+23500000000
-- !query 14
-select (-4e36 - 0.1) - 7e36
+select (5e36 + 0.1) + 5e36
-- !query 14 schema
-struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)>
+struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)>
-- !query 14 output
NULL
-- !query 15
-select 12345678901234567890.0 * 12345678901234567890.0
+select (-4e36 - 0.1) - 7e36
-- !query 15 schema
-struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)>
+struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)>
-- !query 15 output
NULL
-- !query 16
-select 1e35 / 0.1
+select 12345678901234567890.0 * 12345678901234567890.0
-- !query 16 schema
-struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)>
+struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)>
-- !query 16 output
NULL
-- !query 17
-select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345
+select 1e35 / 0.1
-- !query 17 schema
-struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)>
+struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)>
-- !query 17 output
-10012345678912345678912345678911.246907
+NULL
-- !query 18
-select 123456789123456789.1234567890 * 1.123456789123456789
+select 1.2345678901234567890E30 * 1.2345678901234567890E25
-- !query 18 schema
-struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)>
+struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)>
-- !query 18 output
-138698367904130467.654320988515622621
+NULL
-- !query 19
-select 12345678912345.123456789123 / 0.000000012345678
+select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345
-- !query 19 schema
-struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)>
+struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)>
-- !query 19 output
-1000000073899961059796.725866332
+10012345678912345678912345678911.246907
-- !query 20
-set spark.sql.decimalOperations.allowPrecisionLoss=false
+select 123456789123456789.1234567890 * 1.123456789123456789
-- !query 20 schema
-struct
+struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)>
-- !query 20 output
-spark.sql.decimalOperations.allowPrecisionLoss false
+138698367904130467.654320988515622621
-- !query 21
-select id, a+b, a-b, a*b, a/b from decimals_test order by id
+select 12345678912345.123456789123 / 0.000000012345678
-- !query 21 schema
-struct
+struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)>
-- !query 21 output
-1 1099 -899 NULL 0.1001001001001001
-2 24690.246 0 NULL 1
-3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123
-4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436
+1000000073899961059796.725866332
-- !query 22
-select id, a*10, b/10 from decimals_test order by id
+set spark.sql.decimalOperations.allowPrecisionLoss=false
-- !query 22 schema
-struct
+struct
-- !query 22 output
-1 1000 99.9
-2 123451.23 1234.5123
-3 1.234567891011 123.41
-4 1234567891234567890 0.1123456789123456789
+spark.sql.decimalOperations.allowPrecisionLoss false
-- !query 23
-select 10.3 * 3.0
+select id, a+b, a-b, a*b, a/b from decimals_test order by id
-- !query 23 schema
-struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)>
+struct
-- !query 23 output
-30.9
+1 1099 -899 NULL 0.1001001001001001
+2 24690.246 0 NULL 1
+3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123
+4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436
-- !query 24
-select 10.3000 * 3.0
+select id, a*10, b/10 from decimals_test order by id
-- !query 24 schema
-struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)>
+struct
-- !query 24 output
-30.9
+1 1000 99.9
+2 123451.23 1234.5123
+3 1.234567891011 123.41
+4 1234567891234567890 0.1123456789123456789
-- !query 25
-select 10.30000 * 30.0
+select 10.3 * 3.0
-- !query 25 schema
-struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)>
+struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)>
-- !query 25 output
-309
+30.9
-- !query 26
-select 10.300000000000000000 * 3.000000000000000000
+select 10.3000 * 3.0
-- !query 26 schema
-struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)>
+struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)>
-- !query 26 output
30.9
-- !query 27
-select 10.300000000000000000 * 3.0000000000000000000
+select 10.30000 * 30.0
-- !query 27 schema
-struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)>
+struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)>
-- !query 27 output
-NULL
+309
-- !query 28
-select (5e36 + 0.1) + 5e36
+select 10.300000000000000000 * 3.000000000000000000
-- !query 28 schema
-struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)>
+struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)>
-- !query 28 output
-NULL
+30.9
-- !query 29
-select (-4e36 - 0.1) - 7e36
+select 10.300000000000000000 * 3.0000000000000000000
-- !query 29 schema
-struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)>
+struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)>
-- !query 29 output
NULL
-- !query 30
-select 12345678901234567890.0 * 12345678901234567890.0
+select 2.35E10 * 1.0
-- !query 30 schema
-struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)>
+struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)>
-- !query 30 output
-NULL
+23500000000
-- !query 31
-select 1e35 / 0.1
+select (5e36 + 0.1) + 5e36
-- !query 31 schema
-struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)>
+struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)>
-- !query 31 output
NULL
-- !query 32
-select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345
+select (-4e36 - 0.1) - 7e36
-- !query 32 schema
-struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)>
+struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)>
-- !query 32 output
NULL
-- !query 33
-select 123456789123456789.1234567890 * 1.123456789123456789
+select 12345678901234567890.0 * 12345678901234567890.0
-- !query 33 schema
-struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)>
+struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)>
-- !query 33 output
NULL
-- !query 34
-select 12345678912345.123456789123 / 0.000000012345678
+select 1e35 / 0.1
-- !query 34 schema
-struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)>
+struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)>
-- !query 34 output
NULL
-- !query 35
-drop table decimals_test
+select 1.2345678901234567890E30 * 1.2345678901234567890E25
-- !query 35 schema
-struct<>
+struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)>
-- !query 35 output
+NULL
+
+
+-- !query 36
+select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345
+-- !query 36 schema
+struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)>
+-- !query 36 output
+NULL
+
+
+-- !query 37
+select 123456789123456789.1234567890 * 1.123456789123456789
+-- !query 37 schema
+struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)>
+-- !query 37 output
+NULL
+
+
+-- !query 38
+select 12345678912345.123456789123 / 0.000000012345678
+-- !query 38 schema
+struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)>
+-- !query 38 output
+NULL
+
+
+-- !query 39
+drop table decimals_test
+-- !query 39 schema
+struct<>
+-- !query 39 output
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala
index d66a6902b0510..cbef1c7828319 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala
@@ -30,11 +30,15 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self
override def beforeAll() {
super.beforeAll()
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
}
override def afterEach() {
try {
resetSparkContext()
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
} finally {
super.afterEach()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index f8b26f5b28cc7..dfbc0346cb247 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext {
testPartialAggregationPlan(query)
}
+ test("mixed aggregates with same distinct columns") {
+ def assertNoExpand(plan: SparkPlan): Unit = {
+ assert(plan.collect { case e: ExpandExec => e }.isEmpty)
+ }
+
+ withTempView("v") {
+ Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v")
+ // one distinct column
+ val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i")
+ assertNoExpand(query1.queryExecution.executedPlan)
+
+ // 2 distinct columns
+ val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i")
+ assertNoExpand(query2.queryExecution.executedPlan)
+
+ // 2 distinct columns with different order
+ val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i")
+ assertNoExpand(query3.queryExecution.executedPlan)
+ }
+ }
+
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
def checkPlan(fieldTypes: Seq[DataType]): Unit = {
withTempView("testLimit") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index a3ae93810aa3c..d305ce3e698ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -21,15 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File}
import java.util.Properties
import org.apache.spark._
-import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
import org.apache.spark.storage.ShuffleBlockId
-import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter
/**
@@ -43,7 +41,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
}
}
-class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
+class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession {
private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
val converter = unsafeRowConverter(schema)
@@ -58,7 +56,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}
test("toUnsafeRow() test helper method") {
- // This currently doesnt work because the generic getter throws an exception.
+ // This currently doesn't work because the generic getter throws an exception.
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
assert(row.getString(0) === unsafeRow.getUTF8String(0).toString)
@@ -97,59 +95,43 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}
test("SPARK-10466: external sorter spilling with unsafe row serializer") {
- var sc: SparkContext = null
- var outputFile: File = null
- val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten
- Utils.tryWithSafeFinally {
- val conf = new SparkConf()
- .set("spark.shuffle.spill.initialMemoryThreshold", "1")
- .set("spark.shuffle.sort.bypassMergeThreshold", "0")
- .set("spark.testing.memory", "80000")
-
- sc = new SparkContext("local", "test", conf)
- outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
- // prepare data
- val converter = unsafeRowConverter(Array(IntegerType))
- val data = (1 to 10000).iterator.map { i =>
- (i, converter(Row(i)))
- }
- val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
- val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)
-
- val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
- taskContext,
- partitioner = Some(new HashPartitioner(10)),
- serializer = new UnsafeRowSerializer(numFields = 1))
-
- // Ensure we spilled something and have to merge them later
- assert(sorter.numSpills === 0)
- sorter.insertAll(data)
- assert(sorter.numSpills > 0)
+ val conf = new SparkConf()
+ .set("spark.shuffle.spill.initialMemoryThreshold", "1")
+ .set("spark.shuffle.sort.bypassMergeThreshold", "0")
+ .set("spark.testing.memory", "80000")
+ spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
+ val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
+ outputFile.deleteOnExit()
+ // prepare data
+ val converter = unsafeRowConverter(Array(IntegerType))
+ val data = (1 to 10000).iterator.map { i =>
+ (i, converter(Row(i)))
+ }
+ val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)
- // Merging spilled files should not throw assertion error
- sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
- } {
- // Clean up
- if (sc != null) {
- sc.stop()
- }
+ val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
+ taskContext,
+ partitioner = Some(new HashPartitioner(10)),
+ serializer = new UnsafeRowSerializer(numFields = 1))
- // restore the spark env
- SparkEnv.set(oldEnv)
+ // Ensure we spilled something and have to merge them later
+ assert(sorter.numSpills === 0)
+ sorter.insertAll(data)
+ assert(sorter.numSpills > 0)
- if (outputFile != null) {
- outputFile.delete()
- }
- }
+ // Merging spilled files should not throw assertion error
+ sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
}
test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
val conf = new SparkConf().set("spark.shuffle.manager", "sort")
- sc = new SparkContext("local", "test", conf)
+ spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
- val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)))
- .asInstanceOf[RDD[Product2[Int, InternalRow]]]
+ val rowsRDD = spark.sparkContext.parallelize(
+ Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))
+ ).asInstanceOf[RDD[Product2[Int, InternalRow]]]
val dependency =
new ShuffleDependency[Int, InternalRow, InternalRow](
rowsRDD,